|
| 1 | +# AXLearn |
| 2 | +[AXLearn](https://github.com/apple/axlearn) is a deep learning design framework, built on top of JAX and XLA, to support the development of large-scale models. |
| 3 | + |
| 4 | + |
| 5 | +## Hardware and Software Specifications |
| 6 | + |
| 7 | +The functionality have been validated on AWS p5.48xlarge EKS cluster (8x H100 80G). |
| 8 | + |
| 9 | + |
| 10 | +## Containers |
| 11 | +We provide a multi-architecture container that is regularly updated. Use these containers to avoid dependency and environment issues. |
| 12 | +- Latest container: ghcr.io/nvidia/jax:axlearn |
| 13 | +- Nightly dated container: ghcr.io/nvidia/jax:axlearn-YYYY-MM-DD |
| 14 | + |
| 15 | +When you start an interactive session: |
| 16 | + |
| 17 | +- Navigate to `/opt/axlearn` inside the container. |
| 18 | +- Place your persistent files in a mounted directory (e.g. `/opt/axlearn/workspace`). |
| 19 | + |
| 20 | +## Launching a container |
| 21 | +Use the following command to launch a container: |
| 22 | +```bash |
| 23 | +docker run -ti --gpus=all --net=host --ipc=host -v <WORKSPACE_PATH>:/opt/axlearn/workspace -w /opt/axlearn <CONTAINER> /bin/bash |
| 24 | +``` |
| 25 | +where `WORKSPACE_PATH` is the path to the directory where you would like to store any persistent files and `container` is the name of the maxtext container. You can additionally add dataset and vocab paths with the `-v` flag. |
| 26 | + |
| 27 | +## Example: training `fuji-3B-v3-flash-single-host` on EKS |
| 28 | +[Here is the YAML file](../../../.github/eks-workflow-files/axlearn/axlearn-fuji-model.yml) we're using for deploying the training of Fuji-3B model, that uses flash attention, and runs on a single host. The core part of the deployment is: |
| 29 | +```bash |
| 30 | +python3 -m axlearn.common.launch_trainer_main \ |
| 31 | + --module=text.gpt.c4_trainer \ |
| 32 | + --config=${CONFIG} \ |
| 33 | + --trainer_dir=${TRAINER_DIR} \ |
| 34 | + --data_dir=gs://axlearn-public/tensorflow_datasets \ |
| 35 | + --jax_backend=gpu |
| 36 | +``` |
| 37 | +Where `CONFIG="fuji-3B-v3-flash-single-host`. The input dataset is the public tensorflow [C4 dataset](https://www.tensorflow.org/datasets/catalog/c4). |
| 38 | + |
| 39 | +## Testing |
| 40 | +[Here is the YAML file](../../../.github/eks-workflow-files/axlearn/axlearn-job.yml) used for testing AXLearn funcitonalities. In particular, this test makes uses of [`test_axlearn.sh` script](../../../.github/container/test-axlearn.sh). The test runs `pytest` against all the tests contains in `/opt/axlearn/axlearn/common` folder. |
0 commit comments