Skip to content

Commit bf3cadc

Browse files
authored
Add SMP v2 notebook for accelerating training with FP8 on P5. (#4578)
* Update example notebooks and related scripts for latest PT-2.2-TSM-2.2 release. Add FP8 training support on P5. * Add example notebook for accelerating Llama-v2 training with FP8 on P5. * Fix typo in version check * Update configurations. Revert jupyter notebook python version in metadata. Set activation_offloading=False for FP8 notebook. Explicitly enable use_smp_implementation in all SMP v2 notebooks. * Update FP8 notebook docs. * Set zipped_data=0 for use_fsx=False FP8 notebook. * Update compute_tflops() script. * Update minimum sagemaker pysdk version to `2.212`.
1 parent 123dbf4 commit bf3cadc

14 files changed

+1342
-82
lines changed

training/distributed_training/pytorch/model_parallel_v2/gpt-neox/smp-finetuning-gpt-neox-fsdp-tp.ipynb

+3-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"%pip install --upgrade \"sagemaker>=2.2\"\n",
83+
"%pip install --upgrade \"sagemaker>=2.212\"\n",
8484
"%pip install sagemaker-experiments"
8585
]
8686
},
@@ -671,6 +671,7 @@
671671
" \"auto_wrap_policy\": \"transformer_auto_wrap_policy\",\n",
672672
" \"model_type\": model_type,\n",
673673
" \"use_smp_flash_attn\": 1,\n",
674+
" \"use_smp_implementation\": 1,\n",
674675
" \"patch_neox_rope\": 0,\n",
675676
" \"distributed_backend\": \"nccl\",\n",
676677
"}\n",
@@ -882,7 +883,7 @@
882883
" },\n",
883884
" },\n",
884885
" py_version=\"py310\",\n",
885-
" framework_version=\"2.0.1\",\n",
886+
" framework_version=\"2.2.0\",\n",
886887
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
887888
" output_path=s3_output_bucket,\n",
888889
" max_run=86400,\n",

training/distributed_training/pytorch/model_parallel_v2/gpt-neox/smp-train-gpt-neox-fsdp-tp.ipynb

+5-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"\n",
2323
"The notebook is accompanied by the following files:\n",
2424
"- `train.py`: The entry point script that'll be passed to the SageMaker PyTorch estimator later in this notebook when launching the training job.\n",
25-
"-\n",
2625
"- `arguments.py`: This file has functions for argument parsing (i.e. hyperparameters).\n",
2726
"- `checkpoints.py`: This file has functions for saving and loading checkpoints.\n",
2827
"- `data_utils`: This file has functions for handling S3 URLs.\n",
@@ -75,7 +74,7 @@
7574
"metadata": {},
7675
"outputs": [],
7776
"source": [
78-
"%pip install --upgrade \"sagemaker>=2.2\"\n",
77+
"%pip install --upgrade \"sagemaker>=2.212\"\n",
7978
"%pip install sagemaker-experiments"
8079
]
8180
},
@@ -624,7 +623,7 @@
624623
"\n",
625624
"tensor_parallel_degree = 2 # An integer in [1, world_size]\n",
626625
"hybrid_shard_degree = (\n",
627-
" 4 # # An integer in [0, world_size // tensor_parallel_degree] and its default value is 0.\n",
626+
" 4 # An integer in [0, world_size // tensor_parallel_degree] and its default value is 0.\n",
628627
")\n",
629628
"offload_activations = True # Enables SM activation offloading implementation.\n",
630629
"activation_loading_horizon = (\n",
@@ -662,6 +661,7 @@
662661
" \"sharding_strategy\": \"hybrid_shard\",\n",
663662
" \"train_batch_size\": 2,\n",
664663
" \"use_smp_flash_attn\": 1,\n",
664+
" \"use_smp_implementation\": 1,\n",
665665
" \"val_batch_size\": 4,\n",
666666
" \"validation_freq\": save_steps,\n",
667667
" \"vocab_size\": 50257,\n",
@@ -874,7 +874,7 @@
874874
" },\n",
875875
" },\n",
876876
" py_version=\"py310\",\n",
877-
" framework_version=\"2.0.1\",\n",
877+
" framework_version=\"2.2.0\",\n",
878878
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
879879
" output_path=s3_output_bucket,\n",
880880
" max_run=86400,\n",
@@ -956,7 +956,7 @@
956956
" },\n",
957957
" },\n",
958958
" py_version=\"py310\",\n",
959-
" framework_version=\"2.0.1\",\n",
959+
" framework_version=\"2.2.0\",\n",
960960
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
961961
" output_path=s3_output_bucket,\n",
962962
" max_run=86400,\n",

training/distributed_training/pytorch/model_parallel_v2/llama_v2/smp-finetuning-llama-fsdp-tp.ipynb

+3-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"%pip install --upgrade \"sagemaker>=2.2\"\n",
83+
"%pip install --upgrade \"sagemaker>=2.212\"\n",
8484
"%pip install sagemaker-experiments"
8585
]
8686
},
@@ -663,6 +663,7 @@
663663
" \"auto_wrap_policy\": \"transformer_auto_wrap_policy\",\n",
664664
" \"model_type\": model_type,\n",
665665
" \"use_smp_flash_attn\": 1,\n",
666+
" \"use_smp_implementation\": 1,\n",
666667
" \"distributed_backend\": \"nccl\",\n",
667668
"}\n",
668669
"\n",
@@ -867,7 +868,7 @@
867868
" },\n",
868869
" },\n",
869870
" py_version=\"py310\",\n",
870-
" framework_version=\"2.0.1\",\n",
871+
" framework_version=\"2.2.0\",\n",
871872
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
872873
" output_path=s3_output_bucket,\n",
873874
" max_run=86400,\n",

0 commit comments

Comments
 (0)