Skip to content

Commit b2579cb

Browse files
committed
@olupton comments work
1 parent 465264b commit b2579cb

File tree

8 files changed

+238
-12
lines changed

8 files changed

+238
-12
lines changed

.github/actions/submit-delete-k8s-job/action.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ runs:
3434
kubectl logs --all-containers=true --all-pods=true --follow job/${{ inputs.job-name }}
3535
3636
post: |
37-
kubectl delete job ${{ inputs.job-name }}
37+
kubectl delete -f "${{ inputs.job-config-file }}"

.github/eks-workflow-files/axlearn/axlearn-job.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ spec:
3737
- name: upload
3838
image: amazon/aws-cli
3939
env:
40-
- name: TEST_DATE
40+
- name: RUN_ID
4141
value: PLACEHOLDER
4242
command:
4343
- sh
@@ -47,7 +47,11 @@ spec:
4747
sleep 5
4848
done
4949
# Upload to S3 bucket
50-
aws s3 cp /opt/output/summary.txt s3://jax-toolbox-eks-output/axlearn/${TEST_DATE}/summary.txt
50+
aws s3 cp /opt/output/summary.txt s3://jax-toolbox-eks-output/axlearn/${RUN_ID}/summary.txt
51+
# Zip the results of all the tests
52+
tar -czf test_logs.tar.gz /opt/output
53+
# Upload logs to S3 bucket
54+
aws s3 cp /opt/output/summary.txt s3://jax-toolbox-eks-output/axlearn/${RUN_ID}/test_logs.tar.gz
5155
volumeMounts:
5256
- name: output
5357
mountPath: /opt/output

.github/workflows/_ci.yaml

+9-8
Original file line numberDiff line numberDiff line change
@@ -688,11 +688,12 @@ jobs:
688688
- name: Download logs from S3
689689
id: log-s3
690690
run: |
691-
mkdir -p /tmp/axlearn-output
692-
aws s3 cp s3://jax-toolbox-eks-output/axlearn/${{ github.run_id }}/summary.txt /tmp/axlearn-output/
691+
mkdir -p axlearn-output
692+
aws s3 cp s3://jax-toolbox-eks-output/axlearn/${{ github.run_id }}/summary.txt axlearn-output/
693+
aws s3 cp s3://jax-toolbox-eks-output/axlearn/${{ github.run_id }}/test_logs.tar.gz axlearn-output/
693694
694-
passed_tests=$(grep -c ": PASSED" /tmp/axlearn-output/summary.txt || true)
695-
failed_tests=$(grep -c ": FAILED" /tmp/axlearn-output/summary.txt || true)
695+
passed_tests=$(grep -c ": PASSED" axlearn-output/summary.txt || true)
696+
failed_tests=$(grep -c ": FAILED" axlearn-output/summary.txt || true)
696697
total_tests=$((failed_tests + passed_tests))
697698
698699
echo "Passed tests: $passed_tests"
@@ -733,7 +734,7 @@ jobs:
733734
message="Passed $passed_tests out of $total_tests." \
734735
color=$badge_color \
735736
to_json schemaVersion label message color \
736-
> "badge-axlearn-test"
737+
> badge-axlearn-test.json
737738
738739
- name: Upload artifacts
739740
if: ${{ !cancelled() }}
@@ -742,8 +743,8 @@ jobs:
742743
name: "artifact-axlearn-test"
743744
path: |
744745
sitrep.json
745-
"badge-axlearn-test"
746-
summary.txt
746+
badge-axlearn-test.json
747+
axlearn-output/*
747748
748749
# the fuji test will run for 20 minutes only, as per 2025-02-24
749750
# is not possible to set the `max_steps` value
@@ -779,5 +780,5 @@ jobs:
779780
uses: ./.github/actions/submit-delete-k8s-job
780781
with:
781782
job-config-file: ".github/eks-workflow-files/axlearn/axlearn-fuji-model.yml"
782-
job-name: ${{ env.JOB_NAME }}
783+
job-name: ${{ env.JOB_NAME }}https://docs.google.com/spreadsheets/d/12JIThodWLhf-H7Ob9p3CGZHLjKEPp17ogp9Do5Ofa6U/edit?gid=1030128481#gid=1030128481
783784

.github/workflows/_test_nccl.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,4 @@ jobs:
124124
# Clean up in case of errors as well as success
125125
- name: Delete Kubernetes job
126126
if: always()
127-
run: kubectl delete -f .github/eks-workflow-files/mpi-nccl-test.yml
127+
run: kubectl delete -f .github/eks-workflow-files/mpi-nccl-test.yml

README.md

+25
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ We support and test the following JAX frameworks and model architectures. More d
1515
| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
1616
| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
1717
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` |
18+
| axlearn | Fuji | pretraining | `gchr.io/nvidia/jax:axlearn` |
1819

1920
# Build Pipeline Status
2021
<table>
@@ -248,6 +249,30 @@ We support and test the following JAX frameworks and model architectures. More d
248249
</a>
249250
</td>
250251
</tr>
252+
<tr>
253+
<td>
254+
<a href="https://github.com/NVIDIA/JAX-Toolbox/blob/main/.github/container/Dockerfile.axlearn">
255+
<img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=AXLearn%3D%7Bcore%2CAXLearn%7D">
256+
</a>
257+
</td>
258+
<td>
259+
<code>ghcr.io/nvidia/jax:axlearn</code>
260+
</td>
261+
<td>
262+
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-axlearn-md">
263+
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-axlearn-build-amd64.json&logo=docker&label=amd64">
264+
</a>
265+
<br>
266+
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-maxtext-md">
267+
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-axlearn-build-arm64.json&logo=docker&label=arm64">
268+
</a>
269+
</td>
270+
<td>
271+
<a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae#file-badge-maxtext-test-json">
272+
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-axleran-test.json&logo=nvidia&label=A100%20distributed">
273+
</a>
274+
</td>
275+
</tr>
251276
</tbody>
252277
</table>
253278

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# AXLearn
2+
[AXLearn](https://github.com/apple/axlearn) is a deep learning design framework, built on top of JAX and XLA, to support the development of large-scale models.
3+
4+
5+
## Hardware and Software Specifications
6+
7+
Functionality have been validated on AWS p5.48xlarge EKS cluster (8x H100 80G); please refer to the [Configs](#configs) section below for some initial configs and performance numbers. We will continue to populate it with more models and configs. We provide both singlenode and multinode pre-training support. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU.
8+
9+
10+
## Containers
11+
We provide a fully built and ready-to-use multi-arch container, bleeding edge: `ghcr.io/nvidia/jax:axlearn`. We also provide nightly dated images with the naming pattern `ghcr.io/nvidia/jax:axlearn-YYYY-MM-DD`, but we encourage you to use the latest ones for the best performance.
12+
13+
*Note*: All paths mentioned in subsequent sections are relative to the top-level directory of the AXLearn repository. When working interactively with containers, make sure you navigate to `/opt/axlearn` before running any commmands.
14+
15+
## Launching a container
16+
Use the following command to launch a container:
17+
```
18+
docker run -ti --gpus=all --net=host --ipc=host -v <WORKSPACE_PATH>:/opt/axlearn/workspace -w /opt/axlearn <CONTAINER> /bin/bash
19+
```
20+
where `WORKSPACE_PATH` is the path to the directory where you would like to store any persistent files and `container` is the name of the maxtext container. You can additionally add dataset and vocab paths with the `-v` flag.
21+
22+
## Running a Fuji model
23+
### Quick Runs
24+
25+
#### EKS Single node: `fuji-3B-v3-flash-single-host`
26+
Fuji models are defined with 1B, 3B, 7B or 70B parameters. In this example, we deploy the training for a Fuji-3B model, that uses flash attention, and runs on a single host. [Here](scripts/eks-fuji.yaml) we provide an example deployment file. The core point of the deployment is:
27+
```bash
28+
python3 -m axlearn.common.launch_trainer_main \
29+
--module=text.gpt.c4_trainer \
30+
--config=${CONFIG} \
31+
--trainer_dir=${TRAINER_DIR} \
32+
--data_dir=gs://axlearn-public/tensorflow_datasets \
33+
--jax_backend=gpu
34+
```
35+
Where `CONFIG="fuji-3B-v3-flash-single-host`. The input dataset is the public tensorflow [C4 dataset](https://www.tensorflow.org/datasets/catalog/c4).
36+
37+
#### Running a multinode job for `fuji-XB-v2-flash`
38+
39+
For running a multinode job we provide a [custom example](scripts/multinode.py). The code access AXLearn directly, it allows to specify a custom dataset, the number of GPUs to use, the global batch size, as well as the `max_sequence_length`.
40+
41+
42+
## XLA Flags
43+
The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for AXLearn.
44+
45+
```
46+
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
47+
--xla_gpu_enable_triton_gemm=false
48+
--xla_gpu_enable_command_buffer=
49+
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
50+
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
51+
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
52+
--xla_gpu_enable_pipelined_all_gather=true
53+
--xla_gpu_enable_pipelined_reduce_scatter=true
54+
--xla_gpu_enable_pipelined_all_reduce=true
55+
--xla_gpu_enable_while_loop_double_buffering=true
56+
--xla_gpu_enable_all_gather_combine_by_dim=false
57+
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
58+
--xla_disable_hlo_passes=rematerialization"
59+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
apiVersion: batch/v1
2+
kind: Job
3+
metadata:
4+
name: axlearn-fuji
5+
# Specify any labels for running on a dedicated queue
6+
spec:
7+
completions: 1
8+
parallelism: 1
9+
template:
10+
spec:
11+
restartPolicy: Never
12+
containers:
13+
- name: axlearn-fuji-model
14+
image: gchr.io/nvidia/jax:axlearn
15+
command:
16+
- bash
17+
- -xo
18+
- pipefail
19+
- -c
20+
- |
21+
BASEDIR="/opt/axlearn"
22+
CONFIG="fuji-3B-v3-flash-single-host"
23+
HLO_DUMP=0
24+
POSTFIX=""
25+
26+
AR_THRESHOLD=1073741824
27+
AG_THRESHOLD=8589934592
28+
RS_THRESHOLD=8589934592
29+
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
30+
--xla_gpu_enable_highest_priority_async_stream=true
31+
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
32+
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
33+
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
34+
--xla_gpu_enable_pipelined_all_gather=true
35+
--xla_gpu_enable_pipelined_reduce_scatter=true
36+
--xla_gpu_enable_pipelined_all_reduce=true
37+
--xla_gpu_enable_while_loop_double_buffering=true
38+
--xla_gpu_enable_triton_gemm=false
39+
--xla_gpu_enable_all_gather_combine_by_dim=false
40+
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
41+
--xla_disable_hlo_passes=rematerialization}
42+
43+
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
44+
export TF_GPU_ALLOCATOR=cuda_malloc_async
45+
46+
LOG_DIR=${BASEDIR}/logs
47+
TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
48+
mkdir -p ${TRAINER_DIR}
49+
50+
51+
python3 -m axlearn.common.launch_trainer_main \
52+
--module=text.gpt.c4_trainer \
53+
--config=${CONFIG} \
54+
--trainer_dir=${TRAINER_DIR} \
55+
--data_dir=gs://axlearn-public/tensorflow_datasets \
56+
--jax_backend=gpu
57+
resources:
58+
limits:
59+
nvidia.com/gpu: 8
60+
volumeMounts:
61+
- name: output
62+
mountPath: /opt/output
63+
# specify any image secret if needed
64+
volumes:
65+
- name: output
66+
emptyDir: {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
3+
from absl import app, flags
4+
from axlearn.common.launch_trainer import run_trainer
5+
from axlearn.common.config import config_for_function
6+
from axlearn.experiments.text.gpt import c4_trainer
7+
from axlearn.common.trainer import SpmdTrainer
8+
9+
FLAGS = flags.FLAGS
10+
FLAGS.set_default("module", "text.gpt.c4_trainer")
11+
FLAGS.set_default("config", "fuji-7B-v2-flash") # Set the model
12+
FLAGS.set_default("trainer_dir", "/opt/host/axlearn-checkpoints") # Set the trainer directory
13+
14+
def main(_):
15+
axlearn_path = "/opt/axlearn"
16+
os.environ["PYTHONPATH"] = f"{axlearn_path}:{os.environ.get('PYTHONPATH', '')}"
17+
18+
n_gpus = 16 # This can be also an env variable
19+
# Base XLA flags
20+
base_flags = [
21+
"--xla_gpu_enable_latency_hiding_scheduler=true",
22+
"--xla_gpu_enable_command_buffer=",
23+
"--xla_gpu_enable_highest_priority_async_stream=true",
24+
"--xla_gpu_all_reduce_combine_threshold_bytes=1073741824",
25+
"--xla_gpu_all_gather_combine_threshold_bytes=1073741824",
26+
"--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824",
27+
"--xla_gpu_enable_pipelined_all_gather=true",
28+
"--xla_gpu_enable_pipelined_reduce_scatter=true",
29+
"--xla_gpu_enable_pipelined_all_reduce=true",
30+
"--xla_gpu_enable_while_loop_double_buffering=true",
31+
"--xla_gpu_enable_triton_gemm=false",
32+
"--xla_gpu_enable_all_gather_combine_by_dim=false",
33+
"--xla_gpu_enable_reduce_scatter_combine_by_dim=false",
34+
"--xla_disable_hlo_passes=rematerialization",
35+
]
36+
# Get existing flags from environment with proper fallback.
37+
existing_xla_flags = os.environ.get("XLA_FLAGS", "").split()
38+
# XLA flags
39+
os.environ.update({
40+
"XLA_FLAGS": " ".join([
41+
*base_flags,
42+
*existing_xla_flags
43+
])})
44+
45+
os.environ.update({
46+
"DATA_DIR":"gs://axlearn-public/tensorflow_datasets", # Set up your input dataset
47+
"NUM_PROCESSES":f"{n_gpus}",
48+
"DISTRIBUTED_COORDINATOR":"127.0.0.1:8080",
49+
"PROCESS_ID":"0",
50+
})
51+
52+
# Raw config
53+
config_fn = c4_trainer.named_trainer_configs()[FLAGS.config]
54+
trainer_config: SpmdTrainer.Config = config_for_function(config_fn).fn()
55+
56+
trainer_config.max_step = 100 # Set the max number of steps to run
57+
trainer_config.dir = "/opt/host/axlearn-checkpoints" # Use 'dir' instead of 'model_dir'
58+
trainer_config.input.input_dispatcher.global_logical_batch_size = 8 # Tune the batch size for training
59+
#trainer_config.input.source.max_sequence_length = 2048 # Tune the max sequence length if running in OOM
60+
trainer_config.checkpointer.save_policy.n = 500 # Save every 500 steps
61+
trainer_config.checkpointer.keep_every_n_steps = 500 # Keep checkpoints
62+
trainer_config.summary_writer.write_every_n_steps = 100 # Log every 100 steps
63+
64+
run_trainer(
65+
trainer_config=trainer_config,
66+
)
67+
68+
69+
if __name__ == "__main__":
70+
from absl import app
71+
app.run(main)

0 commit comments

Comments
 (0)