Skip to content

Commit e3a9e4e

Browse files
committed
fix the fuji eks model
1 parent 43f75a6 commit e3a9e4e

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

.github/eks-workflow-files/axlearn/axlearn-fuji-model.yml

+38-2
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,50 @@ spec:
1313
spec:
1414
restartPolicy: Never
1515
containers:
16-
- name: axlearn-fuji
16+
- name: axlearn-fuji-model
1717
image: PLACEHOLDER
1818
command:
1919
- bash
2020
- -xo
2121
- pipefail
2222
- -c
23-
- "\nBASEDIR=\"/opt/axlearn\"\nCONFIG=\"fuji-3B-v3-flash-single-host\"\nBASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true\n --xla_gpu_enable_highest_priority_async_stream=true\n --xla_gpu_all_reduce_combine_threshold_bytes=1073741824\n --xla_gpu_all_gather_combine_threshold_bytes=1073741824\n --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824\n --xla_gpu_enable_pipelined_all_gather=true\n --xla_gpu_enable_pipelined_reduce_scatter=true\n --xla_gpu_enable_pipelined_all_reduce=true\n --xla_gpu_enable_while_loop_double_buffering=true\n --xla_gpu_enable_triton_gemm=false\n --xla_gpu_enable_all_gather_combine_by_dim=false\n --xla_gpu_enable_reduce_scatter_combine_by_dim=false\n --xla_disable_hlo_passes=rematerialization}\n\nexport XLA_FLAGS=\"$BASE_XLA_FLAGS ${XLA_FLAGS:-}\" \n\nLOG_DIR=${BASEDIR}/logs\nTRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir\nmkdir -p ${TRAINER_DIR}\n\npython3 -m axlearn.common.launch_trainer_main \\\n --module=text.gpt.c4_trainer \\\n --config=${CONFIG} \\\n --trainer_dir=${TRAINER_DIR} \\\n --data_dir=gs://axlearn-public/tensorflow_datasets \\\n --jax_backend=gpu \n"
23+
- |
24+
BASEDIR="/opt/axlearn"
25+
CONFIG="fuji-3B-v3-flash-single-host"
26+
HLO_DUMP=0
27+
POSTFIX=""
28+
29+
AR_THRESHOLD=1073741824
30+
AG_THRESHOLD=8589934592
31+
RS_THRESHOLD=8589934592
32+
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
33+
--xla_gpu_enable_highest_priority_async_stream=true
34+
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
35+
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
36+
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
37+
--xla_gpu_enable_pipelined_all_gather=true
38+
--xla_gpu_enable_pipelined_reduce_scatter=true
39+
--xla_gpu_enable_pipelined_all_reduce=true
40+
--xla_gpu_enable_while_loop_double_buffering=true
41+
--xla_gpu_enable_triton_gemm=false
42+
--xla_gpu_enable_all_gather_combine_by_dim=false
43+
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
44+
--xla_disable_hlo_passes=rematerialization}
45+
46+
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
47+
export TF_GPU_ALLOCATOR=cuda_malloc_async
48+
49+
LOG_DIR=${BASEDIR}/logs
50+
TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
51+
mkdir -p ${TRAINER_DIR}
52+
53+
54+
python3 -m axlearn.common.launch_trainer_main \
55+
--module=text.gpt.c4_trainer \
56+
--config=${CONFIG} \
57+
--trainer_dir=${TRAINER_DIR} \
58+
--data_dir=gs://axlearn-public/tensorflow_datasets \
59+
--jax_backend=gpu
2460
resources:
2561
limits:
2662
nvidia.com/gpu: 8

0 commit comments

Comments
 (0)