@@ -34,92 +34,15 @@ index 1d822ad..b89ea3c 100644
34
34
35
35
raw_keys["dtype"] = jax.numpy.dtype(raw_keys["dtype"])
36
36
diff --git a/MaxText/train.py b/MaxText/train.py
37
- index f3c2fb1..d69c363 100644
37
+ index f3c2fb1..e454eec 100644
38
38
--- a/MaxText/train.py
39
39
+++ b/MaxText/train.py
40
- @@ -107,34 +107,8 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
41
- })
42
- metrics['scalar'].update({'learning/current_learning_rate': lr })
43
-
44
- - _buffered_step = None
45
- - _buffered_metrics = None
46
- - def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
47
- - """Entry point for all metrics writing in Train's Main.
48
- - TODO: would be better as a Class in the future (that initialized all state!)
49
- -
50
- - To avoid introducing an unnecessary dependency, we "double buffer" -- we hold
51
- - onto the last metrics and step and only publish when we receive a new metrics and step.
52
- - The logic is that this ensures that Jax is able to queues train_steps and we
53
- - don't block when turning "lazy" Jax arrays into real Python numbers.
54
- - """
55
- - global _buffered_step, _buffered_metrics
56
- -
57
- - if _buffered_metrics is not None:
58
- - if _buffered_step is None:
59
- - raise ValueError(f"When writing metrics, {_buffered_step=} was none")
60
- - write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)
61
- -
62
- - if config.metrics_file:
63
- - max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
64
-
65
- - if config.gcs_metrics and jax.process_index() == 0:
66
- - running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
67
- -
68
- - _buffered_step = step
69
- - _buffered_metrics = metrics
70
- -
71
- - def write_metrics_to_tensorboard(writer, metrics, step, config):
72
- + def write_metrics(writer, metrics, step, config):
73
- """ Writes metrics to tensorboard"""
74
- with jax.spmd_mode('allow_all'):
75
- if jax.process_index() == 0:
76
- @@ -329,6 +303,8 @@ def train_loop(config, state=None):
77
- static_argnums=static_argnums,
78
- donate_argnums=donate_argnums)
79
-
80
- + last_step_completion = datetime.datetime.now()
81
- +
82
- local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None
83
- running_gcs_metrics = [] if config.gcs_metrics else None
84
-
85
- @@ -338,22 +314,22 @@ def train_loop(config, state=None):
86
- raise ValueError("Profiling requested but initial profiling step set past training final step")
87
- last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)
88
-
89
- - example_batch = None
90
- - last_step_completion = datetime.datetime.now()
91
- -
92
- + nextrng = jax.random.fold_in(init_rng, start_step)
93
- + example_batch = load_next_batch(data_iterator, None, config)
94
- for step in np.arange(start_step, config.steps):
95
- if step == first_profiling_step:
96
- max_utils.activate_profiler(config)
97
-
98
- - example_batch = load_next_batch(data_iterator, example_batch, config)
99
- - nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
100
- with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
101
- state, metrics = p_train_step(
102
- state, example_batch, nextrng
103
- )
104
-
105
- + example_batch = load_next_batch(data_iterator, example_batch, config)
106
- + nextrng = jax.random.fold_in(init_rng, step+1)
107
- new_time = datetime.datetime.now()
108
- record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
109
- + write_metrics(writer, metrics, step, config)
110
- last_step_completion = new_time
111
-
112
- if checkpoint_manager is not None:
113
- @@ -364,7 +340,11 @@ def train_loop(config, state=None):
114
- checkpoint_manager.wait_until_finished()
115
- sys.exit()
40
+ @@ -369,6 +369,8 @@ def train_loop(config, state=None):
41
+ if step == last_profiling_step:
42
+ max_utils.deactivate_profiler(config)
116
43
117
- - write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
118
- + if config.metrics_file:
119
- + max_utils.write_metrics_locally(metrics, step, config, local_metrics_file)
44
+ + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
120
45
+
121
- + if config.gcs_metrics and jax.process_index() == 0:
122
- + running_gcs_metrics = max_utils.write_metrics_for_gcs(metrics, step, config, running_gcs_metrics)
46
+ max_utils.close_summary_writer(writer)
47
+ return state
123
48
124
- if step == last_profiling_step:
125
- max_utils.deactivate_profiler(config)
0 commit comments