From d2eabf7ed0f92b735183369c73479a5fc8b345f7 Mon Sep 17 00:00:00 2001 From: Aagam Dalal Date: Mon, 3 Feb 2025 15:53:25 -0500 Subject: [PATCH 1/3] Add to BrokerType/BrokerName enums + add ton of comments --- .../common/dtos/model_endpoints.py | 2 + .../start_batch_job_orchestration.py | 4 ++ .../inference/async_inference/celery.py | 1 + .../inference/forwarding/celery_forwarder.py | 1 + .../gateways/celery_task_queue_gateway.py | 46 +++++++++++++++---- .../gateways/resources/k8s_resource_types.py | 7 ++- .../service_builder/celery.py | 19 +++++++- .../inference/test_async_inference.py | 3 ++ model-engine/tests/unit/conftest.py | 3 ++ 9 files changed, 75 insertions(+), 11 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index 36a7c7f6..4f5a6092 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -34,6 +34,7 @@ class BrokerType(str, Enum): REDIS_24H = "redis_24h" SQS = "sqs" SERVICEBUS = "servicebus" + GCPPUBSUB = "gcppubsub" class BrokerName(str, Enum): @@ -45,6 +46,7 @@ class BrokerName(str, Enum): REDIS = "redis-message-broker-master" SQS = "sqs-message-broker-master" SERVICEBUS = "servicebus-message-broker-master" + GCPPUBSUB = "gcppubsub-message-broker-master" class CreateModelEndpointV1Request(BaseModel): diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index de1bd59b..ca592117 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -62,6 +62,7 @@ async def run_batch_job( pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) redis = aioredis.Redis(connection_pool=pool) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + gcppubsub_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.GCPPUBSUB) servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS) monitoring_metrics_gateway = get_monitoring_metrics_gateway() @@ -94,6 +95,9 @@ async def run_batch_job( if infra_config().cloud_provider == "azure": inference_task_queue_gateway = servicebus_task_queue_gateway infra_task_queue_gateway = servicebus_task_queue_gateway + elif infra_config().cloud_provider == "gcp": + inference_task_queue_gateway = gcppubsub_task_queue_gateway + infra_task_queue_gateway = gcppubsub_task_queue_gateway else: inference_task_queue_gateway = sqs_task_queue_gateway infra_task_queue_gateway = sqs_task_queue_gateway diff --git a/model-engine/model_engine_server/inference/async_inference/celery.py b/model-engine/model_engine_server/inference/async_inference/celery.py index 3ea5db6d..3a65ad4b 100644 --- a/model-engine/model_engine_server/inference/async_inference/celery.py +++ b/model-engine/model_engine_server/inference/async_inference/celery.py @@ -26,6 +26,7 @@ celery_kwargs.update( dict(broker_transport_options={"predefined_queues": {queue_name: {"url": queue_url}}}) ) +# TODO: is this unused or something? how come we don't have ABS here? async_inference_service = celery_app(**celery_kwargs) # type: ignore diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 27007969..032b9614 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -207,6 +207,7 @@ def entrypoint(): if args.broker_type is None: args.broker_type = str(BrokerType.SQS.value if args.sqs_url else BrokerType.REDIS.value) + # TODO: how come this doesn't have azure (ASB)? forwarder_config = load_named_config(args.config, args.set) forwarder_loader = LoadForwarder(**forwarder_config["async"]) diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index a1b761f0..9e34873f 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -14,7 +14,21 @@ from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway logger = make_logger(logger_name()) -backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3" + + +def get_backend_protocol(): + cloud_provider = infra_config().cloud_provider + if cloud_provider == "azure": + return "abs" + elif cloud_provider == "aws": + return "s3" + elif cloud_provider == "gcp": + return "gcppubsub" + else: + return "s3" # TODO: I feel like we should raise an error here. + + +backend_protocol = get_backend_protocol() celery_redis = celery_app( None, @@ -36,19 +50,33 @@ backend_protocol=backend_protocol, ) celery_servicebus = celery_app( - None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol + None, + broker_type=str(BrokerType.SERVICEBUS.value), + backend_protocol=backend_protocol, + # TODO: check how Azure uses s3 +) + +# XXX: check the next line +celery_gcppubsub = celery_app( + None, + broker_type=str(BrokerType.GCPPUBSUB.value), + backend_protocol=backend_protocol, ) class CeleryTaskQueueGateway(TaskQueueGateway): def __init__(self, broker_type: BrokerType): self.broker_type = broker_type - assert self.broker_type in [ - BrokerType.SQS, - BrokerType.REDIS, - BrokerType.REDIS_24H, - BrokerType.SERVICEBUS, - ] + assert ( + self.broker_type + in [ # TODO: why do have this assert? this is the same as the enum -- is it so we remember to implement it here? + BrokerType.SQS, + BrokerType.REDIS, + BrokerType.REDIS_24H, + BrokerType.SERVICEBUS, + BrokerType.GCPPUBSUB, + ] + ) def _get_celery_dest(self): if self.broker_type == BrokerType.SQS: @@ -57,6 +85,8 @@ def _get_celery_dest(self): return celery_redis_24h elif self.broker_type == BrokerType.REDIS: return celery_redis + elif self.broker_type == BrokerType.GCPPUBSUB: + return celery_gcppubsub else: return celery_servicebus diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 96d5fbb4..1e7cbc44 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -510,8 +510,8 @@ def container_start_triton_cmd( def get_endpoint_resource_arguments_from_request( k8s_resource_group_name: str, request: CreateOrUpdateResourcesRequest, - sqs_queue_name: str, - sqs_queue_url: str, + sqs_queue_name: str, # TODO: check how this is used + sqs_queue_url: str, # TODO: check how this is used endpoint_resource_name: str, api_version: str = "", service_name_override: Optional[str] = None, @@ -560,6 +560,9 @@ def get_endpoint_resource_arguments_from_request( elif infra_config().cloud_provider == "azure": broker_name = BrokerName.SERVICEBUS.value broker_type = BrokerType.SERVICEBUS.value + elif infra_config().cloud_provider == "gcp": # TODO: cloud provider should be an enum, right? + broker_name = BrokerName.GCPPUBSUB.value + broker_type = BrokerType.GCPPUBSUB.value else: broker_name = BrokerName.SQS.value broker_type = BrokerType.SQS.value diff --git a/model-engine/model_engine_server/service_builder/celery.py b/model-engine/model_engine_server/service_builder/celery.py index 06384c9e..b2fbe04d 100644 --- a/model-engine/model_engine_server/service_builder/celery.py +++ b/model-engine/model_engine_server/service_builder/celery.py @@ -3,11 +3,28 @@ from model_engine_server.core.celery import celery_app from model_engine_server.core.config import infra_config + +# TODO: this is copied from celery_task_queue_gateway.py +def get_backend_protocol(): + cloud_provider = infra_config().cloud_provider + if cloud_provider == "azure": + return "abs" + elif cloud_provider == "aws": + return "s3" + elif cloud_provider == "gcp": + return "gcppubsub" + else: + return "s3" # TODO: I feel like we should raise an error here. + + service_builder_broker_type: str +# TODO: this seems redundant? we definitely have other code doing this if CIRCLECI: service_builder_broker_type = str(BrokerType.REDIS.value) elif infra_config().cloud_provider == "azure": service_builder_broker_type = str(BrokerType.SERVICEBUS.value) +elif infra_config().cloud_provider == "gcp": + service_builder_broker_type = str(BrokerType.GCPPUBSUB.value) else: service_builder_broker_type = str(BrokerType.SQS.value) @@ -18,7 +35,7 @@ ], s3_bucket=infra_config().s3_bucket, broker_type=service_builder_broker_type, - backend_protocol="abs" if infra_config().cloud_provider == "azure" else "s3", + backend_protocol=get_backend_protocol(), # TODO: similarly, this has a big overlap with celery_task_queue_gateway.py ) if __name__ == "__main__": diff --git a/model-engine/tests/integration/inference/test_async_inference.py b/model-engine/tests/integration/inference/test_async_inference.py index d1d7f7c5..049ba648 100644 --- a/model-engine/tests/integration/inference/test_async_inference.py +++ b/model-engine/tests/integration/inference/test_async_inference.py @@ -181,3 +181,6 @@ def test_async_callbacks_botocore_exception( queue_name=queue, args=[1, 2], ) + + +# XXX: probably need to add a test for GCP pubsub here diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 3cacdb5a..06197183 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3412,6 +3412,9 @@ def build_endpoint_request_async_runnable_image( return build_endpoint_request +# XXX: add a test for GCP pubsub here + + @pytest.fixture def build_endpoint_request_streaming_runnable_image( test_api_key: str, model_bundle_5: ModelBundle From c4354688cdd5d4ca0ba6080751cca157f003b9d8 Mon Sep 17 00:00:00 2001 From: Aagam Dalal Date: Tue, 4 Feb 2025 12:07:48 -0500 Subject: [PATCH 2/3] Stup gcppubsub_queue_endpoint_resource_delegate and GCPPubSubBroker + add comments --- .../model_engine_server/api/dependencies.py | 9 ++ .../model_engine_server/core/celery/app.py | 13 +- .../core/celery/celery_autoscaler.py | 11 +- ...pubsub_queue_endpoint_resource_delegate.py | 147 ++++++++++++++++++ .../k8s_endpoint_resource_delegate.py | 1 + .../live_endpoint_resource_gateway.py | 1 + .../queue_endpoint_resource_delegate.py | 2 +- .../service_builder/tasks_v1.py | 5 + 8 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 model-engine/model_engine_server/infra/gateways/resources/gcppubsub_queue_endpoint_resource_delegate.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index e120fbf0..f1c221b6 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -89,6 +89,9 @@ from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( FakeQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcppubsub_queue_endpoint_resource_delegate import ( + GCPPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) @@ -201,6 +204,7 @@ def _get_external_interfaces( redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS) + gcppubsub_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.GCPPUBSUB) monitoring_metrics_gateway = get_monitoring_metrics_gateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, @@ -213,6 +217,8 @@ def _get_external_interfaces( queue_delegate = FakeQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + queue_delegate = GCPPubSubQueueEndpointResourceDelegate() else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) @@ -226,6 +232,9 @@ def _get_external_interfaces( elif infra_config().cloud_provider == "azure": inference_task_queue_gateway = servicebus_task_queue_gateway infra_task_queue_gateway = servicebus_task_queue_gateway + elif infra_config().cloud_provider == "gcp": + inference_task_queue_gateway = gcppubsub_task_queue_gateway + infra_task_queue_gateway = gcppubsub_task_queue_gateway else: inference_task_queue_gateway = sqs_task_queue_gateway infra_task_queue_gateway = sqs_task_queue_gateway diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index af7790d1..10872aff 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -116,9 +116,7 @@ def seconds_to_visibility(timeout: int) -> "TaskVisibility": @staticmethod def from_name(name: str) -> "TaskVisibility": # pylint: disable=no-member,protected-access - lookup = { - x.name: x.value for x in TaskVisibility._value2member_map_.values() - } # type: ignore + lookup = {x.name: x.value for x in TaskVisibility._value2member_map_.values()} # type: ignore return TaskVisibility(lookup[name.upper()]) @@ -252,7 +250,7 @@ def celery_app( s3_bucket: Optional[str] = os.environ.get("S3_BUCKET"), s3_base_path: str = "tmp/celery/", backend_protocol: str = "s3", - broker_type: str = "redis", + broker_type: str = "redis", # TODO: should this be an enum aws_role: Optional[str] = None, broker_transport_options: Optional[Dict[str, Any]] = None, **extra_changes, @@ -372,7 +370,7 @@ def celery_app( (1 day by default) and run `celery beat` periodically to clear expired results from Redis. Visit https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html to learn more about celery beat - :param broker_type: [defaults to "redis"] The broker type. We currently support "redis", "sqs", and "servicebus". + :param broker_type: [defaults to "redis"] The broker type. We currently support "redis", "sqs", "servicebus", and "gcppubsub". :param aws_role: [optional] AWS role to use. @@ -509,6 +507,11 @@ def _get_broker_endpoint_and_transport_options( f"azureservicebus://DefaultAzureCredential@{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", out_broker_transport_options, ) + if broker_type == "gcppubsub": + return ( + "TODO", + out_broker_transport_options, # XXX: implement this + ) raise ValueError( f"Only 'redis', 'sqs', and 'servicebus' are supported values for broker_type, got value {broker_type}" diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index 1b74f279..dcce2c3a 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -42,9 +42,11 @@ def excluded_namespaces(): return [] -ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" +# TODO: what this the relationship between these brokers and the ones in model_endpoints.py? +ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" # TODO: This one is different, the others are represented in model_endpoints.py SQS_BROKER = "sqs-message-broker-master" SERVICEBUS_BROKER = "servicebus-message-broker-master" +GCPPUBSUB_BROKER = "gcppubsub-Fmessage-broker-master" UPDATE_DEPLOYMENT_MAX_RETRIES = 10 @@ -306,6 +308,7 @@ def emit_metrics( f"env:{env}", ] statsd.gauge("celery.max_connections", metrics.broker_metrics.max_connections, tags=tags) + # TODO: how does this work in VPCs when we don't have datadog? def emit_health_metric(metric_name: str, env: str): @@ -472,6 +475,11 @@ async def get_broker_metrics( ) # connection_count and max_connections are redis-specific metrics +class GCPPubSubBroker(AutoscalerBroker): + # XXX: finish this + pass + + class ASBBroker(AutoscalerBroker): @staticmethod def _get_asb_queue_size(queue_name: str): @@ -588,6 +596,7 @@ async def main(): ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), SQS_BROKER: SQSBroker(), SERVICEBUS_BROKER: ASBBroker(), + GCPPUBSUB_BROKER: GCPPubSubBroker(), } broker = BROKER_NAME_TO_CLASS[autoscaler_broker] diff --git a/model-engine/model_engine_server/infra/gateways/resources/gcppubsub_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/gcppubsub_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..26987dc5 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/gcppubsub_queue_endpoint_resource_delegate.py @@ -0,0 +1,147 @@ +# fmt: off + +import json +from string import Template +from typing import Any, Dict, Optional, Sequence + +# import botocore.exceptions +from aioboto3 import Session as AioSession +from aiobotocore.client import AioBaseClient +from model_engine_server.common.config import hmi_config +from model_engine_server.core.aws.roles import session +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger + +# from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + +__all__: Sequence[str] = ("GCPPubSubQueueEndpointResourceDelegate",) + + +def _create_async_sqs_client(sqs_profile: Optional[str]) -> AioBaseClient: + return session(role=sqs_profile, session_type=AioSession).client( + "sqs", region_name=infra_config().default_region + ) + + +def _get_queue_policy(queue_name: str) -> str: + queue_policy_template = Template(hmi_config.sqs_queue_policy_template) + return queue_policy_template.substitute(queue_name=queue_name) + + +def _get_queue_tags( + team: str, endpoint_id: str, endpoint_name: str, endpoint_created_by: str +) -> Dict[str, str]: + queue_tag_template = Template(hmi_config.sqs_queue_tag_template) + return json.loads( + queue_tag_template.substitute( + team=team, + endpoint_id=endpoint_id, + endpoint_name=endpoint_name, + endpoint_created_by=endpoint_created_by, + ) + ) + + +class GCPPubSubQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + def __init__(self, sqs_profile: Optional[str] = None): + self.sqs_profile = sqs_profile + + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + raise NotImplementedError("GCP pubsub is not implemented") + # async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client: + # queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + + # try: + # get_queue_url_response = await sqs_client.get_queue_url(QueueName=queue_name) + # return QueueInfo( + # queue_name=queue_name, + # queue_url=get_queue_url_response["QueueUrl"], + # ) + # except botocore.exceptions.ClientError: + # logger.info("Queue does not exist, creating it") + # pass + + # try: + # create_response = await sqs_client.create_queue( + # QueueName=queue_name, + # Attributes=dict( + # VisibilityTimeout="43200", + # # To match current hardcoded Celery timeout of 24hr + # # However, the max SQS visibility is 12hrs. + # Policy=_get_queue_policy(queue_name=queue_name), + # ), + # tags=_get_queue_tags( + # team=endpoint_labels["team"], + # endpoint_id=endpoint_id, + # endpoint_name=endpoint_name, + # endpoint_created_by=endpoint_created_by, + # ), + # ) + # except botocore.exceptions.ClientError as e: + # raise EndpointResourceInfraException("Failed to create SQS queue") from e + + # if create_response["ResponseMetadata"]["HTTPStatusCode"] != 200: + # raise EndpointResourceInfraException( + # f"Creating SQS queue got non-200 response: {create_response}" + # ) + + # return QueueInfo(queue_name, create_response["QueueUrl"]) + + async def delete_queue(self, endpoint_id: str) -> None: + raise NotImplementedError("GCP pubsub is not implemented") + # queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + # async with _create_async_sqs_client(self.sqs_profile) as sqs_client: + # try: + # queue_url = (await sqs_client.get_queue_url(QueueName=queue_name))["QueueUrl"] + # except botocore.exceptions.ClientError: + # logger.info( + # f"Could not get queue url for queue_name={queue_name}, endpoint_id={endpoint_id}, " + # "skipping delete" + # ) + # return + + # try: + # delete_response = await sqs_client.delete_queue(QueueUrl=queue_url) + # except botocore.exceptions.ClientError as e: + # raise EndpointResourceInfraException("Failed to delete SQS queue") from e + + # # Example failed delete: + # # botocore.errorfactory.QueueDoesNotExist: + # # An error occurred (AWS.SimpleQueueService.NonExistentQueue) when calling the GetQueueUrl operation: + # # The specified queue does not exist for this wsdl version. + # if delete_response["ResponseMetadata"]["HTTPStatusCode"] != 200: + # raise EndpointResourceInfraException( + # f"Deleting SQS queue got non-200 response: {delete_response}" + # ) + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + raise NotImplementedError("GCP pubsub is not implemented") + # queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + # async with _create_async_sqs_client(self.sqs_profile) as sqs_client: + # try: + # queue_url = (await sqs_client.get_queue_url(QueueName=queue_name))["QueueUrl"] + # except botocore.exceptions.ClientError as e: + # raise EndpointResourceInfraException( + # f"Could not find queue {queue_name} for endpoint {endpoint_id}" + # ) from e + + # try: + # attributes_response = await sqs_client.get_queue_attributes( + # QueueUrl=queue_url, AttributeNames=["All"] + # ) + # except botocore.exceptions.ClientError as e: + # raise EndpointResourceInfraException("Failed to get SQS queue attributes") from e + + # return attributes_response diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 45ab0d73..d640d595 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -360,6 +360,7 @@ def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None: container["env"] = container_envs +# TODO: figure out what this does class K8SEndpointResourceDelegate: async def create_or_update_resources( self, diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 4e775974..ed2b8419 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -85,6 +85,7 @@ async def get_resources( ) if endpoint_type == ModelEndpointType.ASYNC: + # TODO: this seems poorly named. this isn't just for SQS, right? sqs_attributes = await self.queue_delegate.get_queue_attributes(endpoint_id=endpoint_id) if ( "Attributes" in sqs_attributes diff --git a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py index 76c77e64..aea26402 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py @@ -14,7 +14,7 @@ class QueueInfo(NamedTuple): class QueueEndpointResourceDelegate(ABC): """ - Base class for an interactor with SQS or ASB. This is used by the LiveEndpointResourceGateway. + Base class for an interactor with SQS, ASB, or GCP pubsub. This is used by the LiveEndpointResourceGateway. """ @abstractmethod diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index cd4ff63c..06dc8751 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -28,6 +28,9 @@ from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( FakeQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcppubsub_queue_endpoint_resource_delegate import ( + GCPPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( set_lazy_load_kubernetes_clients, ) @@ -65,6 +68,8 @@ def get_live_endpoint_builder_service( queue_delegate = FakeQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + queue_delegate = GCPPubSubQueueEndpointResourceDelegate() else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) From d8dc4d3cf2b4a8e879c69f3e04a16bdf4edb33ef Mon Sep 17 00:00:00 2001 From: Aagam Dalal Date: Mon, 10 Feb 2025 11:30:06 -0500 Subject: [PATCH 3/3] Update readme + add helper comment in env_vars.py --- model-engine/README.md | 6 +++++- model-engine/model_engine_server/common/env_vars.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/model-engine/README.md b/model-engine/README.md index 3f6b9579..017f22c7 100644 --- a/model-engine/README.md +++ b/model-engine/README.md @@ -37,6 +37,10 @@ pip install -r requirements.txt && \ Run `mypy . --install-types` to set up mypy. +## Running this + +After installing this package, you can view `setup.py` for the list of entrypoints (for instance, there's a command for how to run the api server) + ## Testing Most of the business logic in Model Engine should contain unit tests, located in @@ -46,4 +50,4 @@ Most of the business logic in Model Engine should contain unit tests, located in We've decided to make our V2 APIs OpenAI compatible. We generate the corresponding Pydantic models: 1. Fetch the OpenAPI spec from https://github.com/openai/openai-openapi/blob/master/openapi.yaml -2. Run scripts/generate-openai-types.sh +2. Run scripts/generate-openai-types.sh \ No newline at end of file diff --git a/model-engine/model_engine_server/common/env_vars.py b/model-engine/model_engine_server/common/env_vars.py index 2a69cbff..5cb29074 100644 --- a/model-engine/model_engine_server/common/env_vars.py +++ b/model-engine/model_engine_server/common/env_vars.py @@ -75,6 +75,7 @@ def get_boolean_env_var(name: str) -> bool: if LOCAL: logger.warning("LOCAL development & testing mode is ON") +# TODO: add a comment here once we understand what this does. GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND") if GIT_TAG == "GIT_TAG_NOT_FOUND" and "pytest" not in sys.modules: raise ValueError("GIT_TAG environment variable must be set")