Skip to content

Updated code to support ssh into sagemaker space apps - JupyterLab #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 61 additions & 45 deletions sagemaker_ssh_helper/ide.py
Original file line number Diff line number Diff line change
@@ -46,12 +46,13 @@ def __init__(self, arn, version_arn) -> None:
class SSHIDE:
logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE')

def __init__(self, domain_id: str, user: str, region_name: str = None):
self.user = user
def __init__(self, domain_id: str, user_or_space: str = None, region_name: str = None, is_user_profile: bool = True):
self.user_or_space = user_or_space
self.domain_id = domain_id
self.current_region = region_name or boto3.session.Session().region_name
self.client = boto3.client('sagemaker', region_name=self.current_region)
self.ssh_log = SSHLog(region_name=self.current_region)
self.is_user_profile = is_user_profile

def create_ssh_kernel_app(self, app_name: str,
image_name_or_arn='sagemaker-datascience-38',
@@ -108,13 +109,18 @@ def get_app_status(self, app_name: str, app_type: str = 'KernelGateway') -> IDEA
:return: None | 'InService' | 'Deleted' | 'Deleting' | 'Failed' | 'Pending'
"""
response = None

describe_app_request_params = {
"DomainId": self.domain_id,
"AppType": app_type,
"AppName": app_name,
}

describe_app_request_params.update(
{"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space})

try:
response = self.client.describe_app(
DomainId=self.domain_id,
AppType=app_type,
UserProfileName=self.user,
AppName=app_name,
)
response = self.client.describe_app(**describe_app_request_params)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == 'ResourceNotFound':
@@ -137,12 +143,16 @@ def delete_app(self, app_name, app_type, wait: bool = True):
self.logger.info(f"Deleting app {app_name}")

try:
_ = self.client.delete_app(
DomainId=self.domain_id,
AppType=app_type,
UserProfileName=self.user,
AppName=app_name,
)
delete_app_request_params = {
"DomainId": self.domain_id,
"AppType": app_type,
"AppName": app_name,
}

delete_app_request_params.update(
{"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space})

_ = self.client.delete_app(**delete_app_request_params)
except ClientError as e:
# probably, already deleted
code = e.response.get("Error", {}).get("Code")
@@ -173,13 +183,17 @@ def create_app(self, app_name, app_type, instance_type, image_arn,
if lifecycle_arn:
resource_spec['LifecycleConfigArn'] = lifecycle_arn

_ = self.client.create_app(
DomainId=self.domain_id,
AppType=app_type,
AppName=app_name,
UserProfileName=self.user,
ResourceSpec=resource_spec,
)
create_app_request_params = {
"DomainId": self.domain_id,
"AppType": app_type,
"AppName": app_name,
"ResourceSpec": resource_spec,
}

create_app_request_params.update(
{"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space})

_ = self.client.create_app(**create_app_request_params)
status = self.get_app_status(app_name)
while status.is_pending():
self.logger.info(f"Waiting for the InService status. Current status: {status}")
@@ -199,45 +213,47 @@ def resolve_sagemaker_kernel_image_arn(self, image_name):
sagemaker_account_id = "470317259841" # eu-west-1, TODO: check all images
return f"arn:aws:sagemaker:{self.current_region}:{sagemaker_account_id}:image/{image_name}"

def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0):
print(self.get_kernel_instance_id(app_name, timeout_in_sec, index))
def print_instance_id(self, app_name, timeout_in_sec, index: int = 0):
print(self.get_instance_id(app_name, timeout_in_sec, index))

def get_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0,
not_earlier_than_timestamp: int = 0):
ids = self.get_kernel_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp)
def get_instance_id(self, app_name, timeout_in_sec, index: int = 0,
not_earlier_than_timestamp: int = 0):
ids = self.get_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp)
if len(ids) == 0:
raise ValueError(f"No kernel instances found for app {app_name}")
raise ValueError(f"No instances found for app {app_name}")
return ids[index]

def get_kernel_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0):
self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags "
f"in domain '{self.domain_id}' for user '{self.user}'")
def get_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0):
self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags in domain '{self.domain_id}' "
f"for {f'user' if self.is_user_profile else f'space'} '{self.user_or_space}'")
self.log_urls(app_name)
if self.domain_id and self.user:
result = SSMManager().get_studio_user_kgw_instance_ids(self.domain_id, self.user, app_name,
timeout_in_sec, not_earlier_than_timestamp)
elif self.user:

if self.domain_id and self.user_or_space:
result = SSMManager().get_studio_instance_ids(self.domain_id, self.user_or_space, app_name,
timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile)
elif self.user_or_space:
self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest "
f"active kernel gateway with the name {app_name} in the region {self.current_region} "
f"for user profile {self.user}")
result = SSMManager().get_studio_user_kgw_instance_ids("", self.user, app_name,
timeout_in_sec, not_earlier_than_timestamp)
f"active {app_name} in the region {self.current_region} "
f"for {'user' if self.is_user_profile else 'space'} {self.user_or_space}")
result = SSMManager().get_studio_instance_ids("", self.user_or_space, app_name,
timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile)
else:
self.logger.warning(f"Domain ID or user profile name are not set. Will attempt to connect to the latest "
f"active kernel gateway with the name {app_name} in the region {self.current_region}")
self.logger.warning(
f"Domain ID or {'user' if self.is_user_profile else 'space'} are not set. Will attempt to connect to the latest "
f"active {app_name} in the region {self.current_region}")
result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp)
return result

def log_urls(self, app_name):
self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}")
if self.domain_id and self.user:
self.logger.info(f"Remote apps metadata is at {self.get_user_metadata_url()}")
if self.domain_id:
self.logger.info(f"Remote apps metadata is at {self.get_user_or_space_metadata_url()}")

def get_cloudwatch_url(self, app_name):
return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user, app_name)
return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user_or_space, app_name, self.is_user_profile)

def get_user_metadata_url(self):
return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user)
def get_user_or_space_metadata_url(self):
return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user_or_space, self.is_user_profile)

def create_and_attach_image(self, image_name, ecr_image_name,
role_arn,
42 changes: 21 additions & 21 deletions sagemaker_ssh_helper/interactive_sagemaker.py
Original file line number Diff line number Diff line change
@@ -31,23 +31,23 @@ def set_ping_status(self, ping_status):


class SageMakerStudioApp(SageMakerCoreApp):
def __init__(self, domain_id: str, user_profile_name: str, app_name: str, app_type: str,
app_status: IDEAppStatus) -> None:
def __init__(self, domain_id: str, user_profile_or_space_name: str, app_name: str, app_type: str, app_status: IDEAppStatus,
is_user_profile: bool = True) -> None:
super().__init__()
self.app_status = app_status
self.app_type = app_type
self.app_name = app_name
self.user_profile_name = user_profile_name
self.user_profile_or_space_name = user_profile_or_space_name
self.domain_id = domain_id
self.resource_type = "ide"
self.resource_type = "ide" if is_user_profile else "space-ide"

def __str__(self) -> str:
return "{0:<16} {1:<18} {2:<12} {5}.{4}.{3}.{6}".format(
self.ping_status if self.ssm_instance_id else self.NO_SSH_FLAG,
self.app_type,
str(self.app_status),
self.domain_id,
self.user_profile_name,
self.user_profile_or_space_name,
self.app_name,
SageMakerSecureShellHelper.type_to_fqdn(self.resource_type)
)
@@ -160,17 +160,17 @@ def list_ide_apps(self) -> List[SageMakerStudioApp]:
app_name = app_dict['AppName']
app_type = app_dict['AppType']
if 'SpaceName' in app_dict:
logging.info("Don't support spaces: skipping app %s of type %s" % (app_name, app_type))
pass
space_name = app_dict['SpaceName']
logging.info("Found app %s of type %s for space %s" % (app_name, app_type, space_name))
app_status = SSHIDE(domain_id, space_name, self.region, is_user_profile=False).get_app_status(app_name, app_type)
result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=space_name, app_name=app_dict['AppName'],
app_type=app_dict['AppType'], app_status=app_status, is_user_profile=False))
elif app_type in ['JupyterServer', 'KernelGateway']:
user_profile_name = app_dict['UserProfileName']
logging.info("Found app %s of type %s for user %s" % (app_name, app_type, user_profile_name))
app_status = SSHIDE(domain_id, user_profile_name, self.region).get_app_status(app_name, app_type)
result.append(SageMakerStudioApp(
domain_id, user_profile_name,
app_dict['AppName'], app_dict['AppType'],
app_status
))
app_status = SSHIDE(domain_id, user_profile_name, self.region, is_user_profile=True).get_app_status(app_name, app_type)
result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=user_profile_name, app_name=app_dict['AppName'],
app_type=app_dict['AppType'], app_status=app_status, is_user_profile=True))
else:
logging.info("Unsupported app type %s" % app_type)
pass # We don't support other types like 'DetailedProfiler'
@@ -290,14 +290,14 @@ def __init__(self, sagemaker: SageMaker, manager: SSMManager,
self.manager = manager
self.log = log

def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], user_profile_name: Optional[str]):
def list_studio_ide_apps_for_user_or_space_and_domain(self, domain_id: Optional[str], user_profile_or_space_name: Optional[str]):
managed_instances = self.manager.list_all_instances_and_fetch_tags()
sagemaker_apps = self.sagemaker.list_ide_apps()
result = []
for sagemaker_app in sagemaker_apps:
if (sagemaker_app.domain_id == domain_id or domain_id is None or domain_id == "") \
and (sagemaker_app.user_profile_name == user_profile_name or user_profile_name is None
or user_profile_name == ""):
and (sagemaker_app.user_profile_or_space_name == user_profile_or_space_name or user_profile_or_space_name is None
or user_profile_or_space_name == ""):
instance_id = self._find_latest_app_instance_id(managed_instances, sagemaker_app)
if instance_id:
tags = managed_instances[instance_id]
@@ -308,15 +308,15 @@ def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], use
return result

def print_studio_ide_apps_for_user_and_domain(self, domain_id: str, user_profile_name: str):
apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name)
apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_or_space_and_domain(domain_id, user_profile_name)
for app in apps:
print(app)

def list_studio_ide_apps_for_user(self, user_profile_name: str):
return self.list_studio_ide_apps_for_user_and_domain(None, user_profile_name)
def list_studio_ide_apps_for_user_or_space(self, user_profile_or_space_name: str):
return self.list_studio_ide_apps_for_user_or_space_and_domain(None, user_profile_or_space_name)

def list_studio_ide_apps(self):
return self.list_studio_ide_apps_for_user_and_domain(None, None)
return self.list_studio_ide_apps_for_user_or_space_and_domain(None, None)

@staticmethod
def _find_latest_instance_id(managed_instances: Dict[str, Dict[str, str]],
@@ -342,7 +342,7 @@ def _find_latest_app_instance_id(managed_instances: Dict[str, Dict[str, str]], s
arn = tags['SSHResourceArn'] if 'SSHResourceArn' in tags else ''
timestamp = int(tags['SSHTimestamp']) if 'SSHTimestamp' in tags else 0
if (':app/' in arn and arn.endswith(f"/{sagemaker_app.app_name}")
and f"/{sagemaker_app.user_profile_name}/" in arn
and f"/{sagemaker_app.user_profile_or_space_name}/" in arn
and f"/{sagemaker_app.domain_id}/" in arn
and timestamp > max_timestamp):
result = managed_instance_id
16 changes: 10 additions & 6 deletions sagemaker_ssh_helper/log.py
Original file line number Diff line number Diff line change
@@ -197,22 +197,26 @@ def get_transform_metadata_url(self, transform_job_name):
f"sagemaker/home?region={self.region_name}#" \
f"/transform-jobs/{transform_job_name}"

def get_ide_cloudwatch_url(self, domain, user, app_name):
app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway'
if user:
def get_ide_cloudwatch_url(self, domain, user_or_space, app_name, is_user_profile=True):
if is_user_profile:
app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway'
else:
app_type = 'JupyterLab'
if user_or_space:
return f"https://{self.aws_console.get_console_domain()}/" \
f"cloudwatch/home?region={self.region_name}#" \
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \
f"$3FlogStreamNameFilter$3D{domain}$252F{user}$252F{app_type}$252F{app_name}"
f"$3FlogStreamNameFilter$3D{domain}$252F{user_or_space}$252F{app_type}$252F{app_name}"
return f"https://{self.aws_console.get_console_domain()}/" \
f"cloudwatch/home?region={self.region_name}#" \
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \
f"$3FlogStreamNameFilter$3D{app_type}$252F{app_name}"

def get_ide_metadata_url(self, domain, user):
def get_ide_metadata_url(self, domain, user_or_space, is_user_profile=True):
scope = 'user' if is_user_profile else 'space'
return f"https://{self.aws_console.get_console_domain()}/" \
f"sagemaker/home?region={self.region_name}#" \
f"/studio/{domain}/user/{user}"
f"/studio/{domain}/{scope}/{user_or_space}"

def count_sns_notifications(self, topic_name: str, period: timedelta):
cloudwatch_resource = boto3.resource('cloudwatch', region_name=self.region_name)
12 changes: 6 additions & 6 deletions sagemaker_ssh_helper/manager.py
Original file line number Diff line number Diff line change
@@ -113,14 +113,14 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0):
self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}")
return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec)

def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0,
not_earlier_than_timestamp: int = 0):
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway: '{kgw_name}'")
def get_studio_instance_ids(self, domain_id, user_profile_or_space_name, app_name, timeout_in_sec=0, not_earlier_than_timestamp: int = 0, is_user_profile=False):
self.logger.info(f"Querying SSM instance IDs for app '{app_name}' in SageMaker Studio {'kernel gateway' if is_user_profile else 'space'}: '{user_profile_or_space_name}'")
if not domain_id:
arn_filter = f":app/.*/{user_profile_name}/"
arn_filter = f":app/.*/{user_profile_or_space_name}/"
else:
arn_filter = f":app/{domain_id}/{user_profile_name}/"
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec,
arn_filter = f":app/{domain_id}/{user_profile_or_space_name}/"

return self.get_instance_ids('app', f"{app_name}", timeout_in_sec,
arn_filter_regex=arn_filter,
not_earlier_than_timestamp=not_earlier_than_timestamp)

6 changes: 3 additions & 3 deletions sagemaker_ssh_helper/sm-connect-ssh-proxy
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ send_command=$(aws ssm send-command \
'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys',
'ls -la /etc/ssh/authorized_keys'
]" \
--no-cli-pager --no-paginate \
--no-paginate \
--output json)

json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/'
@@ -114,7 +114,7 @@ for i in $(seq 1 15); do
command_output=$(aws ssm get-command-invocation \
--instance-id "${INSTANCE_ID}" \
--command-id "${command_id}" \
--no-cli-pager --no-paginate \
--no-paginate \
--output json)
command_output=$(echo "$command_output" | $(_python) -m json.tool)
command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp")
@@ -166,7 +166,7 @@ proxy_command="aws ssm start-session\
--parameters portNumber=%p"

# shellcheck disable=SC2086
ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
ssh -4 -o User=sagemaker-user -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
-o ProxyCommand="$proxy_command" \
-o ConnectTimeout=90 \
-o ServerAliveInterval=15 -o ServerAliveCountMax=3 \
Loading