@@ -151,6 +151,20 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
151
151
self ._closing = False
152
152
self .pipeline_factory = PipelineFactory (SecretsManager ())
153
153
self .context_tracking : Optional [PipelineContext ] = None
154
+ self .idle_timeout = 10
155
+ self .idle_timer = None
156
+
157
+ def _reset_idle_timer (self ) -> None :
158
+ if self .idle_timer :
159
+ self .idle_timer .cancel ()
160
+ self .idle_timer = asyncio .get_event_loop ().call_later (
161
+ self .idle_timeout , self ._handle_idle_timeout
162
+ )
163
+
164
+ def _handle_idle_timeout (self ) -> None :
165
+ logger .warning ("Idle timeout reached, closing connection" )
166
+ if self .transport and not self .transport .is_closing ():
167
+ self .transport .close ()
154
168
155
169
def _select_pipeline (self , method : str , path : str ) -> Optional [CopilotPipeline ]:
156
170
if method == "POST" and path == "v1/engines/copilot-codex/completions" :
@@ -215,6 +229,7 @@ def connection_made(self, transport: asyncio.Transport) -> None:
215
229
self .transport = transport
216
230
self .peername = transport .get_extra_info ("peername" )
217
231
logger .debug (f"Client connected from { self .peername } " )
232
+ self ._reset_idle_timer ()
218
233
219
234
def get_headers_dict (self ) -> Dict [str , str ]:
220
235
"""Convert raw headers to dictionary format"""
@@ -350,8 +365,10 @@ async def _forward_data_to_target(self, data: bytes) -> None:
350
365
pipeline_output = pipeline_output .reconstruct ()
351
366
self .target_transport .write (pipeline_output )
352
367
368
+
353
369
def data_received (self , data : bytes ) -> None :
354
370
"""Handle received data from client"""
371
+ self ._reset_idle_timer ()
355
372
try :
356
373
if not self ._check_buffer_size (data ):
357
374
self .send_error_response (413 , b"Request body too large" )
@@ -556,6 +573,7 @@ async def connect_to_target(self) -> None:
556
573
logger .error (f"Error during TLS handshake: { e } " )
557
574
self .send_error_response (502 , b"TLS handshake failed" )
558
575
576
+
559
577
def send_error_response (self , status : int , message : bytes ) -> None :
560
578
"""Send error response to client"""
561
579
if self ._closing :
@@ -593,6 +611,37 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
593
611
self .buffer .clear ()
594
612
self .ssl_context = None
595
613
614
+ if self .idle_timer :
615
+ self .idle_timer .cancel ()
616
+
617
+ def eof_received (self ) -> None :
618
+ print ("in eof received" )
619
+ """Handle connection loss"""
620
+ if self ._closing :
621
+ return
622
+
623
+ self ._closing = True
624
+ logger .debug (f"EOF received from { self .peername } " )
625
+
626
+ # Close target transport if it exists and isn't already closing
627
+ if self .target_transport and not self .target_transport .is_closing ():
628
+ try :
629
+ self .target_transport .close ()
630
+ except Exception as e :
631
+ logger .error (f"Error closing target transport when EOF: { e } " )
632
+
633
+ # Clear references to help with cleanup
634
+ self .transport = None
635
+ self .target_transport = None
636
+ self .buffer .clear ()
637
+ self .ssl_context = None
638
+
639
+ def pause_writing (self ) -> None :
640
+ print ("Transport buffer full, pausing writing" )
641
+
642
+ def resume_writing (self ) -> None :
643
+ print ("Transport buffer ready, resuming writing" )
644
+
596
645
@classmethod
597
646
async def create_proxy_server (
598
647
cls , host : str , port : int , ssl_context : Optional [ssl .SSLContext ] = None
0 commit comments