Skip to content

Commit

Permalink
[SageMaker] Add launcher and support for SageMaker HPO jobs (#1133)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

Fixes #1072

*Description of changes:*

* We change train_entry and sagmaker_train so that when the user does
not provide a value for the `--model-artifact-s3` argument, we rely on
the SageMaker service to upload model artifacts. This makes it possible
to run HPO and other jobs that rely on using the "official" SageMaker
model paths.
* Add example launch script for HPO jobs and documentation.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
thvasilo and classicsong authored Feb 18, 2025
1 parent 4289130 commit fdb2575
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 21 deletions.
99 changes: 99 additions & 0 deletions docs/source/cli/model-training-inference/distributed/sagemaker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,105 @@ from ``${DATASET_S3_PATH}`` as input and create a DistDGL graph with
``${NUM_PARTITIONS}`` under the output path, ``${OUTPUT_PATH}``.
Currently we only support ``random`` as the partitioning algorithm.

Launch hyper-parameter optimization task
````````````````````````````````````````

GraphStorm supports `automatic model tuning <https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html>`_
with SageMaker AI,
which allows you to optimize the hyper-parameters
of your model with an easy-to-use interface.

The ``sagemaker/launch/launch_hyperparameter_tuning.py`` script can act as a thin
wrapper for SageMaker's `HyperParameterTuner <https://sagemaker.readthedocs.io/en/stable/api/training/tuner.html>`_.

You define the hyper-parameters of interest by passing a filepath to a JSON file,
or a python dictionary as a string,
where the structure of the dictionary is the same as for SageMaker's
`Dynamic hyper-parameters <https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html#automatic-model-tuning-define-ranges-dynamic>`.
For example your JSON file can look like:

.. code:: python
# Content of my_param_ranges.json
{
"ParameterRanges": {
"CategoricalParameterRanges": [
{
"Name": "model_encoder_type",
"Values": ["rgcn", "hgt"]
}
],
"ContinuousParameterRanges": [
{
"Name": "lr",
"MinValue": "1e-5",
"MaxValue" : "1e-2",
"ScalingType": "Auto"
}
],
"IntegerParameterRanges": [
{
"Name": "hidden_size",
"MinValue": "64",
"MaxValue": "256",
"ScalingType": "Auto"
}
]
}
}
Which you can then use to launch an HPO job:

.. code:: bash
# Example hyper-parameter ranges
python launch/launch_hyperparameter_tuning.py \
--hyperparameter-ranges my_param_ranges.json
# Other launch parameters...
For continuous and integer parameters you can provide a ``ScalingType``
string that directly corresponds to one of SageMaker's
`scaling types <https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html#scaling-type>`_.
By default scaling type will be ``'Auto'``.

Use ``--metric-name`` to provide the name of a GraphStorm metric to use as a tuning objective,
e.g. ``"accuracy"``. See the entry for ``eval_metric`` in :ref:`Evaluation Metrics <eval_metrics>`
for a full list of supported metrics.

``--eval-mask`` defines which dataset to collect metrics from, and
can be either ``"test"`` or ``"val"`` to collect metrics from test or validation set,
respectively. Finally use ``--objective-type`` to set the type of the objective,
which can be either ``"Maximize"`` or ``"Minimize"``.
See the `SageMaker documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html>`_
for more details

Finally you can use ``--strategy`` to select the optimization strategy
from one of "Bayesian", "Random", "Hyperband", "Grid". See the
`SageMaker documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html>`_
for more details on each strategy.

Example HPO call:

.. code:: bash
python launch/launch_hyperparameter_tuning.py \
--task-name my-gnn-hpo-job \
--role arn:aws:iam::123456789012:role/SageMakerRole \
--region us-west-2 \
--image-url 123456789012.dkr.ecr.us-west-2.amazonaws.com/graphstorm:sagemaker-gpu \
--graph-name my-graph \
--task-type node_classification \
--graph-data-s3 s3://my-bucket/graph-data/ \
--yaml-s3 s3://my-bucket/train.yaml \
--model-artifact-s3 s3://my-bucket/model-artifacts/ \
--max-jobs 20 \
--max-parallel-jobs 4 \
--hyperparameter-ranges my_param_ranges.json \
--metric-name "accuracy" \
--eval-mask "val" \
--objective-type "Maximize" \
--strategy "Bayesian"
Passing additional arguments to the SageMaker Estimator
```````````````````````````````````````````````````````
Sometimes you might want to pass additional arguments to the constructor
Expand Down
47 changes: 32 additions & 15 deletions python/graphstorm/sagemaker/sagemaker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
Training entry point.
"""
# Install additional requirements
import os

import json
import logging
import os
import queue
import socket
import time
import json
import subprocess
from threading import Thread, Event
import sys
import queue
import time
from threading import Thread, Event

import boto3
import sagemaker
Expand All @@ -36,6 +36,7 @@
BUILTIN_TASK_MULTI_TASK)
from .utils import (download_yaml_config,
download_graph,
get_job_port,
keep_alive,
barrier_master,
barrier,
Expand Down Expand Up @@ -164,7 +165,16 @@ def run_train(args, unknownargs):
os.makedirs(restore_model_path, exist_ok=True)
else:
restore_model_path = None
output_path = "/tmp/gsgnn_model/"

if args.model_artifact_s3:
# If a user provides an S3 output destination as an input arg, the script itself
# will upload the model artifacts after training, so we save under /tmp.
output_path = "/tmp/gsgnn_model/"
else:
# If the user does not provide an output destination as an arg, we rely on SageMaker to
# do the model upload so we save the model to the pre-determined path /opt/ml/model
output_path = "/opt/ml/model"

os.makedirs(output_path, exist_ok=True)

# start the ssh server
Expand Down Expand Up @@ -195,12 +205,14 @@ def run_train(args, unknownargs):
raise RuntimeError(f"Can not get host name of {hosts}")

master_addr = args.master_addr
master_port = get_job_port(train_env['job_name'])
# sync with all instances in the cluster
if host_rank == 0:
# sync with workers
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((master_addr, 12345))
sock.bind((master_addr, master_port))
sock.listen(world_size)
logging.info("Master listening on %s:%s", master_addr, master_port)

client_list = [None] * world_size
for i in range(1, world_size):
Expand All @@ -211,12 +223,12 @@ def run_train(args, unknownargs):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for _ in range(30):
try:
sock.connect((master_addr, 12345))
sock.connect((master_addr, master_port))
break
except: # pylint: disable=bare-except
logging.info("Try to connect %s", master_addr)
logging.info("Trying to connect to %s:%s...", master_addr, master_port)
time.sleep(10)
logging.info("Connected")
logging.info("Connected to %s:%s", master_addr, master_port)

# write ip list info into disk
ip_list_path = os.path.join(data_path, 'ip_list.txt')
Expand All @@ -229,7 +241,11 @@ def run_train(args, unknownargs):
graph_data_s3 = args.graph_data_s3
task_type = args.task_type
train_yaml_s3 = args.train_yaml_s3
model_artifact_s3 = args.model_artifact_s3.rstrip('/')
# If the user provided an output destination, trim any trailing '/'
if args.model_artifact_s3:
gs_model_artifact_s3 = args.model_artifact_s3.rstrip('/')
else:
gs_model_artifact_s3 = None
custom_script = args.custom_script

boto_session = boto3.session.Session(region_name=args.region)
Expand Down Expand Up @@ -292,6 +308,7 @@ def run_train(args, unknownargs):
logging.error("Task failed")
sys.exit(-1)

# If there are saved models
if os.path.exists(save_model_path):
upload_model_artifacts(model_artifact_s3, save_model_path, sagemaker_session)
# We upload models only when the user explicitly set the model_artifact_s3
# argument. Otherwise we can rely on the SageMaker service to do the upload.
if gs_model_artifact_s3 and os.path.exists(save_model_path):
upload_model_artifacts(gs_model_artifact_s3, save_model_path, sagemaker_session)
73 changes: 71 additions & 2 deletions python/graphstorm/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@
sagemaker script utilities
"""
import subprocess

import hashlib
import logging
import os
import time
import shutil
import socket
import subprocess
import time
from typing import Optional
from urllib.parse import urlparse

import boto3
from botocore.errorfactory import ClientError
from sagemaker.s3 import S3Downloader
from sagemaker.s3 import S3Uploader

PORT_MIN = 10000 # Avoid privileged ports
PORT_MAX = 65535 # Maximum TCP port number

def run(launch_cmd, state_q, env=None):
""" Running cmd using shell
Expand Down Expand Up @@ -455,3 +461,66 @@ def remove_embs(emb_path):
Local embedding path
"""
remove_data(emb_path)

def is_port_available(port):
"""Check if a port is available."""
try:
# Try to bind to all interfaces with a timeout
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# Add one second timeout
s.settimeout(1)
s.bind(('', port))
# Also try listening to ensure port is fully available
s.listen(1)
return True
except (OSError, socket.timeout):
return False

def find_free_port(start_port: Optional[int]=None):
"""Find next available port, starting from start_port."""
if start_port is None:
# Let OS choose
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
except OSError:
# Fall back to manual search if OS assignment fails
start_port = PORT_MIN

# Try ports sequentially starting from start_port
port = start_port
while port <= PORT_MAX:
if is_port_available(port):
return port
port += 1

raise RuntimeError("No available ports found")

def get_job_port(job_str_identifier: Optional[str] = None):
"""Get port number based on per-job unique ID
Parameters
----------
unique_identifier : str, optional
An identifier that should be unique to each SM job, by default None
Returns
-------
int
A common port number that master and workers will use
"""
if not job_str_identifier:
job_str_identifier = os.getenv('SM_USER_ARGS', '')

# Create a hash of the unique identifier
hash_object = hashlib.md5(job_str_identifier.encode())
hash_hex = hash_object.hexdigest()

# Convert first 4 chars of hash to int and scale to valid port range
# Using 10000-65000 to avoid privileged ports and common ports
base_port = PORT_MIN + (int(hash_hex[:4], 16) % (PORT_MAX - PORT_MIN))

# Ensure we return an open port, starting at base_port
port = find_free_port(base_port)
return port
Loading

0 comments on commit fdb2575

Please sign in to comment.