|
22 | 22 | "\n",
|
23 | 23 | "The notebook is accompanied by the following files:\n",
|
24 | 24 | "- `train.py`: The entry point script that'll be passed to the SageMaker PyTorch estimator later in this notebook when launching the training job.\n",
|
25 |
| - "-\n", |
26 | 25 | "- `arguments.py`: This file has functions for argument parsing (i.e. hyperparameters).\n",
|
27 | 26 | "- `checkpoints.py`: This file has functions for saving and loading checkpoints.\n",
|
28 | 27 | "- `data_utils`: This file has functions for handling S3 URLs.\n",
|
|
75 | 74 | "metadata": {},
|
76 | 75 | "outputs": [],
|
77 | 76 | "source": [
|
78 |
| - "%pip install --upgrade \"sagemaker>=2.2\"\n", |
| 77 | + "%pip install --upgrade \"sagemaker>=2.212\"\n", |
79 | 78 | "%pip install sagemaker-experiments"
|
80 | 79 | ]
|
81 | 80 | },
|
|
624 | 623 | "\n",
|
625 | 624 | "tensor_parallel_degree = 2 # An integer in [1, world_size]\n",
|
626 | 625 | "hybrid_shard_degree = (\n",
|
627 |
| - " 4 # # An integer in [0, world_size // tensor_parallel_degree] and its default value is 0.\n", |
| 626 | + " 4 # An integer in [0, world_size // tensor_parallel_degree] and its default value is 0.\n", |
628 | 627 | ")\n",
|
629 | 628 | "offload_activations = True # Enables SM activation offloading implementation.\n",
|
630 | 629 | "activation_loading_horizon = (\n",
|
|
662 | 661 | " \"sharding_strategy\": \"hybrid_shard\",\n",
|
663 | 662 | " \"train_batch_size\": 2,\n",
|
664 | 663 | " \"use_smp_flash_attn\": 1,\n",
|
| 664 | + " \"use_smp_implementation\": 1,\n", |
665 | 665 | " \"val_batch_size\": 4,\n",
|
666 | 666 | " \"validation_freq\": save_steps,\n",
|
667 | 667 | " \"vocab_size\": 50257,\n",
|
|
874 | 874 | " },\n",
|
875 | 875 | " },\n",
|
876 | 876 | " py_version=\"py310\",\n",
|
877 |
| - " framework_version=\"2.0.1\",\n", |
| 877 | + " framework_version=\"2.2.0\",\n", |
878 | 878 | " # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
|
879 | 879 | " output_path=s3_output_bucket,\n",
|
880 | 880 | " max_run=86400,\n",
|
|
956 | 956 | " },\n",
|
957 | 957 | " },\n",
|
958 | 958 | " py_version=\"py310\",\n",
|
959 |
| - " framework_version=\"2.0.1\",\n", |
| 959 | + " framework_version=\"2.2.0\",\n", |
960 | 960 | " # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
|
961 | 961 | " output_path=s3_output_bucket,\n",
|
962 | 962 | " max_run=86400,\n",
|
|
0 commit comments