-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GCP] Update artifact registry/db code #685
base: main
Are you sure you want to change the base?
Changes from 13 commits
8db5a37
8d8d8fa
7894737
f19da2b
0655ccd
c2993f8
66d4f3a
5d4c90c
04f5a19
06df969
8c6265f
3bd84cf
777a604
d2019e5
a7074ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ def excluded_namespaces(): | |
ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" | ||
SQS_BROKER = "sqs-message-broker-master" | ||
SERVICEBUS_BROKER = "servicebus-message-broker-master" | ||
GCP_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master" | ||
|
||
UPDATE_DEPLOYMENT_MAX_RETRIES = 10 | ||
|
||
|
@@ -588,6 +589,7 @@ async def main(): | |
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), | ||
SQS_BROKER: SQSBroker(), | ||
SERVICEBUS_BROKER: ASBBroker(), | ||
GCP_REDIS_BROKER: RedisBroker(use_elasticache=False), | ||
} | ||
|
||
broker = BROKER_NAME_TO_CLASS[autoscaler_broker] | ||
|
@@ -598,10 +600,18 @@ async def main(): | |
) | ||
|
||
if broker_type == "redis": | ||
# TODO gcp: change this to use cloud storage | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be covered in the cloud storage PR @anishxyz was working on. I think we can merge this first, then merge that. |
||
# NOTE: the infra config is not available in the autoscaler (for some reason), so we have | ||
# to use the autoscaler_broker to determine the infra. | ||
backend_protocol = "redis" if "gcp" in autoscaler_broker else "s3" | ||
inspect = { | ||
db_index: inspect_app( | ||
app=celery_app( | ||
None, broker_type=broker_type, task_visibility=db_index, aws_role=aws_profile | ||
None, | ||
broker_type=broker_type, | ||
task_visibility=db_index, | ||
aws_role=aws_profile, | ||
backend_protocol=backend_protocol, | ||
) | ||
) | ||
for db_index in get_all_db_indexes() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import json | ||
import os | ||
import sys | ||
import time | ||
|
@@ -7,6 +8,7 @@ | |
import sqlalchemy | ||
from azure.identity import DefaultAzureCredential | ||
from azure.keyvault.secrets import SecretClient | ||
from google.cloud.secretmanager_v1 import SecretManagerServiceClient | ||
from model_engine_server.core.aws.secrets import get_key_file | ||
from model_engine_server.core.config import InfraConfig, infra_config | ||
from model_engine_server.core.loggers import logger_name, make_logger | ||
|
@@ -20,8 +22,12 @@ | |
|
||
|
||
def get_key_file_name(environment: str) -> str: | ||
if infra_config().cloud_provider == "azure": | ||
# azure and gcp don't support "/" in the key file secret name | ||
# so we use dashes | ||
if infra_config().cloud_provider == "azure" or infra_config().cloud_provider == "gcp": | ||
return f"{environment}-ml-infra-pg".replace("training", "prod").replace("-new", "") | ||
|
||
# aws does support "/" in the key file secret name | ||
return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "") | ||
|
||
|
||
|
@@ -55,16 +61,17 @@ def get_engine_url( | |
key_file = os.environ.get("DB_SECRET_NAME") | ||
if env is None: | ||
env = infra_config().env | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the values of env / where is it used? |
||
# TODO: what are the values of env? | ||
if key_file is None: | ||
key_file = get_key_file_name(env) # type: ignore | ||
logger.debug(f"Using key file {key_file}") | ||
|
||
if infra_config().cloud_provider == "azure": | ||
client = SecretClient( | ||
az_secret_client = SecretClient( | ||
vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net", | ||
credential=DefaultAzureCredential(), | ||
) | ||
db = client.get_secret(key_file).value | ||
db = az_secret_client.get_secret(key_file).value | ||
user = os.environ.get("AZURE_IDENTITY_NAME") | ||
token = DefaultAzureCredential().get_token( | ||
"https://ossrdbms-aad.database.windows.net/.default" | ||
|
@@ -76,6 +83,26 @@ def get_engine_url( | |
# for recommendations on how to work with rotating auth credentials | ||
engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require" | ||
expiry_in_sec = token.expires_on | ||
elif infra_config().cloud_provider == "gcp": | ||
gcp_secret_manager_client = ( | ||
SecretManagerServiceClient() | ||
) # uses application default credentials (see: https://cloud.google.com/secret-manager/docs/reference/libraries#client-libraries-usage-python) | ||
secret_version = gcp_secret_manager_client.access_secret_version( | ||
request={ | ||
"name": f"projects/{infra_config().ml_account_id}/secrets/{key_file}/versions/latest" | ||
} | ||
) | ||
creds = json.loads(secret_version.payload.data.decode("utf-8")) | ||
|
||
user = creds.get("username") | ||
password = creds.get("password") | ||
host = creds.get("host") | ||
port = str(creds.get("port")) | ||
dbname = creds.get("dbname") | ||
|
||
logger.info(f"Connecting to db {host}:{port}, name {dbname}") | ||
|
||
engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}" | ||
else: | ||
db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE") | ||
creds = get_key_file(key_file, db_secret_aws_profile) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Optional | ||
|
||
from google.api_core.client_options import ClientOptions | ||
from google.api_core.exceptions import NotFound | ||
from google.cloud import artifactregistry_v1 as artifactregistry | ||
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse | ||
from model_engine_server.core.config import infra_config | ||
from model_engine_server.core.loggers import logger_name, make_logger | ||
from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException | ||
from model_engine_server.domain.repositories import DockerRepository | ||
|
||
logger = make_logger(logger_name()) | ||
|
||
|
||
class GCPArtifactRegistryDockerRepository(DockerRepository): | ||
def _get_client(self): | ||
client = artifactregistry.ArtifactRegistryClient( | ||
client_options=ClientOptions() | ||
# NOTE: uses default auth credentials for GCP. Read `google.auth.default` function for more details | ||
) | ||
return client | ||
|
||
def _get_repository_prefix(self) -> str: | ||
# GCP is verbose and so has a long prefix for the repository | ||
return f"projects/{infra_config().ml_account_id}/locations/{infra_config().default_region}/repositories" | ||
|
||
def image_exists( | ||
self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None | ||
) -> bool: | ||
client = self._get_client() | ||
|
||
try: | ||
client.get_docker_image( | ||
artifactregistry.GetDockerImageRequest( | ||
# This is the google cloud naming convention: https://cloud.google.com/artifact-registry/docs/docker/names | ||
name=f"{self._get_repository_prefix()}/{repository_name}/dockerImages/{image_tag}" | ||
) | ||
) | ||
except NotFound: | ||
return False | ||
return True | ||
|
||
def get_image_url(self, image_tag: str, repository_name: str) -> str: | ||
return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" | ||
|
||
def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: | ||
raise NotImplementedError("GCP image build not supported yet") | ||
|
||
def get_latest_image_tag(self, repository_name: str) -> str: | ||
client = self._get_client() | ||
parent = f"{self._get_repository_prefix()}/{repository_name}" | ||
try: | ||
images_pager = client.list_docker_images( | ||
artifactregistry.ListDockerImagesRequest( | ||
parent=parent, | ||
order_by="update_time_desc", # NOTE: we expect that the artifact registry is immutable, so there should not be any updates after upload | ||
page_size=1, | ||
) | ||
) | ||
|
||
docker_image_page = next(images_pager.pages, None) | ||
if docker_image_page is None: | ||
raise DockerRepositoryNotFoundException | ||
if ( | ||
len(docker_image_page.docker_images) == 0 | ||
): # This condition shouldn't happen since we're asking for 1 image per page | ||
raise DockerRepositoryNotFoundException | ||
return docker_image_page.docker_images[ | ||
0 | ||
].name # TODO: is the return as expected? it's a big string, not just the tag | ||
AaDalal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except NotFound: | ||
raise DockerRepositoryNotFoundException |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kovben95scale had a good question about this -- is there a reason we don't use redis on azure etc? I think we're just using this as a celery task queue here, which seems like it would fit with redis.