diff --git a/docs/source/cli/model-training-inference/distributed/sagemaker.rst b/docs/source/cli/model-training-inference/distributed/sagemaker.rst index 603b529d7d..518e36a789 100644 --- a/docs/source/cli/model-training-inference/distributed/sagemaker.rst +++ b/docs/source/cli/model-training-inference/distributed/sagemaker.rst @@ -213,6 +213,105 @@ from ``${DATASET_S3_PATH}`` as input and create a DistDGL graph with ``${NUM_PARTITIONS}`` under the output path, ``${OUTPUT_PATH}``. Currently we only support ``random`` as the partitioning algorithm. +Launch hyper-parameter optimization task +```````````````````````````````````````` + +GraphStorm supports `automatic model tuning `_ +with SageMaker AI, +which allows you to optimize the hyper-parameters +of your model with an easy-to-use interface. + +The ``sagemaker/launch/launch_hyperparameter_tuning.py`` script can act as a thin +wrapper for SageMaker's `HyperParameterTuner `_. + +You define the hyper-parameters of interest by passing a filepath to a JSON file, +or a python dictionary as a string, +where the structure of the dictionary is the same as for SageMaker's +`Dynamic hyper-parameters `. +For example your JSON file can look like: + +.. code:: python + + # Content of my_param_ranges.json + { + "ParameterRanges": { + "CategoricalParameterRanges": [ + { + "Name": "model_encoder_type", + "Values": ["rgcn", "hgt"] + } + ], + "ContinuousParameterRanges": [ + { + "Name": "lr", + "MinValue": "1e-5", + "MaxValue" : "1e-2", + "ScalingType": "Auto" + } + ], + "IntegerParameterRanges": [ + { + "Name": "hidden_size", + "MinValue": "64", + "MaxValue": "256", + "ScalingType": "Auto" + } + ] + } + } + +Which you can then use to launch an HPO job: + +.. code:: bash + + # Example hyper-parameter ranges + python launch/launch_hyperparameter_tuning.py \ + --hyperparameter-ranges my_param_ranges.json + # Other launch parameters... + +For continuous and integer parameters you can provide a ``ScalingType`` +string that directly corresponds to one of SageMaker's +`scaling types `_. +By default scaling type will be ``'Auto'``. + +Use ``--metric-name`` to provide the name of a GraphStorm metric to use as a tuning objective, +e.g. ``"accuracy"``. See the entry for ``eval_metric`` in :ref:`Evaluation Metrics ` +for a full list of supported metrics. + +``--eval-mask`` defines which dataset to collect metrics from, and +can be either ``"test"`` or ``"val"`` to collect metrics from test or validation set, +respectively. Finally use ``--objective-type`` to set the type of the objective, +which can be either ``"Maximize"`` or ``"Minimize"``. +See the `SageMaker documentation `_ +for more details + +Finally you can use ``--strategy`` to select the optimization strategy +from one of "Bayesian", "Random", "Hyperband", "Grid". See the +`SageMaker documentation `_ +for more details on each strategy. + +Example HPO call: + +.. code:: bash + + python launch/launch_hyperparameter_tuning.py \ + --task-name my-gnn-hpo-job \ + --role arn:aws:iam::123456789012:role/SageMakerRole \ + --region us-west-2 \ + --image-url 123456789012.dkr.ecr.us-west-2.amazonaws.com/graphstorm:sagemaker-gpu \ + --graph-name my-graph \ + --task-type node_classification \ + --graph-data-s3 s3://my-bucket/graph-data/ \ + --yaml-s3 s3://my-bucket/train.yaml \ + --model-artifact-s3 s3://my-bucket/model-artifacts/ \ + --max-jobs 20 \ + --max-parallel-jobs 4 \ + --hyperparameter-ranges my_param_ranges.json \ + --metric-name "accuracy" \ + --eval-mask "val" \ + --objective-type "Maximize" \ + --strategy "Bayesian" + Passing additional arguments to the SageMaker Estimator ``````````````````````````````````````````````````````` Sometimes you might want to pass additional arguments to the constructor diff --git a/python/graphstorm/sagemaker/sagemaker_train.py b/python/graphstorm/sagemaker/sagemaker_train.py index fe8aafd146..21efa1988c 100644 --- a/python/graphstorm/sagemaker/sagemaker_train.py +++ b/python/graphstorm/sagemaker/sagemaker_train.py @@ -15,16 +15,16 @@ Training entry point. """ -# Install additional requirements -import os + +import json import logging +import os +import queue import socket -import time -import json import subprocess -from threading import Thread, Event import sys -import queue +import time +from threading import Thread, Event import boto3 import sagemaker @@ -36,6 +36,7 @@ BUILTIN_TASK_MULTI_TASK) from .utils import (download_yaml_config, download_graph, + get_job_port, keep_alive, barrier_master, barrier, @@ -164,7 +165,16 @@ def run_train(args, unknownargs): os.makedirs(restore_model_path, exist_ok=True) else: restore_model_path = None - output_path = "/tmp/gsgnn_model/" + + if args.model_artifact_s3: + # If a user provides an S3 output destination as an input arg, the script itself + # will upload the model artifacts after training, so we save under /tmp. + output_path = "/tmp/gsgnn_model/" + else: + # If the user does not provide an output destination as an arg, we rely on SageMaker to + # do the model upload so we save the model to the pre-determined path /opt/ml/model + output_path = "/opt/ml/model" + os.makedirs(output_path, exist_ok=True) # start the ssh server @@ -195,12 +205,14 @@ def run_train(args, unknownargs): raise RuntimeError(f"Can not get host name of {hosts}") master_addr = args.master_addr + master_port = get_job_port(train_env['job_name']) # sync with all instances in the cluster if host_rank == 0: # sync with workers sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((master_addr, 12345)) + sock.bind((master_addr, master_port)) sock.listen(world_size) + logging.info("Master listening on %s:%s", master_addr, master_port) client_list = [None] * world_size for i in range(1, world_size): @@ -211,12 +223,12 @@ def run_train(args, unknownargs): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(30): try: - sock.connect((master_addr, 12345)) + sock.connect((master_addr, master_port)) break except: # pylint: disable=bare-except - logging.info("Try to connect %s", master_addr) + logging.info("Trying to connect to %s:%s...", master_addr, master_port) time.sleep(10) - logging.info("Connected") + logging.info("Connected to %s:%s", master_addr, master_port) # write ip list info into disk ip_list_path = os.path.join(data_path, 'ip_list.txt') @@ -229,7 +241,11 @@ def run_train(args, unknownargs): graph_data_s3 = args.graph_data_s3 task_type = args.task_type train_yaml_s3 = args.train_yaml_s3 - model_artifact_s3 = args.model_artifact_s3.rstrip('/') + # If the user provided an output destination, trim any trailing '/' + if args.model_artifact_s3: + gs_model_artifact_s3 = args.model_artifact_s3.rstrip('/') + else: + gs_model_artifact_s3 = None custom_script = args.custom_script boto_session = boto3.session.Session(region_name=args.region) @@ -292,6 +308,7 @@ def run_train(args, unknownargs): logging.error("Task failed") sys.exit(-1) - # If there are saved models - if os.path.exists(save_model_path): - upload_model_artifacts(model_artifact_s3, save_model_path, sagemaker_session) + # We upload models only when the user explicitly set the model_artifact_s3 + # argument. Otherwise we can rely on the SageMaker service to do the upload. + if gs_model_artifact_s3 and os.path.exists(save_model_path): + upload_model_artifacts(gs_model_artifact_s3, save_model_path, sagemaker_session) diff --git a/python/graphstorm/sagemaker/utils.py b/python/graphstorm/sagemaker/utils.py index 42e46f334c..71d51fbe5e 100644 --- a/python/graphstorm/sagemaker/utils.py +++ b/python/graphstorm/sagemaker/utils.py @@ -15,11 +15,15 @@ sagemaker script utilities """ -import subprocess + +import hashlib import logging import os -import time import shutil +import socket +import subprocess +import time +from typing import Optional from urllib.parse import urlparse import boto3 @@ -27,6 +31,8 @@ from sagemaker.s3 import S3Downloader from sagemaker.s3 import S3Uploader +PORT_MIN = 10000 # Avoid privileged ports +PORT_MAX = 65535 # Maximum TCP port number def run(launch_cmd, state_q, env=None): """ Running cmd using shell @@ -455,3 +461,66 @@ def remove_embs(emb_path): Local embedding path """ remove_data(emb_path) + +def is_port_available(port): + """Check if a port is available.""" + try: + # Try to bind to all interfaces with a timeout + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Add one second timeout + s.settimeout(1) + s.bind(('', port)) + # Also try listening to ensure port is fully available + s.listen(1) + return True + except (OSError, socket.timeout): + return False + +def find_free_port(start_port: Optional[int]=None): + """Find next available port, starting from start_port.""" + if start_port is None: + # Let OS choose + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + except OSError: + # Fall back to manual search if OS assignment fails + start_port = PORT_MIN + + # Try ports sequentially starting from start_port + port = start_port + while port <= PORT_MAX: + if is_port_available(port): + return port + port += 1 + + raise RuntimeError("No available ports found") + +def get_job_port(job_str_identifier: Optional[str] = None): + """Get port number based on per-job unique ID + + Parameters + ---------- + unique_identifier : str, optional + An identifier that should be unique to each SM job, by default None + + Returns + ------- + int + A common port number that master and workers will use + """ + if not job_str_identifier: + job_str_identifier = os.getenv('SM_USER_ARGS', '') + + # Create a hash of the unique identifier + hash_object = hashlib.md5(job_str_identifier.encode()) + hash_hex = hash_object.hexdigest() + + # Convert first 4 chars of hash to int and scale to valid port range + # Using 10000-65000 to avoid privileged ports and common ports + base_port = PORT_MIN + (int(hash_hex[:4], 16) % (PORT_MAX - PORT_MIN)) + + # Ensure we return an open port, starting at base_port + port = find_free_port(base_port) + return port diff --git a/sagemaker/launch/launch_hyperparameter_tuning.py b/sagemaker/launch/launch_hyperparameter_tuning.py new file mode 100644 index 0000000000..4524e8b7e1 --- /dev/null +++ b/sagemaker/launch/launch_hyperparameter_tuning.py @@ -0,0 +1,260 @@ +r""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Launch SageMaker HPO jobs. + + Example: + + python launch_hyperparameter_tuning.py \ + --task-name my-gnn-hpo-job \ + --role arn:aws:iam::123456789012:role/SageMakerRole \ + --region us-west-2 \ + --image-url 123456789012.dkr.ecr.us-west-2.amazonaws.com/my-graphstorm-image:latest \ + --graph-name my-graph \ + --task-type node_classification \ + --graph-data-s3 s3://my-bucket/graph-data/ \ + --yaml-s3 s3://my-bucket/train.yaml \ + --model-artifact-s3 s3://my-bucket/model-artifacts/ \ + --max-jobs 20 \ + --max-parallel-jobs 4 \ + --hyperparameter-ranges ./my-parameter-ranges.json + --metric-name "accuracy" \ + --metric-dataset "val" \ + --objective-type Maximize +""" +import os +import json +import logging + +from sagemaker.pytorch.estimator import PyTorch +from sagemaker.tuner import ( + ParameterRange, + HyperparameterTuner, + ContinuousParameter, + IntegerParameter, + CategoricalParameter, +) + +from common_parser import ( + parse_estimator_kwargs, + parse_unknown_gs_args, +) +from launch_train import get_train_parser + +INSTANCE_TYPE = "ml.g4dn.12xlarge" + + +def parse_hyperparameter_ranges(hyperparameter_ranges_json: str) -> dict[str, ParameterRange]: + """Parse the hyperparameter ranges from JSON string and determine if autotune is needed. + + Expected JSON structure is at + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html + + Parameters + ---------- + hyperparameter_ranges_json : str + Path to a JSON file or JSON as a string. + + Returns + ------- + dict + Hyperparameter dict that can be used to create a HyperparameterTuner + object. + + Raises + ------ + ValueError + If the JSON dict contains an invalid parameter type. + + .. versionadded:: 0.4.1 + """ + if os.path.exists(hyperparameter_ranges_json): + with open(hyperparameter_ranges_json, "r", encoding="utf-8") as f: + str_ranges: dict[str, list[dict]] = json.load(f)["ParameterRanges"] + else: + str_ranges = json.loads(hyperparameter_ranges_json)["ParameterRanges"] + + hyperparameter_ranges: dict[str, ParameterRange] = {} + for params_type, config_list in str_ranges.items(): + if params_type == "ContinuousParameterRanges": + for config in config_list: + param_name = config["Name"] + hyperparameter_ranges[param_name] = ContinuousParameter( + config["MinValue"], + config["MaxValue"], + config.get("ScalingType", "Auto"), + ) + elif params_type == "IntegerParameterRanges": + for config in config_list: + param_name = config["Name"] + hyperparameter_ranges[param_name] = IntegerParameter( + config["MinValue"], + config["MaxValue"], + config.get("ScalingType", "Auto"), + ) + elif params_type == "CategoricalParameterRanges": + for config in config_list: + param_name = config["Name"] + hyperparameter_ranges[param_name] = CategoricalParameter( + config["Values"] + ) + else: + raise ValueError( + f"Unknown parameter type {params_type}. " + "Expect one of 'CategoricalParameterRanges', 'ContinuousParameterRanges', " + "'IntegerParameterRanges'" + ) + return hyperparameter_ranges + + +def run_hyperparameter_tuning_job(args, image, unknownargs): + """Run hyperparameter tuning job using SageMaker HyperparameterTuner""" + + container_image_uri = image + + prefix = f"gs-hpo-{args.graph_name}" + + params = { + "eval-metric": args.metric_name, + "graph-data-s3": args.graph_data_s3, + "graph-name": args.graph_name, + "log-level": args.log_level, + "task-type": args.task_type, + "train-yaml-s3": args.yaml_s3, + "topk-model-to-save": "1", + } + if args.custom_script is not None: + params["custom-script"] = args.custom_script + if args.model_checkpoint_to_load is not None: + params["model-checkpoint-to-load"] = args.model_checkpoint_to_load + + unknown_args_dict = parse_unknown_gs_args(unknownargs) + params.update(unknown_args_dict) + + logging.info("SageMaker launch parameters %s", params) + logging.info("Parameters forwarded to GraphStorm %s", unknown_args_dict) + + estimator_kwargs = parse_estimator_kwargs(args.sm_estimator_parameters) + + est = PyTorch( + entry_point=os.path.basename(args.entry_point), + source_dir=os.path.dirname(args.entry_point), + image_uri=container_image_uri, + role=args.role, + instance_count=args.instance_count, + instance_type=args.instance_type, + output_path=args.model_artifact_s3, + py_version="py3", + base_job_name=prefix, + hyperparameters=params, + tags=[ + {"Key": "GraphStorm", "Value": "oss"}, + {"Key": "GraphStorm_Task", "Value": "HPO"}, + ], + **estimator_kwargs, + ) + + hyperparameter_ranges = parse_hyperparameter_ranges(args.hyperparameter_ranges) + + # Construct the full metric name based on user input + full_metric_name = f"best_{args.eval_mask.lower()}_score:{args.metric_name.lower()}" + + tuner = HyperparameterTuner( + estimator=est, + objective_metric_name=full_metric_name, + hyperparameter_ranges=hyperparameter_ranges, + objective_type=args.objective_type, + max_jobs=args.max_jobs, + max_parallel_jobs=args.max_parallel_jobs, + metric_definitions=[ + { + "Name": full_metric_name, + "Regex": ( + f"INFO:root:best_{args.eval_mask.lower()}_score: " + f"{{'{args.metric_name}': ([0-9\\.]+)}}" + ), + } + ], + strategy=args.strategy, + ) + + tuner.fit({"train": args.yaml_s3}, wait=not args.async_execution) + + +def get_hpo_parser(): + """Return a parser for GraphStorm hyperparameter tuning task.""" + parser = get_train_parser() + + hpo_group = parser.add_argument_group("Hyperparameter tuning arguments") + + hpo_group.add_argument( + "--max-jobs", + type=int, + default=10, + help="Maximum number of training jobs to run", + ) + hpo_group.add_argument( + "--max-parallel-jobs", + type=int, + default=2, + help="Maximum number of parallel training jobs", + ) + hpo_group.add_argument( + "--hyperparameter-ranges", + type=str, + required=True, + help="Path to a JSON file, or a JSON string defining hyperparameter ranges. " + "For syntax see 'Dynamic hyperparameters' in " + "https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html" + , + ) + hpo_group.add_argument( + "--metric-name", + type=str, + required=True, + help="Name of the metric to optimize for (e.g., 'accuracy', 'amri')", + ) + hpo_group.add_argument( + "--eval-mask", + type=str, + required=True, + choices=["test", "val"], + help="Whether to use test or validation metrics for HPO.", + ) + hpo_group.add_argument( + "--objective-type", + type=str, + default="Maximize", + choices=["Maximize", "Minimize"], + help="Type of objective, can be 'Maximize' or 'Minimize'", + ) + hpo_group.add_argument( + "--strategy", + type=str, + default="Bayesian", + choices=["Bayesian", "Random", "Hyperband", "Grid"], + help="Optimization strategy. Default: 'Bayesian'.", + ) + + return parser + + +if __name__ == "__main__": + arg_parser = get_hpo_parser() + args, unknownargs = arg_parser.parse_known_args() + print(f"HPO launch known args: '{args}'") + print(f"HPO launch unknown args:{type(unknownargs)=} '{unknownargs=}'") + + run_hyperparameter_tuning_job(args, args.image_url, unknownargs) diff --git a/sagemaker/run/train_entry.py b/sagemaker/run/train_entry.py index 0d08948370..91afb5a8f9 100644 --- a/sagemaker/run/train_entry.py +++ b/sagemaker/run/train_entry.py @@ -37,10 +37,11 @@ def get_train_parser(): required=True) parser.add_argument("--train-yaml-s3", type=str, help="S3 location of training yaml file. " - "Do not store it with partitioned graph", - required=True) - parser.add_argument("--model-artifact-s3", type=str, - help="S3 location to store the model artifacts.") + "Do not store it with partitioned graph") + parser.add_argument("--model-artifact-s3", type=str, default=None, + help="S3 location to store the model artifacts. If None, we rely on SageMaker " + "to upload model artifacts, so the launching Estimator needs to have 'output_path' set. " + "Default: None") parser.add_argument("--model-checkpoint-to-load", type=str, default=None, help="S3 path to a model checkpoint from a previous training task " "that is going to be resumed.")