|
| 1 | +import os |
| 2 | + |
| 3 | +from absl import app, flags |
| 4 | +from axlearn.common.launch_trainer import run_trainer |
| 5 | +from axlearn.common.config import config_for_function |
| 6 | +from axlearn.experiments.text.gpt import c4_trainer |
| 7 | +from axlearn.common.trainer import SpmdTrainer |
| 8 | + |
| 9 | +FLAGS = flags.FLAGS |
| 10 | +FLAGS.set_default("module", "text.gpt.c4_trainer") |
| 11 | +FLAGS.set_default("config", "fuji-7B-v2-flash") # Set the model |
| 12 | +FLAGS.set_default("trainer_dir", "/opt/host/axlearn-checkpoints") # Set the trainer directory |
| 13 | + |
| 14 | +def main(_): |
| 15 | + axlearn_path = "/opt/axlearn" |
| 16 | + os.environ["PYTHONPATH"] = f"{axlearn_path}:{os.environ.get('PYTHONPATH', '')}" |
| 17 | + |
| 18 | + n_gpus = 16 # This can be also an env variable |
| 19 | + # Base XLA flags |
| 20 | + base_flags = [ |
| 21 | + "--xla_gpu_enable_latency_hiding_scheduler=true", |
| 22 | + "--xla_gpu_enable_command_buffer=", |
| 23 | + "--xla_gpu_enable_highest_priority_async_stream=true", |
| 24 | + "--xla_gpu_all_reduce_combine_threshold_bytes=1073741824", |
| 25 | + "--xla_gpu_all_gather_combine_threshold_bytes=1073741824", |
| 26 | + "--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824", |
| 27 | + "--xla_gpu_enable_pipelined_all_gather=true", |
| 28 | + "--xla_gpu_enable_pipelined_reduce_scatter=true", |
| 29 | + "--xla_gpu_enable_pipelined_all_reduce=true", |
| 30 | + "--xla_gpu_enable_while_loop_double_buffering=true", |
| 31 | + "--xla_gpu_enable_triton_gemm=false", |
| 32 | + "--xla_gpu_enable_all_gather_combine_by_dim=false", |
| 33 | + "--xla_gpu_enable_reduce_scatter_combine_by_dim=false", |
| 34 | + "--xla_disable_hlo_passes=rematerialization", |
| 35 | + ] |
| 36 | + # Get existing flags from environment with proper fallback. |
| 37 | + existing_xla_flags = os.environ.get("XLA_FLAGS", "").split() |
| 38 | + # XLA flags |
| 39 | + os.environ.update({ |
| 40 | + "XLA_FLAGS": " ".join([ |
| 41 | + *base_flags, |
| 42 | + *existing_xla_flags |
| 43 | + ])}) |
| 44 | + |
| 45 | + os.environ.update({ |
| 46 | + "DATA_DIR":"gs://axlearn-public/tensorflow_datasets", # Set up your input dataset |
| 47 | + "NUM_PROCESSES":f"{n_gpus}", |
| 48 | + "DISTRIBUTED_COORDINATOR":"127.0.0.1:8080", |
| 49 | + "PROCESS_ID":"0", |
| 50 | + }) |
| 51 | + |
| 52 | + # Raw config |
| 53 | + config_fn = c4_trainer.named_trainer_configs()[FLAGS.config] |
| 54 | + trainer_config: SpmdTrainer.Config = config_for_function(config_fn).fn() |
| 55 | + |
| 56 | + trainer_config.max_step = 100 # Set the max number of steps to run |
| 57 | + trainer_config.dir = "/opt/host/axlearn-checkpoints" # Use 'dir' instead of 'model_dir' |
| 58 | + trainer_config.input.input_dispatcher.global_logical_batch_size = 8 # Tune the batch size for training |
| 59 | + #trainer_config.input.source.max_sequence_length = 2048 # Tune the max sequence length if running in OOM |
| 60 | + trainer_config.checkpointer.save_policy.n = 500 # Save every 500 steps |
| 61 | + trainer_config.checkpointer.keep_every_n_steps = 500 # Keep checkpoints |
| 62 | + trainer_config.summary_writer.write_every_n_steps = 100 # Log every 100 steps |
| 63 | + |
| 64 | + run_trainer( |
| 65 | + trainer_config=trainer_config, |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +if __name__ == "__main__": |
| 70 | + from absl import app |
| 71 | + app.run(main) |
0 commit comments