diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 59c73981..eb92d8a0 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -277,6 +277,13 @@ def __init__( self.secret_manager = secret_manager self.is_fim = is_fim self.context = PipelineContext() + + # we create the sesitive context here so that it is not shared between individual requests + # TODO: could we get away with just generating the session ID for an instance? + self.context.sensitive = PipelineSensitiveData( + manager=self.secret_manager, + session_id=str(uuid.uuid4()), + ) self.context.metadata["is_fim"] = is_fim async def process_request( @@ -290,17 +297,14 @@ async def process_request( is_copilot: bool = False, ) -> PipelineResult: """Process a request through all pipeline steps""" - self.context.sensitive = PipelineSensitiveData( - manager=self.secret_manager, - session_id=str(uuid.uuid4()), - api_key=api_key, - model=model, - provider=provider, - api_base=api_base, - ) self.context.metadata["extra_headers"] = extra_headers current_request = request + self.context.sensitive.api_key = api_key + self.context.sensitive.model = model + self.context.sensitive.provider = provider + self.context.sensitive.api_base = api_base + # For Copilot provider=openai. Use a flag to not clash with other places that may use that. provider_db = "copilot" if is_copilot else provider @@ -336,8 +340,9 @@ def __init__( self.pipeline_steps = pipeline_steps self.secret_manager = secret_manager self.is_fim = is_fim + self.instance = self._create_instance() - def create_instance(self) -> InputPipelineInstance: + def _create_instance(self) -> InputPipelineInstance: """Create a new pipeline instance for processing a request""" return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim) @@ -352,7 +357,6 @@ async def process_request( is_copilot: bool = False, ) -> PipelineResult: """Create a new pipeline instance and process the request""" - instance = self.create_instance() - return await instance.process_request( + return await self.instance.process_request( request, provider, model, api_key, api_base, extra_headers, is_copilot ) diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 89c31c59..f5bb716a 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -162,6 +162,10 @@ async def process_stream( logger.error(f"Error processing stream: {e}") raise e finally: + # Don't flush the buffer if we assume we'll call the pipeline again + if cleanup_sensitive is False: + return + # Process any remaining content in buffer when stream ends if self._context.buffer: final_content = "".join(self._context.buffer) diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py index a7b32319..bef07c75 100644 --- a/src/codegate/pipeline/secrets/manager.py +++ b/src/codegate/pipeline/secrets/manager.py @@ -21,7 +21,7 @@ class SecretsManager: def __init__(self): self.crypto = CodeGateCrypto() - self._session_store: dict[str, SecretEntry] = {} + self._session_store: dict[str, dict[str, SecretEntry]] = {} self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str: @@ -41,12 +41,14 @@ def store_secret(self, value: str, service: str, secret_type: str, session_id: s encrypted_value = self.crypto.encrypt_token(value, session_id) # Store mappings - self._session_store[session_id] = SecretEntry( + session_secrets = self._session_store.get(session_id, {}) + session_secrets[encrypted_value] = SecretEntry( original=value, encrypted=encrypted_value, service=service, secret_type=secret_type, ) + self._session_store[session_id] = session_secrets self._encrypted_to_session[encrypted_value] = session_id logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value) @@ -58,7 +60,9 @@ def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[ try: stored_session_id = self._encrypted_to_session.get(encrypted_value) if stored_session_id == session_id: - return self._session_store[session_id].original + session_secrets = self._session_store[session_id].get(encrypted_value) + if session_secrets: + return session_secrets.original except Exception as e: logger.error("Error retrieving secret", error=str(e)) return None @@ -71,9 +75,10 @@ def cleanup(self): """Securely wipe sensitive data""" try: # Convert and wipe original values - for entry in self._session_store.values(): - original_bytes = bytearray(entry.original.encode()) - self.crypto.wipe_bytearray(original_bytes) + for secrets in self._session_store.values(): + for entry in secrets.values(): + original_bytes = bytearray(entry.original.encode()) + self.crypto.wipe_bytearray(original_bytes) # Clear the dictionaries self._session_store.clear() @@ -92,9 +97,9 @@ def cleanup_session(self, session_id: str): """ try: # Get the secret entry for the session - entry = self._session_store.get(session_id) + secrets = self._session_store.get(session_id, {}) - if entry: + for entry in secrets.values(): # Securely wipe the original value original_bytes = bytearray(entry.original.encode()) self.crypto.wipe_bytearray(original_bytes) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 606db5bb..43ec17a8 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -362,6 +362,7 @@ async def process_chunk( if match: # Found a complete marker, process it encrypted_value = match.group(1) + print("----> encrypted_value: ", encrypted_value) original_value = input_context.sensitive.manager.get_original_value( encrypted_value, input_context.sensitive.session_id, @@ -370,6 +371,8 @@ async def process_chunk( if original_value is None: # If value not found, leave as is original_value = match.group(0) # Keep the REDACTED marker + else: + print("----> original_value: ", original_value) # Post an alert with the redacted content input_context.add_alert(self.name, trigger_string=encrypted_value) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 78c5df72..2daa5a8d 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -1,11 +1,9 @@ import json -from typing import Optional import structlog from fastapi import Header, HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion from codegate.providers.base import BaseProvider @@ -15,20 +13,14 @@ class AnthropicProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index b529ad3b..dc45616e 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -10,9 +10,9 @@ from codegate.pipeline.base import ( PipelineContext, PipelineResult, - SequentialPipelineProcessor, ) -from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory +from codegate.pipeline.output import OutputPipelineInstance from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer @@ -34,19 +34,13 @@ def __init__( input_normalizer: ModelInputNormalizer, output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): self.router = APIRouter() self._completion_handler = completion_handler self._input_normalizer = input_normalizer self._output_normalizer = output_normalizer - self._pipeline_processor = pipeline_processor - self._fim_pipelin_processor = fim_pipeline_processor - self._output_pipeline_processor = output_pipeline_processor - self._fim_output_pipeline_processor = fim_output_pipeline_processor + self._pipeline_factory = pipeline_factory self._db_recorder = DbRecorder() self._pipeline_response_formatter = PipelineResponseFormatter( output_normalizer, self._db_recorder @@ -73,10 +67,10 @@ async def _run_output_stream_pipeline( # Decide which pipeline processor to use out_pipeline_processor = None if is_fim_request: - out_pipeline_processor = self._fim_output_pipeline_processor + out_pipeline_processor = self._pipeline_factory.create_fim_output_pipeline() logger.info("FIM pipeline selected for output.") else: - out_pipeline_processor = self._output_pipeline_processor + out_pipeline_processor = self._pipeline_factory.create_output_pipeline() logger.info("Chat completion pipeline selected for output.") if out_pipeline_processor is None: logger.info("No output pipeline processor found, passing through") @@ -117,11 +111,11 @@ async def _run_input_pipeline( ) -> PipelineResult: # Decide which pipeline processor to use if is_fim_request: - pipeline_processor = self._fim_pipelin_processor + pipeline_processor = self._pipeline_factory.create_fim_pipeline() logger.info("FIM pipeline selected for execution.") normalized_request = self._fim_normalizer.normalize(normalized_request) else: - pipeline_processor = self._pipeline_processor + pipeline_processor = self._pipeline_factory.create_input_pipeline() logger.info("Chat completion pipeline selected for execution.") if pipeline_processor is None: return PipelineResult(request=normalized_request) diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index 5268aeaa..d1ef13da 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -24,6 +24,7 @@ class CopilotPipeline(ABC): def __init__(self, pipeline_factory: PipelineFactory): self.pipeline_factory = pipeline_factory + self.instance = self._create_pipeline() self.normalizer = self._create_normalizer() self.provider_name = "openai" @@ -33,7 +34,7 @@ def _create_normalizer(self): pass @abstractmethod - def create_pipeline(self) -> SequentialPipelineProcessor: + def _create_pipeline(self) -> SequentialPipelineProcessor: """Each strategy defines which pipeline to create""" pass @@ -84,7 +85,11 @@ def _create_shortcut_response(result: PipelineResult, model: str) -> bytes: body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode() return body - async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]: + async def process_body( + self, + headers: list[str], + body: bytes, + ) -> Tuple[bytes, PipelineContext | None]: """Common processing logic for all strategies""" try: normalized_body = self.normalizer.normalize(body) @@ -97,8 +102,7 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi except ValueError: continue - pipeline = self.create_pipeline() - result = await pipeline.process_request( + result = await self.instance.process_request( request=normalized_body, provider=self.provider_name, model=normalized_body.get("model", "gpt-4o-mini"), @@ -168,10 +172,13 @@ class CopilotFimPipeline(CopilotPipeline): format and the FIM pipeline used by all providers. """ + def __init__(self, pipeline_factory: PipelineFactory): + super().__init__(pipeline_factory) + def _create_normalizer(self): return CopilotFimNormalizer() - def create_pipeline(self) -> SequentialPipelineProcessor: + def _create_pipeline(self) -> SequentialPipelineProcessor: return self.pipeline_factory.create_fim_pipeline() @@ -181,8 +188,11 @@ class CopilotChatPipeline(CopilotPipeline): format and the FIM pipeline used by all providers. """ + def __init__(self, pipeline_factory: PipelineFactory): + super().__init__(pipeline_factory) + def _create_normalizer(self): return CopilotChatNormalizer() - def create_pipeline(self) -> SequentialPipelineProcessor: + def _create_pipeline(self) -> SequentialPipelineProcessor: return self.pipeline_factory.create_input_pipeline() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index f8bfeb1e..6a3adf67 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -150,8 +150,16 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.cert_manager = TLSCertDomainManager(self.ca) self._closing = False self.pipeline_factory = PipelineFactory(SecretsManager()) + self.input_pipeline: Optional[CopilotPipeline] = None + self.fim_pipeline: Optional[CopilotPipeline] = None + # the context as provided by the pipeline self.context_tracking: Optional[PipelineContext] = None + def _ensure_pipelines(self): + if not self.input_pipeline or not self.fim_pipeline: + self.input_pipeline = CopilotChatPipeline(self.pipeline_factory) + self.fim_pipeline = CopilotFimPipeline(self.pipeline_factory) + def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: if method != "POST": logger.debug("Not a POST request, no pipeline selected") @@ -161,10 +169,10 @@ def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: if path == route.path: if route.pipeline_type == PipelineType.FIM: logger.debug("Selected FIM pipeline") - return CopilotFimPipeline(self.pipeline_factory) + return self.fim_pipeline elif route.pipeline_type == PipelineType.CHAT: logger.debug("Selected CHAT pipeline") - return CopilotChatPipeline(self.pipeline_factory) + return self.input_pipeline logger.debug("No pipeline selected") return None @@ -181,7 +189,6 @@ async def _body_through_pipeline( # if we didn't select any strategy that would change the request # let's just pass through the body as-is return body, None - logger.debug(f"Processing body through pipeline: {len(body)} bytes") return await strategy.process_body(headers, body) async def _request_to_target(self, headers: list[str], body: bytes): @@ -222,15 +229,15 @@ def connection_made(self, transport: asyncio.Transport) -> None: self.peername = transport.get_extra_info("peername") logger.debug(f"Client connected from {self.peername}") - def get_headers_dict(self) -> Dict[str, str]: + def get_headers_dict(self, complete_request) -> Dict[str, str]: """Convert raw headers to dictionary format""" headers_dict = {} try: - if b"\r\n\r\n" not in self.buffer: + if b"\r\n\r\n" not in complete_request: return {} - headers_end = self.buffer.index(b"\r\n\r\n") - headers = self.buffer[:headers_end].split(b"\r\n")[1:] + headers_end = complete_request.index(b"\r\n\r\n") + headers = complete_request[:headers_end].split(b"\r\n")[1:] for header in headers: try: @@ -288,6 +295,9 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest http_request.headers, http_request.body, ) + # TODO: it's weird that we're overwriting the context. + # Should we set the context once? Maybe when + # creating the pipeline instance? self.context_tracking = context if context and context.shortcut_response: @@ -439,31 +449,50 @@ def data_received(self, data: bytes) -> None: self.buffer.extend(data) - if not self.headers_parsed: - self.headers_parsed = self.parse_headers() - if self.headers_parsed: - if self.request.method == "CONNECT": - self.handle_connect() - self.buffer.clear() - else: - # Only process the request once we have the complete body - asyncio.create_task(self.handle_http_request()) - else: - if self._has_complete_body(): - # Process the complete request through the pipeline - complete_request = bytes(self.buffer) - # logger.debug(f"Complete request: {complete_request}") - self.buffer.clear() - asyncio.create_task(self._forward_data_to_target(complete_request)) + while self.buffer: # Process as many complete requests as we have + if not self.headers_parsed: + self.headers_parsed = self.parse_headers() + if self.headers_parsed: + self._ensure_pipelines() + if self.request.method == "CONNECT": + if self._has_complete_body(): + self.handle_connect() + self.buffer.clear() # CONNECT requests are handled differently + break # CONNECT handling complete + elif self._has_complete_body(): + # Find where this request ends + headers_end = self.buffer.index(b"\r\n\r\n") + headers = self.buffer[:headers_end].split(b"\r\n")[1:] + content_length = 0 + for header in headers: + if header.lower().startswith(b"content-length:"): + content_length = int(header.split(b":", 1)[1]) + break + + request_end = headers_end + 4 + content_length + complete_request = self.buffer[:request_end] + + self.buffer = self.buffer[request_end:] # Keep remaining data + + self.headers_parsed = False # Reset for next request + + asyncio.create_task(self.handle_http_request(complete_request)) + break # Either processing request or need more data + else: + if self._has_complete_body(): + complete_request = bytes(self.buffer) + self.buffer.clear() # Clear buffer for next request + asyncio.create_task(self._forward_data_to_target(complete_request)) + break # Either processing request or need more data except Exception as e: logger.error(f"Error processing received data: {e}") self.send_error_response(502, str(e).encode()) - async def handle_http_request(self) -> None: + async def handle_http_request(self, complete_request: bytes) -> None: """Handle standard HTTP request""" try: - target_url = await self._get_target_url() + target_url = await self._get_target_url(complete_request) except Exception as e: logger.error(f"Error getting target URL: {e}") self.send_error_response(404, b"Not Found") @@ -507,9 +536,9 @@ async def handle_http_request(self) -> None: new_headers.append(f"Host: {self.target_host}") if self.target_transport: - if self.buffer: - body_start = self.buffer.index(b"\r\n\r\n") + 4 - body = self.buffer[body_start:] + if complete_request: + body_start = complete_request.index(b"\r\n\r\n") + 4 + body = complete_request[body_start:] await self._request_to_target(new_headers, body) else: # just skip it @@ -521,9 +550,9 @@ async def handle_http_request(self) -> None: logger.error(f"Error preparing or sending request to target: {e}") self.send_error_response(502, b"Bad Gateway") - async def _get_target_url(self) -> Optional[str]: + async def _get_target_url(self, complete_request) -> Optional[str]: """Determine target URL based on request path and headers""" - headers_dict = self.get_headers_dict() + headers_dict = self.get_headers_dict(complete_request) auth_header = headers_dict.get("authorization", "") if auth_header: @@ -756,10 +785,12 @@ def connection_made(self, transport: asyncio.Transport) -> None: def _ensure_output_processor(self) -> None: if self.proxy.context_tracking is None: + logger.debug("No context tracking, no need to process pipeline") # No context tracking, no need to process pipeline return if self.sse_processor is not None: + logger.debug("Already initialized, no need to reinitialize") # Already initialized, no need to reinitialize return diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 37dc64d1..7f90619e 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,11 +1,9 @@ import json -from typing import Optional import structlog from fastapi import HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer @@ -14,20 +12,14 @@ class LlamaCppProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = LlamaCppCompletionHandler() super().__init__( LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index 49fbc103..f569d988 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -16,6 +16,7 @@ async def ollama_stream_generator( """OpenAI-style SSE format""" try: async for chunk in stream: + print(chunk) try: yield f"{chunk.model_dump_json()}\n\n" except Exception as e: diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index f8e901d4..8307f7e0 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -1,13 +1,11 @@ import json -from typing import Optional import httpx import structlog from fastapi import HTTPException, Request from codegate.config import Config -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaShim @@ -16,10 +14,7 @@ class OllamaProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): config = Config.get_config() if config is None: @@ -32,9 +27,7 @@ def __init__( OllamaInputNormalizer(), OllamaOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, + pipeline_factory, ) @property @@ -45,6 +38,7 @@ def _setup_routes(self): """ Sets up Ollama API routes. """ + @self.router.get(f"/{self.provider_route_name}/api/tags") async def get_tags(request: Request): """ diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 75c201da..53aa7db8 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,11 +1,9 @@ import json -from typing import Optional import structlog from fastapi import Header, HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -14,20 +12,14 @@ class OpenAIProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 6d7b9bca..f39ed8d6 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -1,5 +1,4 @@ import json -from typing import Optional import httpx import structlog @@ -7,8 +6,7 @@ from litellm import atext_completion from codegate.config import Config -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -17,10 +15,7 @@ class VLLMProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = LiteLLmShim( stream_generator=sse_stream_generator, fim_completion_func=atext_completion @@ -29,10 +24,7 @@ def __init__( VLLMInputNormalizer(), VLLMOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/server.py b/src/codegate/server.py index 57206712..9ea9e569 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -51,47 +51,30 @@ def init_app(pipeline_factory: PipelineFactory) -> FastAPI: # Register all known providers registry.add_provider( "openai", - OpenAIProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), - ), + OpenAIProvider(pipeline_factory), ) registry.add_provider( "anthropic", AnthropicProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) registry.add_provider( "llamacpp", LlamaCppProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) registry.add_provider( "vllm", VLLMProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) registry.add_provider( "ollama", OllamaProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py index 5cb06ade..177e8f3f 100644 --- a/tests/pipeline/secrets/test_manager.py +++ b/tests/pipeline/secrets/test_manager.py @@ -1,6 +1,6 @@ import pytest -from codegate.pipeline.secrets.manager import SecretEntry, SecretsManager +from codegate.pipeline.secrets.manager import SecretsManager class TestSecretsManager: @@ -21,11 +21,8 @@ def test_store_secret(self): # Verify the secret was stored stored = self.manager.get_by_session_id(self.test_session) - assert isinstance(stored, SecretEntry) - assert stored.original == self.test_value - assert stored.encrypted == encrypted - assert stored.service == self.test_service - assert stored.secret_type == self.test_type + assert isinstance(stored, dict) + assert stored[encrypted].original == self.test_value # Verify encrypted value can be retrieved retrieved = self.manager.get_original_value(encrypted, self.test_session) @@ -86,10 +83,15 @@ def test_multiple_secrets_same_session(self): encrypted1 = self.manager.store_secret("secret1", "service1", "type1", self.test_session) encrypted2 = self.manager.store_secret("secret2", "service2", "type2", self.test_session) - # Latest secret should be retrievable + # Latest secret should be retrievable in the session stored = self.manager.get_by_session_id(self.test_session) - assert stored.original == "secret2" - assert stored.encrypted == encrypted2 + assert isinstance(stored, dict) + assert stored[encrypted1].original == "secret1" + assert stored[encrypted2].original == "secret2" + + # Both secrets should be retrievable directly + assert self.manager.get_original_value(encrypted1, self.test_session) == "secret1" + assert self.manager.get_original_value(encrypted2, self.test_session) == "secret2" # Both encrypted values should map to the session assert self.manager._encrypted_to_session[encrypted1] == self.test_session @@ -119,7 +121,7 @@ def test_secure_cleanup(self): # Get reference to stored data before cleanup stored = self.manager.get_by_session_id(self.test_session) - original_value = stored.original + assert len(stored) == 1 # Perform cleanup self.manager.cleanup() @@ -127,7 +129,6 @@ def test_secure_cleanup(self): # Verify the original string was overwritten, not just removed # This test is a bit tricky since Python strings are immutable, # but we can at least verify the data is no longer accessible - assert original_value not in str(self.manager._session_store) assert self.test_value not in str(self.manager._session_store) def test_session_isolation(self): diff --git a/tests/test_provider.py b/tests/test_provider.py index f2c4011f..95361c97 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -11,14 +11,12 @@ def __init__(self): mocked_input_normalizer = MagicMock() mocked_output_normalizer = MagicMock() mocked_completion_handler = MagicMock() - mocked_pipepeline = MagicMock() - mocked_fim_pipeline = MagicMock() + mocked_factory = MagicMock() super().__init__( mocked_input_normalizer, mocked_output_normalizer, mocked_completion_handler, - mocked_pipepeline, - mocked_fim_pipeline, + mocked_factory, ) def _setup_routes(self) -> None: