Skip to content

Commit 5eeb0cd

Browse files
authored
Merge pull request #520 from stacklok/fix-copilot-hang
Fix issue causing copilot to hang after creating multiple sessions
2 parents 1d0c5f5 + 7008330 commit 5eeb0cd

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

src/codegate/pipeline/output.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async def _record_to_db(self):
115115
await self._db_recorder.record_context(self._input_context)
116116

117117
async def process_stream(
118-
self, stream: AsyncIterator[ModelResponse]
118+
self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True
119119
) -> AsyncIterator[ModelResponse]:
120120
"""
121121
Process a stream through all pipeline steps
@@ -182,7 +182,7 @@ async def process_stream(
182182
self._context.buffer.clear()
183183

184184
# Cleanup sensitive data through the input context
185-
if self._input_context and self._input_context.sensitive:
185+
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
186186
self._input_context.sensitive.secure_cleanup()
187187

188188

src/codegate/providers/copilot/provider.py

+45-19
Original file line numberDiff line numberDiff line change
@@ -705,10 +705,15 @@ def __init__(self, proxy: CopilotProvider):
705705
self.stream_queue: Optional[asyncio.Queue] = None
706706
self.processing_task: Optional[asyncio.Task] = None
707707

708+
self.finish_stream = False
709+
710+
# For debugging only
711+
# self.data_sent = []
712+
708713
def connection_made(self, transport: asyncio.Transport) -> None:
709714
"""Handle successful connection to target"""
710715
self.transport = transport
711-
logger.debug(f"Target transport peer: {transport.get_extra_info('peername')}")
716+
logger.debug(f"Connection established to target: {transport.get_extra_info('peername')}")
712717
self.proxy.target_transport = transport
713718

714719
def _ensure_output_processor(self) -> None:
@@ -737,7 +742,7 @@ async def _process_stream(self):
737742
try:
738743

739744
async def stream_iterator():
740-
while True:
745+
while not self.stream_queue.empty():
741746
incoming_record = await self.stream_queue.get()
742747

743748
record_content = incoming_record.get("content", {})
@@ -750,6 +755,9 @@ async def stream_iterator():
750755
else:
751756
content = choice.get("delta", {}).get("content")
752757

758+
if choice.get("finish_reason", None) == "stop":
759+
self.finish_stream = True
760+
753761
streaming_choices.append(
754762
StreamingChoices(
755763
finish_reason=choice.get("finish_reason", None),
@@ -771,22 +779,18 @@ async def stream_iterator():
771779
)
772780
yield mr
773781

774-
async for record in self.output_pipeline_instance.process_stream(stream_iterator()):
782+
async for record in self.output_pipeline_instance.process_stream(
783+
stream_iterator(), cleanup_sensitive=False
784+
):
775785
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
776786
sse_data = f"data: {chunk}\n\n".encode("utf-8")
777787
chunk_size = hex(len(sse_data))[2:] + "\r\n"
778788
self._proxy_transport_write(chunk_size.encode())
779789
self._proxy_transport_write(sse_data)
780790
self._proxy_transport_write(b"\r\n")
781791

782-
sse_data = b"data: [DONE]\n\n"
783-
# Add chunk size for DONE message too
784-
chunk_size = hex(len(sse_data))[2:] + "\r\n"
785-
self._proxy_transport_write(chunk_size.encode())
786-
self._proxy_transport_write(sse_data)
787-
self._proxy_transport_write(b"\r\n")
788-
# Now send the final zero chunk
789-
self._proxy_transport_write(b"0\r\n\r\n")
792+
if self.finish_stream:
793+
self.finish_data()
790794

791795
except asyncio.CancelledError:
792796
logger.debug("Stream processing cancelled")
@@ -795,12 +799,37 @@ async def stream_iterator():
795799
logger.error(f"Error processing stream: {e}")
796800
finally:
797801
# Clean up
802+
self.stream_queue = None
798803
if self.processing_task and not self.processing_task.done():
799804
self.processing_task.cancel()
800-
if self.proxy.context_tracking and self.proxy.context_tracking.sensitive:
801-
self.proxy.context_tracking.sensitive.secure_cleanup()
805+
806+
def finish_data(self):
807+
logger.debug("Finishing data stream")
808+
sse_data = b"data: [DONE]\n\n"
809+
# Add chunk size for DONE message too
810+
chunk_size = hex(len(sse_data))[2:] + "\r\n"
811+
self._proxy_transport_write(chunk_size.encode())
812+
self._proxy_transport_write(sse_data)
813+
self._proxy_transport_write(b"\r\n")
814+
# Now send the final zero chunk
815+
self._proxy_transport_write(b"0\r\n\r\n")
816+
817+
# For debugging only
818+
# print("===========START DATA SENT====================")
819+
# for data in self.data_sent:
820+
# print(data)
821+
# self.data_sent = []
822+
# print("===========START DATA SENT====================")
823+
824+
self.finish_stream = False
825+
self.headers_sent = False
802826

803827
def _process_chunk(self, chunk: bytes):
828+
# For debugging only
829+
# print("===========START DATA RECVD====================")
830+
# print(chunk)
831+
# print("===========END DATA RECVD======================")
832+
804833
records = self.sse_processor.process_chunk(chunk)
805834

806835
for record in records:
@@ -812,13 +841,12 @@ def _process_chunk(self, chunk: bytes):
812841
self.stream_queue.put_nowait(record)
813842

814843
def _proxy_transport_write(self, data: bytes):
844+
# For debugging only
845+
# self.data_sent.append(data)
815846
if not self.proxy.transport or self.proxy.transport.is_closing():
816847
logger.error("Proxy transport not available")
817848
return
818849
self.proxy.transport.write(data)
819-
# print("DEBUG =================================")
820-
# print(data)
821-
# print("DEBUG =================================")
822850

823851
def data_received(self, data: bytes) -> None:
824852
"""Handle data received from target"""
@@ -848,15 +876,13 @@ def data_received(self, data: bytes) -> None:
848876
logger.debug(f"Headers sent: {headers}")
849877

850878
data = data[header_end + 4 :]
851-
# print("DEBUG =================================")
852-
# print(data)
853-
# print("DEBUG =================================")
854879

855880
self._process_chunk(data)
856881

857882
def connection_lost(self, exc: Optional[Exception]) -> None:
858883
"""Handle connection loss to target"""
859884

885+
logger.debug("Lost connection to target")
860886
if (
861887
not self.proxy._closing
862888
and self.proxy.transport

0 commit comments

Comments
 (0)