Skip to content

Commit 5a8c8a9

Browse files
author
Md Khan
committed
fixed MaxText logging bug
1 parent 1ec8734 commit 5a8c8a9

File tree

2 files changed

+8
-85
lines changed

2 files changed

+8
-85
lines changed

.github/container/maxtext-mha.patch

+7-84
Original file line numberDiff line numberDiff line change
@@ -34,92 +34,15 @@ index 1d822ad..b89ea3c 100644
3434

3535
raw_keys["dtype"] = jax.numpy.dtype(raw_keys["dtype"])
3636
diff --git a/MaxText/train.py b/MaxText/train.py
37-
index f3c2fb1..d69c363 100644
37+
index f3c2fb1..e454eec 100644
3838
--- a/MaxText/train.py
3939
+++ b/MaxText/train.py
40-
@@ -107,34 +107,8 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
41-
})
42-
metrics['scalar'].update({'learning/current_learning_rate': lr })
43-
44-
-_buffered_step = None
45-
-_buffered_metrics = None
46-
-def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
47-
- """Entry point for all metrics writing in Train's Main.
48-
- TODO: would be better as a Class in the future (that initialized all state!)
49-
-
50-
- To avoid introducing an unnecessary dependency, we "double buffer" -- we hold
51-
- onto the last metrics and step and only publish when we receive a new metrics and step.
52-
- The logic is that this ensures that Jax is able to queues train_steps and we
53-
- don't block when turning "lazy" Jax arrays into real Python numbers.
54-
- """
55-
- global _buffered_step, _buffered_metrics
56-
-
57-
- if _buffered_metrics is not None:
58-
- if _buffered_step is None:
59-
- raise ValueError(f"When writing metrics, {_buffered_step=} was none")
60-
- write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)
61-
-
62-
- if config.metrics_file:
63-
- max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
64-
65-
- if config.gcs_metrics and jax.process_index() == 0:
66-
- running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
67-
-
68-
- _buffered_step = step
69-
- _buffered_metrics = metrics
70-
-
71-
-def write_metrics_to_tensorboard(writer, metrics, step, config):
72-
+def write_metrics(writer, metrics, step, config):
73-
""" Writes metrics to tensorboard"""
74-
with jax.spmd_mode('allow_all'):
75-
if jax.process_index() == 0:
76-
@@ -329,6 +303,8 @@ def train_loop(config, state=None):
77-
static_argnums=static_argnums,
78-
donate_argnums=donate_argnums)
79-
80-
+ last_step_completion = datetime.datetime.now()
81-
+
82-
local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None
83-
running_gcs_metrics = [] if config.gcs_metrics else None
84-
85-
@@ -338,22 +314,22 @@ def train_loop(config, state=None):
86-
raise ValueError("Profiling requested but initial profiling step set past training final step")
87-
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)
88-
89-
- example_batch = None
90-
- last_step_completion = datetime.datetime.now()
91-
-
92-
+ nextrng = jax.random.fold_in(init_rng, start_step)
93-
+ example_batch = load_next_batch(data_iterator, None, config)
94-
for step in np.arange(start_step, config.steps):
95-
if step == first_profiling_step:
96-
max_utils.activate_profiler(config)
97-
98-
- example_batch = load_next_batch(data_iterator, example_batch, config)
99-
- nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
100-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
101-
state, metrics = p_train_step(
102-
state, example_batch, nextrng
103-
)
104-
105-
+ example_batch = load_next_batch(data_iterator, example_batch, config)
106-
+ nextrng = jax.random.fold_in(init_rng, step+1)
107-
new_time = datetime.datetime.now()
108-
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
109-
+ write_metrics(writer, metrics, step, config)
110-
last_step_completion = new_time
111-
112-
if checkpoint_manager is not None:
113-
@@ -364,7 +340,11 @@ def train_loop(config, state=None):
114-
checkpoint_manager.wait_until_finished()
115-
sys.exit()
40+
@@ -369,6 +369,8 @@ def train_loop(config, state=None):
41+
if step == last_profiling_step:
42+
max_utils.deactivate_profiler(config)
11643

117-
- write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
118-
+ if config.metrics_file:
119-
+ max_utils.write_metrics_locally(metrics, step, config, local_metrics_file)
44+
+ write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
12045
+
121-
+ if config.gcs_metrics and jax.process_index() == 0:
122-
+ running_gcs_metrics = max_utils.write_metrics_for_gcs(metrics, step, config, running_gcs_metrics)
46+
max_utils.close_summary_writer(writer)
47+
return state
12348

124-
if step == last_profiling_step:
125-
max_utils.deactivate_profiler(config)

.github/workflows/_test_maxtext.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
MAXTEXT_IMAGE:
77
type: string
88
description: MaxText image from ghcr.io/nvidia
9-
default: ghcr.io/nvidia/upstream-maxtext:nightly-2024-01-31
9+
default: ghcr.io/nvidia/upstream-maxtext:latest
1010
required: false
1111
EXTRA_TEST_ARGS:
1212
type: string

0 commit comments

Comments
 (0)