Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GCP] Update artifact registry/db code #685

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,17 @@ serviceTemplate:
config:
values:
infra:
# cloud_provider [required]; either "aws" or "azure"
# cloud_provider [required]; either "aws", "azure", or "gcp"
cloud_provider: aws
# k8s_cluster_name [required] is the name of the k8s cluster
k8s_cluster_name: main_cluster
# dns_host_domain [required] is the domain name of the k8s cluster
dns_host_domain: llm-engine.domain.com
# default_region [required] is the default AWS region for various resources (e.g ECR)
default_region: us-east-1
# aws_account_id [required] is the AWS account ID for various resources (e.g ECR)
# ml_account_id [required] is the AWS account ID for various resources (e.g ECR) if cloud_provider is "aws", and the GCP project ID if cloud_provider is "gcp"
ml_account_id: "000000000000"
# docker_repo_prefix [required] is the prefix for AWS ECR repositories
# docker_repo_prefix [required] is the prefix for AWS ECR repositories, GCP Artifact Registry repositories, or Azure Container Registry repositories
docker_repo_prefix: "000000000000.dkr.ecr.us-east-1.amazonaws.com"
# redis_host [required if redis_aws_secret_name not present] is the hostname of the redis cluster you wish to connect
redis_host: llm-engine-prod-cache.use1.cache.amazonaws.com
Expand Down
7 changes: 6 additions & 1 deletion model-engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ Run `mypy . --install-types` to set up mypy.
Most of the business logic in Model Engine should contain unit tests, located in
[`tests/unit`](./tests/unit). To run the tests, run `pytest`.

## Building Docker Images

In order to build docker images, you must change directories into the llm-engine repository root and then run
`docker build -f model-engine/Dockerfile .`

## Generating OpenAI types
We've decided to make our V2 APIs OpenAI compatible. We generate the
corresponding Pydantic models:
1. Fetch the OpenAPI spec from https://github.com/openai/openai-openapi/blob/master/openapi.yaml
2. Run scripts/generate-openai-types.sh
2. Run scripts/generate-openai-types.sh
14 changes: 13 additions & 1 deletion model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
DbTriggerRepository,
ECRDockerRepository,
FakeDockerRepository,
GCPArtifactRegistryDockerRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
RedisModelEndpointCacheRepository,
Expand Down Expand Up @@ -226,6 +227,10 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kovben95scale had a good question about this -- is there a reason we don't use redis on azure etc? I think we're just using this as a celery task queue here, which seems like it would fit with redis.

# we use redis for gcp (instead of using servicebus or the like)
inference_task_queue_gateway = redis_24h_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
else:
inference_task_queue_gateway = sqs_task_queue_gateway
infra_task_queue_gateway = sqs_task_queue_gateway
Expand Down Expand Up @@ -345,6 +350,12 @@ def _get_external_interfaces(
docker_repository = FakeDockerRepository()
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
docker_repository = ACRDockerRepository()
elif "pkg.dev" in infra_config().docker_repo_prefix:
assert (
infra_config().docker_repo_prefix
== f"{infra_config().default_region}-docker.pkg.dev/{infra_config().ml_account_id}" # this stores the gcp project id (when cloud_provider is gcp)
)
docker_repository = GCPArtifactRegistryDockerRepository()
else:
docker_repository = ECRDockerRepository()

Expand Down Expand Up @@ -387,7 +398,8 @@ def get_default_external_interfaces() -> ExternalInterfaces:

def get_default_external_interfaces_read_only() -> ExternalInterfaces:
session = async_scoped_session(
get_session_read_only_async(), scopefunc=asyncio.current_task # type: ignore
get_session_read_only_async(),
scopefunc=asyncio.current_task, # type: ignore
)
return _get_external_interfaces(read_only=True, session=session)

Expand Down
7 changes: 6 additions & 1 deletion model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ class HostedModelInferenceServiceConfig:
user_inference_tensorflow_repository: str
docker_image_layer_cache_repository: str
sensitive_log_mode: bool
# Exactly one of the following three must be specified
# Exactly one of the following four must be specified
cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics
cache_redis_azure_host: Optional[str] = None
cache_redis_aws_secret_name: Optional[str] = (
None # Not an env var because the redis cache info is already here
)
cache_redis_gcp_host: Optional[str] = None

sglang_repository: Optional[str] = None

@classmethod
Expand Down Expand Up @@ -103,6 +105,9 @@ def cache_redis_url(self) -> str:
), "cache_redis_aws_secret_name is only for AWS"
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
return creds["cache-url"]
elif self.cache_redis_gcp_host:
assert infra_config().cloud_provider == "gcp"
return f"rediss://{self.cache_redis_gcp_host}"

assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BrokerName(str, Enum):
"""

REDIS = "redis-message-broker-master"
REDIS_GCP = "redis-gcp-memorystore-message-broker-master"
SQS = "sqs-message-broker-master"
SERVICEBUS = "servicebus-message-broker-master"

Expand Down
2 changes: 2 additions & 0 deletions model-engine/model_engine_server/core/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
TaskVisibility,
celery_app,
get_all_db_indexes,
get_default_backend_protocol,
get_redis_host_port,
inspect_app,
)

__all__: Sequence[str] = (
"celery_app",
"get_default_backend_protocol",
"get_all_db_indexes",
"get_redis_host_port",
"inspect_app",
Expand Down
14 changes: 11 additions & 3 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def seconds_to_visibility(timeout: int) -> "TaskVisibility":
@staticmethod
def from_name(name: str) -> "TaskVisibility":
# pylint: disable=no-member,protected-access
lookup = {
x.name: x.value for x in TaskVisibility._value2member_map_.values()
} # type: ignore
lookup = {x.name: x.value for x in TaskVisibility._value2member_map_.values()} # type: ignore
return TaskVisibility(lookup[name.upper()])


Expand Down Expand Up @@ -595,3 +593,13 @@ async def get_num_unclaimed_tasks_async(
if redis_instance is None:
await _redis_instance.close() # type: ignore
return num_unclaimed


def get_default_backend_protocol():
logger.info("CLOUD PROVIDER: %s", infra_config().cloud_provider)
if infra_config().cloud_provider == "azure":
return "abs"
elif infra_config().cloud_provider == "gcp":
return "redis" # TODO gcp: replace with cloud storage
else:
return "s3"
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def excluded_namespaces():
ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master"
SQS_BROKER = "sqs-message-broker-master"
SERVICEBUS_BROKER = "servicebus-message-broker-master"
GCP_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master"

UPDATE_DEPLOYMENT_MAX_RETRIES = 10

Expand Down Expand Up @@ -588,6 +589,7 @@ async def main():
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True),
SQS_BROKER: SQSBroker(),
SERVICEBUS_BROKER: ASBBroker(),
GCP_REDIS_BROKER: RedisBroker(use_elasticache=False),
}

broker = BROKER_NAME_TO_CLASS[autoscaler_broker]
Expand All @@ -598,10 +600,18 @@ async def main():
)

if broker_type == "redis":
# TODO gcp: change this to use cloud storage
Copy link
Author

@AaDalal AaDalal Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be covered in the cloud storage PR @anishxyz was working on. I think we can merge this first, then merge that.

# NOTE: the infra config is not available in the autoscaler (for some reason), so we have
# to use the autoscaler_broker to determine the infra.
backend_protocol = "redis" if "gcp" in autoscaler_broker else "s3"
inspect = {
db_index: inspect_app(
app=celery_app(
None, broker_type=broker_type, task_visibility=db_index, aws_role=aws_profile
None,
broker_type=broker_type,
task_visibility=db_index,
aws_role=aws_profile,
backend_protocol=backend_protocol,
)
)
for db_index in get_all_db_indexes()
Expand Down
2 changes: 1 addition & 1 deletion model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class _InfraConfig:
k8s_cluster_name: str
dns_host_domain: str
default_region: str
ml_account_id: str
ml_account_id: str # NOTE: this stores the aws account id if cloud_provider is aws, and the gcp project id if cloud_provider is gcpFgc
docker_repo_prefix: str
s3_bucket: str
redis_host: Optional[str] = None
Expand Down
32 changes: 29 additions & 3 deletions model-engine/model_engine_server/db/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import sys
import time
Expand All @@ -7,6 +8,7 @@
import sqlalchemy
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from google.cloud.secretmanager_v1 import SecretManagerServiceClient
from model_engine_server.core.aws.secrets import get_key_file
from model_engine_server.core.config import InfraConfig, infra_config
from model_engine_server.core.loggers import logger_name, make_logger
Expand All @@ -20,8 +22,12 @@


def get_key_file_name(environment: str) -> str:
if infra_config().cloud_provider == "azure":
# azure and gcp don't support "/" in the key file secret name
# so we use dashes
if infra_config().cloud_provider == "azure" or infra_config().cloud_provider == "gcp":
return f"{environment}-ml-infra-pg".replace("training", "prod").replace("-new", "")

# aws does support "/" in the key file secret name
return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "")


Expand Down Expand Up @@ -60,11 +66,11 @@ def get_engine_url(
logger.debug(f"Using key file {key_file}")

if infra_config().cloud_provider == "azure":
client = SecretClient(
az_secret_client = SecretClient(
vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net",
credential=DefaultAzureCredential(),
)
db = client.get_secret(key_file).value
db = az_secret_client.get_secret(key_file).value
user = os.environ.get("AZURE_IDENTITY_NAME")
token = DefaultAzureCredential().get_token(
"https://ossrdbms-aad.database.windows.net/.default"
Expand All @@ -76,6 +82,26 @@ def get_engine_url(
# for recommendations on how to work with rotating auth credentials
engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require"
expiry_in_sec = token.expires_on
elif infra_config().cloud_provider == "gcp":
gcp_secret_manager_client = (
SecretManagerServiceClient()
) # uses application default credentials (see: https://cloud.google.com/secret-manager/docs/reference/libraries#client-libraries-usage-python)
secret_version = gcp_secret_manager_client.access_secret_version(
request={
"name": f"projects/{infra_config().ml_account_id}/secrets/{key_file}/versions/latest"
}
)
creds = json.loads(secret_version.payload.data.decode("utf-8"))

user = creds.get("username")
password = creds.get("password")
host = creds.get("host")
port = str(creds.get("port"))
dbname = creds.get("dbname")

logger.info(f"Connecting to db {host}:{port}, name {dbname}")

engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
else:
db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE")
creds = get_key_file(key_file, db_secret_aws_profile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
GetAsyncTaskV1Response,
TaskStatus,
)
from model_engine_server.core.celery import TaskVisibility, celery_app
from model_engine_server.core.celery import TaskVisibility, celery_app, get_default_backend_protocol
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import InvalidRequestException
from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway

logger = make_logger(logger_name())
backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3"

backend_protocol = get_default_backend_protocol()

celery_redis = celery_app(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,13 @@ def get_endpoint_resource_arguments_from_request(

image_hash = compute_image_hash(request.image)

# In Circle CI, we use Redis on localhost instead of SQS
# In Circle CI/GCP, we use Redis on localhost instead of SQS
if CIRCLECI:
broker_name = BrokerName.REDIS.value
broker_type = BrokerType.REDIS.value
elif infra_config().cloud_provider == "gcp":
broker_name = BrokerName.REDIS_GCP.value
broker_type = BrokerType.REDIS.value
elif infra_config().cloud_provider == "azure":
broker_name = BrokerName.SERVICEBUS.value
broker_type = BrokerType.SERVICEBUS.value
Expand All @@ -576,6 +579,7 @@ def get_endpoint_resource_arguments_from_request(
abs_account_name = os.getenv("ABS_ACCOUNT_NAME")
if abs_account_name is not None:
main_env.append({"name": "ABS_ACCOUNT_NAME", "value": abs_account_name})
# TODO: what should we add here

# LeaderWorkerSet exclusive
worker_env = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .ecr_docker_repository import ECRDockerRepository
from .fake_docker_repository import FakeDockerRepository
from .feature_flag_repository import FeatureFlagRepository
from .gcp_artifact_registry_docker_repository import GCPArtifactRegistryDockerRepository
from .live_tokenizer_repository import LiveTokenizerRepository
from .llm_fine_tune_repository import LLMFineTuneRepository
from .model_endpoint_cache_repository import ModelEndpointCacheRepository
Expand Down Expand Up @@ -42,4 +43,5 @@
"RedisModelEndpointCacheRepository",
"S3FileLLMFineTuneRepository",
"S3FileLLMFineTuneEventsRepository",
"GCPArtifactRegistryDockerRepository",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Optional

from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import NotFound
from google.cloud import artifactregistry_v1 as artifactregistry
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException
from model_engine_server.domain.repositories import DockerRepository

logger = make_logger(logger_name())


class GCPArtifactRegistryDockerRepository(DockerRepository):
def _get_client(self):
client = artifactregistry.ArtifactRegistryClient(
client_options=ClientOptions()
# NOTE: uses default auth credentials for GCP. Read `google.auth.default` function for more details
)
return client

def _get_repository_prefix(self) -> str:
# GCP is verbose and so has a long prefix for the repository
return f"projects/{infra_config().ml_account_id}/locations/{infra_config().default_region}/repositories"

def image_exists(
self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None
) -> bool:
client = self._get_client()

try:
client.get_docker_image(
artifactregistry.GetDockerImageRequest(
# This is the google cloud naming convention: https://cloud.google.com/artifact-registry/docs/docker/names
name=f"{self._get_repository_prefix()}/{repository_name}/dockerImages/{image_tag}"
)
)
except NotFound:
return False
return True

def get_image_url(self, image_tag: str, repository_name: str) -> str:
return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}"

def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
raise NotImplementedError("GCP image build not supported yet")

def get_latest_image_tag(self, repository_name: str) -> str:
client = self._get_client()
parent = f"{self._get_repository_prefix()}/{repository_name}"
try:
images_pager = client.list_docker_images(
artifactregistry.ListDockerImagesRequest(
parent=parent,
order_by="update_time_desc", # NOTE: we expect that the artifact registry is immutable, so there should not be any updates after upload
page_size=1,
)
)

docker_image_page = next(images_pager.pages, None)
if docker_image_page is None:
raise DockerRepositoryNotFoundException
if (
len(docker_image_page.docker_images) == 0
): # This condition shouldn't happen since we're asking for 1 image per page
raise DockerRepositoryNotFoundException
return docker_image_page.docker_images[0].tags[0]
except NotFound:
raise DockerRepositoryNotFoundException
Loading