@@ -705,10 +705,15 @@ def __init__(self, proxy: CopilotProvider):
705
705
self .stream_queue : Optional [asyncio .Queue ] = None
706
706
self .processing_task : Optional [asyncio .Task ] = None
707
707
708
+ self .finish_stream = False
709
+
710
+ # For debugging only
711
+ # self.data_sent = []
712
+
708
713
def connection_made (self , transport : asyncio .Transport ) -> None :
709
714
"""Handle successful connection to target"""
710
715
self .transport = transport
711
- logger .debug (f"Target transport peer : { transport .get_extra_info ('peername' )} " )
716
+ logger .debug (f"Connection established to target : { transport .get_extra_info ('peername' )} " )
712
717
self .proxy .target_transport = transport
713
718
714
719
def _ensure_output_processor (self ) -> None :
@@ -737,7 +742,7 @@ async def _process_stream(self):
737
742
try :
738
743
739
744
async def stream_iterator ():
740
- while True :
745
+ while not self . stream_queue . empty () :
741
746
incoming_record = await self .stream_queue .get ()
742
747
743
748
record_content = incoming_record .get ("content" , {})
@@ -750,6 +755,9 @@ async def stream_iterator():
750
755
else :
751
756
content = choice .get ("delta" , {}).get ("content" )
752
757
758
+ if choice .get ("finish_reason" , None ) == "stop" :
759
+ self .finish_stream = True
760
+
753
761
streaming_choices .append (
754
762
StreamingChoices (
755
763
finish_reason = choice .get ("finish_reason" , None ),
@@ -771,22 +779,18 @@ async def stream_iterator():
771
779
)
772
780
yield mr
773
781
774
- async for record in self .output_pipeline_instance .process_stream (stream_iterator ()):
782
+ async for record in self .output_pipeline_instance .process_stream (
783
+ stream_iterator (), cleanup_sensitive = False
784
+ ):
775
785
chunk = record .model_dump_json (exclude_none = True , exclude_unset = True )
776
786
sse_data = f"data: { chunk } \n \n " .encode ("utf-8" )
777
787
chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
778
788
self ._proxy_transport_write (chunk_size .encode ())
779
789
self ._proxy_transport_write (sse_data )
780
790
self ._proxy_transport_write (b"\r \n " )
781
791
782
- sse_data = b"data: [DONE]\n \n "
783
- # Add chunk size for DONE message too
784
- chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
785
- self ._proxy_transport_write (chunk_size .encode ())
786
- self ._proxy_transport_write (sse_data )
787
- self ._proxy_transport_write (b"\r \n " )
788
- # Now send the final zero chunk
789
- self ._proxy_transport_write (b"0\r \n \r \n " )
792
+ if self .finish_stream :
793
+ self .finish_data ()
790
794
791
795
except asyncio .CancelledError :
792
796
logger .debug ("Stream processing cancelled" )
@@ -795,12 +799,37 @@ async def stream_iterator():
795
799
logger .error (f"Error processing stream: { e } " )
796
800
finally :
797
801
# Clean up
802
+ self .stream_queue = None
798
803
if self .processing_task and not self .processing_task .done ():
799
804
self .processing_task .cancel ()
800
- if self .proxy .context_tracking and self .proxy .context_tracking .sensitive :
801
- self .proxy .context_tracking .sensitive .secure_cleanup ()
805
+
806
+ def finish_data (self ):
807
+ logger .debug ("Finishing data stream" )
808
+ sse_data = b"data: [DONE]\n \n "
809
+ # Add chunk size for DONE message too
810
+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
811
+ self ._proxy_transport_write (chunk_size .encode ())
812
+ self ._proxy_transport_write (sse_data )
813
+ self ._proxy_transport_write (b"\r \n " )
814
+ # Now send the final zero chunk
815
+ self ._proxy_transport_write (b"0\r \n \r \n " )
816
+
817
+ # For debugging only
818
+ # print("===========START DATA SENT====================")
819
+ # for data in self.data_sent:
820
+ # print(data)
821
+ # self.data_sent = []
822
+ # print("===========START DATA SENT====================")
823
+
824
+ self .finish_stream = False
825
+ self .headers_sent = False
802
826
803
827
def _process_chunk (self , chunk : bytes ):
828
+ # For debugging only
829
+ # print("===========START DATA RECVD====================")
830
+ # print(chunk)
831
+ # print("===========END DATA RECVD======================")
832
+
804
833
records = self .sse_processor .process_chunk (chunk )
805
834
806
835
for record in records :
@@ -812,13 +841,12 @@ def _process_chunk(self, chunk: bytes):
812
841
self .stream_queue .put_nowait (record )
813
842
814
843
def _proxy_transport_write (self , data : bytes ):
844
+ # For debugging only
845
+ # self.data_sent.append(data)
815
846
if not self .proxy .transport or self .proxy .transport .is_closing ():
816
847
logger .error ("Proxy transport not available" )
817
848
return
818
849
self .proxy .transport .write (data )
819
- # print("DEBUG =================================")
820
- # print(data)
821
- # print("DEBUG =================================")
822
850
823
851
def data_received (self , data : bytes ) -> None :
824
852
"""Handle data received from target"""
@@ -848,15 +876,13 @@ def data_received(self, data: bytes) -> None:
848
876
logger .debug (f"Headers sent: { headers } " )
849
877
850
878
data = data [header_end + 4 :]
851
- # print("DEBUG =================================")
852
- # print(data)
853
- # print("DEBUG =================================")
854
879
855
880
self ._process_chunk (data )
856
881
857
882
def connection_lost (self , exc : Optional [Exception ]) -> None :
858
883
"""Handle connection loss to target"""
859
884
885
+ logger .debug ("Lost connection to target" )
860
886
if (
861
887
not self .proxy ._closing
862
888
and self .proxy .transport
0 commit comments