Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GCP] Aagamdalal/sgp 3575 model engine update pubsub model engine code #686

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion model-engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ pip install -r requirements.txt && \

Run `mypy . --install-types` to set up mypy.

## Running this

After installing this package, you can view `setup.py` for the list of entrypoints (for instance, there's a command for how to run the api server)

## Testing

Most of the business logic in Model Engine should contain unit tests, located in
Expand All @@ -46,4 +50,4 @@ Most of the business logic in Model Engine should contain unit tests, located in
We've decided to make our V2 APIs OpenAI compatible. We generate the
corresponding Pydantic models:
1. Fetch the OpenAPI spec from https://github.com/openai/openai-openapi/blob/master/openapi.yaml
2. Run scripts/generate-openai-types.sh
2. Run scripts/generate-openai-types.sh
9 changes: 9 additions & 0 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcppubsub_queue_endpoint_resource_delegate import (
GCPPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
Expand Down Expand Up @@ -201,6 +204,7 @@ def _get_external_interfaces(
redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H)
sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS)
servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS)
gcppubsub_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.GCPPUBSUB)
monitoring_metrics_gateway = get_monitoring_metrics_gateway()
model_endpoint_record_repo = DbModelEndpointRecordRepository(
monitoring_metrics_gateway=monitoring_metrics_gateway,
Expand All @@ -213,6 +217,8 @@ def _get_external_interfaces(
queue_delegate = FakeQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
queue_delegate = GCPPubSubQueueEndpointResourceDelegate()
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -226,6 +232,9 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
inference_task_queue_gateway = gcppubsub_task_queue_gateway
infra_task_queue_gateway = gcppubsub_task_queue_gateway
else:
inference_task_queue_gateway = sqs_task_queue_gateway
infra_task_queue_gateway = sqs_task_queue_gateway
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class BrokerType(str, Enum):
REDIS_24H = "redis_24h"
SQS = "sqs"
SERVICEBUS = "servicebus"
GCPPUBSUB = "gcppubsub"


class BrokerName(str, Enum):
Expand All @@ -45,6 +46,7 @@ class BrokerName(str, Enum):
REDIS = "redis-message-broker-master"
SQS = "sqs-message-broker-master"
SERVICEBUS = "servicebus-message-broker-master"
GCPPUBSUB = "gcppubsub-message-broker-master"


class CreateModelEndpointV1Request(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/common/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_boolean_env_var(name: str) -> bool:
if LOCAL:
logger.warning("LOCAL development & testing mode is ON")

# TODO: add a comment here once we understand what this does.
GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND")
if GIT_TAG == "GIT_TAG_NOT_FOUND" and "pytest" not in sys.modules:
raise ValueError("GIT_TAG environment variable must be set")
13 changes: 8 additions & 5 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def seconds_to_visibility(timeout: int) -> "TaskVisibility":
@staticmethod
def from_name(name: str) -> "TaskVisibility":
# pylint: disable=no-member,protected-access
lookup = {
x.name: x.value for x in TaskVisibility._value2member_map_.values()
} # type: ignore
lookup = {x.name: x.value for x in TaskVisibility._value2member_map_.values()} # type: ignore
return TaskVisibility(lookup[name.upper()])


Expand Down Expand Up @@ -252,7 +250,7 @@ def celery_app(
s3_bucket: Optional[str] = os.environ.get("S3_BUCKET"),
s3_base_path: str = "tmp/celery/",
backend_protocol: str = "s3",
broker_type: str = "redis",
broker_type: str = "redis", # TODO: should this be an enum
aws_role: Optional[str] = None,
broker_transport_options: Optional[Dict[str, Any]] = None,
**extra_changes,
Expand Down Expand Up @@ -372,7 +370,7 @@ def celery_app(
(1 day by default) and run `celery beat` periodically to clear expired results from Redis. Visit
https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html to learn more about celery beat

:param broker_type: [defaults to "redis"] The broker type. We currently support "redis", "sqs", and "servicebus".
:param broker_type: [defaults to "redis"] The broker type. We currently support "redis", "sqs", "servicebus", and "gcppubsub".

:param aws_role: [optional] AWS role to use.

Expand Down Expand Up @@ -509,6 +507,11 @@ def _get_broker_endpoint_and_transport_options(
f"azureservicebus://DefaultAzureCredential@{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net",
out_broker_transport_options,
)
if broker_type == "gcppubsub":
return (
"TODO",
out_broker_transport_options, # XXX: implement this
)

raise ValueError(
f"Only 'redis', 'sqs', and 'servicebus' are supported values for broker_type, got value {broker_type}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def excluded_namespaces():
return []


ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master"
# TODO: what this the relationship between these brokers and the ones in model_endpoints.py?
ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" # TODO: This one is different, the others are represented in model_endpoints.py
SQS_BROKER = "sqs-message-broker-master"
SERVICEBUS_BROKER = "servicebus-message-broker-master"
GCPPUBSUB_BROKER = "gcppubsub-Fmessage-broker-master"

UPDATE_DEPLOYMENT_MAX_RETRIES = 10

Expand Down Expand Up @@ -306,6 +308,7 @@ def emit_metrics(
f"env:{env}",
]
statsd.gauge("celery.max_connections", metrics.broker_metrics.max_connections, tags=tags)
# TODO: how does this work in VPCs when we don't have datadog?


def emit_health_metric(metric_name: str, env: str):
Expand Down Expand Up @@ -472,6 +475,11 @@ async def get_broker_metrics(
) # connection_count and max_connections are redis-specific metrics


class GCPPubSubBroker(AutoscalerBroker):
# XXX: finish this
pass


class ASBBroker(AutoscalerBroker):
@staticmethod
def _get_asb_queue_size(queue_name: str):
Expand Down Expand Up @@ -588,6 +596,7 @@ async def main():
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True),
SQS_BROKER: SQSBroker(),
SERVICEBUS_BROKER: ASBBroker(),
GCPPUBSUB_BROKER: GCPPubSubBroker(),
}

broker = BROKER_NAME_TO_CLASS[autoscaler_broker]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def run_batch_job(
pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url)
redis = aioredis.Redis(connection_pool=pool)
sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS)
gcppubsub_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.GCPPUBSUB)
servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS)

monitoring_metrics_gateway = get_monitoring_metrics_gateway()
Expand Down Expand Up @@ -94,6 +95,9 @@ async def run_batch_job(
if infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
inference_task_queue_gateway = gcppubsub_task_queue_gateway
infra_task_queue_gateway = gcppubsub_task_queue_gateway
else:
inference_task_queue_gateway = sqs_task_queue_gateway
infra_task_queue_gateway = sqs_task_queue_gateway
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
celery_kwargs.update(
dict(broker_transport_options={"predefined_queues": {queue_name: {"url": queue_url}}})
)
# TODO: is this unused or something? how come we don't have ABS here?

async_inference_service = celery_app(**celery_kwargs) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def entrypoint():

if args.broker_type is None:
args.broker_type = str(BrokerType.SQS.value if args.sqs_url else BrokerType.REDIS.value)
# TODO: how come this doesn't have azure (ASB)?

forwarder_config = load_named_config(args.config, args.set)
forwarder_loader = LoadForwarder(**forwarder_config["async"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@
from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway

logger = make_logger(logger_name())
backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3"


def get_backend_protocol():
cloud_provider = infra_config().cloud_provider
if cloud_provider == "azure":
return "abs"
elif cloud_provider == "aws":
return "s3"
elif cloud_provider == "gcp":
return "gcppubsub"
else:
return "s3" # TODO: I feel like we should raise an error here.


backend_protocol = get_backend_protocol()

celery_redis = celery_app(
None,
Expand All @@ -36,19 +50,33 @@
backend_protocol=backend_protocol,
)
celery_servicebus = celery_app(
None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol
None,
broker_type=str(BrokerType.SERVICEBUS.value),
backend_protocol=backend_protocol,
# TODO: check how Azure uses s3
)

# XXX: check the next line
celery_gcppubsub = celery_app(
None,
broker_type=str(BrokerType.GCPPUBSUB.value),
backend_protocol=backend_protocol,
)


class CeleryTaskQueueGateway(TaskQueueGateway):
def __init__(self, broker_type: BrokerType):
self.broker_type = broker_type
assert self.broker_type in [
BrokerType.SQS,
BrokerType.REDIS,
BrokerType.REDIS_24H,
BrokerType.SERVICEBUS,
]
assert (
self.broker_type
in [ # TODO: why do have this assert? this is the same as the enum -- is it so we remember to implement it here?
BrokerType.SQS,
BrokerType.REDIS,
BrokerType.REDIS_24H,
BrokerType.SERVICEBUS,
BrokerType.GCPPUBSUB,
]
)

def _get_celery_dest(self):
if self.broker_type == BrokerType.SQS:
Expand All @@ -57,6 +85,8 @@ def _get_celery_dest(self):
return celery_redis_24h
elif self.broker_type == BrokerType.REDIS:
return celery_redis
elif self.broker_type == BrokerType.GCPPUBSUB:
return celery_gcppubsub
else:
return celery_servicebus

Expand Down
Loading