Skip to content

Commit 7f186cc

Browse files
committed
fix readme, and copy of zip file, and xla flags
1 parent fc64bbd commit 7f186cc

File tree

6 files changed

+41
-201
lines changed

6 files changed

+41
-201
lines changed

.github/eks-workflow-files/axlearn/axlearn-fuji-model.yml

-4
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,13 @@ spec:
3030
AG_THRESHOLD=8589934592
3131
RS_THRESHOLD=8589934592
3232
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
33-
--xla_gpu_enable_highest_priority_async_stream=true
3433
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
3534
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
3635
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
3736
--xla_gpu_enable_pipelined_all_gather=true
3837
--xla_gpu_enable_pipelined_reduce_scatter=true
3938
--xla_gpu_enable_pipelined_all_reduce=true
4039
--xla_gpu_enable_while_loop_double_buffering=true
41-
--xla_gpu_enable_triton_gemm=false
42-
--xla_gpu_enable_all_gather_combine_by_dim=false
43-
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
4440
--xla_disable_hlo_passes=rematerialization}
4541
4642
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"

.github/eks-workflow-files/axlearn/axlearn-job.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ spec:
5151
# Zip the results of all the tests
5252
tar -czf test_logs.tar.gz /opt/output
5353
# Upload logs to S3 bucket
54-
aws s3 cp /opt/output/summary.txt s3://jax-toolbox-eks-output/axlearn/${RUN_ID}/test_logs.tar.gz
54+
aws s3 cp test_logs.tar.gz s3://jax-toolbox-eks-output/axlearn/${RUN_ID}/test_logs.tar.gz
5555
volumeMounts:
5656
- name: output
5757
mountPath: /opt/output

docs/frameworks/axlearn/README.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# AXLearn
2+
[AXLearn](https://github.com/apple/axlearn) is a deep learning design framework, built on top of JAX and XLA, to support the development of large-scale models.
3+
4+
5+
## Hardware and Software Specifications
6+
7+
The functionality have been validated on AWS p5.48xlarge EKS cluster (8x H100 80G).
8+
9+
10+
## Containers
11+
We provide a multi-architecture container that is regularly updated. Use these containers to avoid dependency and environment issues.
12+
- Latest container: ghcr.io/nvidia/jax:axlearn
13+
- Nightly dated container: ghcr.io/nvidia/jax:axlearn-YYYY-MM-DD
14+
15+
When you start an interactive session:
16+
17+
- Navigate to `/opt/axlearn` inside the container.
18+
- Place your persistent files in a mounted directory (e.g. `/opt/axlearn/workspace`).
19+
20+
## Launching a container
21+
Use the following command to launch a container:
22+
```bash
23+
docker run -ti --gpus=all --net=host --ipc=host -v <WORKSPACE_PATH>:/opt/axlearn/workspace -w /opt/axlearn <CONTAINER> /bin/bash
24+
```
25+
where `WORKSPACE_PATH` is the path to the directory where you would like to store any persistent files and `container` is the name of the maxtext container. You can additionally add dataset and vocab paths with the `-v` flag.
26+
27+
## Example: training `fuji-3B-v3-flash-single-host` on EKS
28+
[Here is the YAML file](../../../.github/eks-workflow-files/axlearn/axlearn-fuji-model.yml) we're using for deploying the training of Fuji-3B model, that uses flash attention, and runs on a single host. The core part of the deployment is:
29+
```bash
30+
python3 -m axlearn.common.launch_trainer_main \
31+
--module=text.gpt.c4_trainer \
32+
--config=${CONFIG} \
33+
--trainer_dir=${TRAINER_DIR} \
34+
--data_dir=gs://axlearn-public/tensorflow_datasets \
35+
--jax_backend=gpu
36+
```
37+
Where `CONFIG="fuji-3B-v3-flash-single-host`. The input dataset is the public tensorflow [C4 dataset](https://www.tensorflow.org/datasets/catalog/c4).
38+
39+
## Testing
40+
[Here is the YAML file](../../../.github/eks-workflow-files/axlearn/axlearn-job.yml) used for testing AXLearn funcitonalities. In particular, this test makes uses of [`test_axlearn.sh` script](../../../.github/container/test-axlearn.sh). The test runs `pytest` against all the tests contains in `/opt/axlearn/axlearn/common` folder.

rosetta/rosetta/projects/axlearn/README.md

-59
This file was deleted.

rosetta/rosetta/projects/axlearn/scripts/eks-fuji.yaml

-66
This file was deleted.

rosetta/rosetta/projects/axlearn/scripts/multinode.py

-71
This file was deleted.

0 commit comments

Comments
 (0)