Skip to content

Commit 1d0c5f5

Browse files
authored
Merge pull request #516 from stacklok/debug_idle_connections
fix: add monitoring for idle connections and close them
2 parents c23a481 + e3829a6 commit 1d0c5f5

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

src/codegate/providers/copilot/provider.py

+49
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,20 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
151151
self._closing = False
152152
self.pipeline_factory = PipelineFactory(SecretsManager())
153153
self.context_tracking: Optional[PipelineContext] = None
154+
self.idle_timeout = 10
155+
self.idle_timer = None
156+
157+
def _reset_idle_timer(self) -> None:
158+
if self.idle_timer:
159+
self.idle_timer.cancel()
160+
self.idle_timer = asyncio.get_event_loop().call_later(
161+
self.idle_timeout, self._handle_idle_timeout
162+
)
163+
164+
def _handle_idle_timeout(self) -> None:
165+
logger.warning("Idle timeout reached, closing connection")
166+
if self.transport and not self.transport.is_closing():
167+
self.transport.close()
154168

155169
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
156170
if method == "POST" and path == "v1/engines/copilot-codex/completions":
@@ -215,6 +229,7 @@ def connection_made(self, transport: asyncio.Transport) -> None:
215229
self.transport = transport
216230
self.peername = transport.get_extra_info("peername")
217231
logger.debug(f"Client connected from {self.peername}")
232+
self._reset_idle_timer()
218233

219234
def get_headers_dict(self) -> Dict[str, str]:
220235
"""Convert raw headers to dictionary format"""
@@ -350,8 +365,10 @@ async def _forward_data_to_target(self, data: bytes) -> None:
350365
pipeline_output = pipeline_output.reconstruct()
351366
self.target_transport.write(pipeline_output)
352367

368+
353369
def data_received(self, data: bytes) -> None:
354370
"""Handle received data from client"""
371+
self._reset_idle_timer()
355372
try:
356373
if not self._check_buffer_size(data):
357374
self.send_error_response(413, b"Request body too large")
@@ -556,6 +573,7 @@ async def connect_to_target(self) -> None:
556573
logger.error(f"Error during TLS handshake: {e}")
557574
self.send_error_response(502, b"TLS handshake failed")
558575

576+
559577
def send_error_response(self, status: int, message: bytes) -> None:
560578
"""Send error response to client"""
561579
if self._closing:
@@ -593,6 +611,37 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
593611
self.buffer.clear()
594612
self.ssl_context = None
595613

614+
if self.idle_timer:
615+
self.idle_timer.cancel()
616+
617+
def eof_received(self) -> None:
618+
print("in eof received")
619+
"""Handle connection loss"""
620+
if self._closing:
621+
return
622+
623+
self._closing = True
624+
logger.debug(f"EOF received from {self.peername}")
625+
626+
# Close target transport if it exists and isn't already closing
627+
if self.target_transport and not self.target_transport.is_closing():
628+
try:
629+
self.target_transport.close()
630+
except Exception as e:
631+
logger.error(f"Error closing target transport when EOF: {e}")
632+
633+
# Clear references to help with cleanup
634+
self.transport = None
635+
self.target_transport = None
636+
self.buffer.clear()
637+
self.ssl_context = None
638+
639+
def pause_writing(self) -> None:
640+
print("Transport buffer full, pausing writing")
641+
642+
def resume_writing(self) -> None:
643+
print("Transport buffer ready, resuming writing")
644+
596645
@classmethod
597646
async def create_proxy_server(
598647
cls, host: str, port: int, ssl_context: Optional[ssl.SSLContext] = None

0 commit comments

Comments
 (0)