@@ -13,14 +13,50 @@ spec:
13
13
spec :
14
14
restartPolicy : Never
15
15
containers :
16
- - name : axlearn-fuji
16
+ - name : axlearn-fuji-model
17
17
image : PLACEHOLDER
18
18
command :
19
19
- bash
20
20
- -xo
21
21
- pipefail
22
22
- -c
23
- - "\nBASEDIR=\"/opt/axlearn\"\nCONFIG=\"fuji-3B-v3-flash-single-host\"\nBASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true\n --xla_gpu_enable_highest_priority_async_stream=true\n --xla_gpu_all_reduce_combine_threshold_bytes=1073741824\n --xla_gpu_all_gather_combine_threshold_bytes=1073741824\n --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824\n --xla_gpu_enable_pipelined_all_gather=true\n --xla_gpu_enable_pipelined_reduce_scatter=true\n --xla_gpu_enable_pipelined_all_reduce=true\n --xla_gpu_enable_while_loop_double_buffering=true\n --xla_gpu_enable_triton_gemm=false\n --xla_gpu_enable_all_gather_combine_by_dim=false\n --xla_gpu_enable_reduce_scatter_combine_by_dim=false\n --xla_disable_hlo_passes=rematerialization}\n\nexport XLA_FLAGS=\"$BASE_XLA_FLAGS ${XLA_FLAGS:-}\" \n\nLOG_DIR=${BASEDIR}/logs\nTRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir\nmkdir -p ${TRAINER_DIR}\n\npython3 -m axlearn.common.launch_trainer_main \\\n --module=text.gpt.c4_trainer \\\n --config=${CONFIG} \\\n --trainer_dir=${TRAINER_DIR} \\\n --data_dir=gs://axlearn-public/tensorflow_datasets \\\n --jax_backend=gpu \n"
23
+ - |
24
+ BASEDIR="/opt/axlearn"
25
+ CONFIG="fuji-3B-v3-flash-single-host"
26
+ HLO_DUMP=0
27
+ POSTFIX=""
28
+
29
+ AR_THRESHOLD=1073741824
30
+ AG_THRESHOLD=8589934592
31
+ RS_THRESHOLD=8589934592
32
+ BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
33
+ --xla_gpu_enable_highest_priority_async_stream=true
34
+ --xla_gpu_all_reduce_combine_threshold_bytes=1073741824
35
+ --xla_gpu_all_gather_combine_threshold_bytes=1073741824
36
+ --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
37
+ --xla_gpu_enable_pipelined_all_gather=true
38
+ --xla_gpu_enable_pipelined_reduce_scatter=true
39
+ --xla_gpu_enable_pipelined_all_reduce=true
40
+ --xla_gpu_enable_while_loop_double_buffering=true
41
+ --xla_gpu_enable_triton_gemm=false
42
+ --xla_gpu_enable_all_gather_combine_by_dim=false
43
+ --xla_gpu_enable_reduce_scatter_combine_by_dim=false
44
+ --xla_disable_hlo_passes=rematerialization}
45
+
46
+ export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
47
+ export TF_GPU_ALLOCATOR=cuda_malloc_async
48
+
49
+ LOG_DIR=${BASEDIR}/logs
50
+ TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
51
+ mkdir -p ${TRAINER_DIR}
52
+
53
+
54
+ python3 -m axlearn.common.launch_trainer_main \
55
+ --module=text.gpt.c4_trainer \
56
+ --config=${CONFIG} \
57
+ --trainer_dir=${TRAINER_DIR} \
58
+ --data_dir=gs://axlearn-public/tensorflow_datasets \
59
+ --jax_backend=gpu
24
60
resources :
25
61
limits :
26
62
nvidia.com/gpu : 8
0 commit comments