Skip to content

Commit 89df0ca

Browse files
committed
updated.
1 parent fb9b737 commit 89df0ca

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

redco/deployers/deployer.py

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import jax
17+
from flax.training.common_utils import shard_prng_key
1718
import orbax.checkpoint as ocp
1819

1920
from .data_utils import get_host_examples, get_data_batches
@@ -274,6 +275,14 @@ def gen_rng(self):
274275
self._rng, new_rng = jax.random.split(self._rng)
275276
return new_rng
276277

278+
def gen_model_step_rng(self):
279+
rng = self.gen_rng()
280+
if self.mesh is None:
281+
rng = jax.random.split(
282+
rng, num=jax.process_count())[jax.process_index()]
283+
rng = shard_prng_key(rng)
284+
return rng
285+
277286
def log_info(self, info, title=None, step=None):
278287
"""Logs a messages"""
279288
log_info(

redco/predictors/predictor.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,7 @@ def predict(self,
136136
self.setup_running_step(
137137
dummy_batch=batch, params_shape_or_params=params)
138138

139-
rng = self._deployer.gen_rng()
140-
if self.mesh is None:
141-
rng = jax.random.split(
142-
rng, num=jax.process_count())[jax.process_index()]
143-
rng = shard_prng_key(rng)
139+
rng = self._deployer.gen_model_step_rng()
144140
batch_preds_with_idxes = self._deployer.run_model_step(
145141
step_fn=self._p_pred_step, input_args=(rng, params, batch))
146142
batch_preds = process_batch_preds(

redco/trainers/trainer.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from jax.sharding import PartitionSpec as P
2323
from flax.jax_utils import replicate, unreplicate
2424
from flax.training import train_state
25-
from flax.training.common_utils import shard_prng_key
2625
from flax.core.frozen_dict import freeze
2726
from orbax.checkpoint.utils import \
2827
fully_replicated_host_local_array_to_global_array
@@ -227,11 +226,7 @@ def train(self, examples, per_device_batch_size, desc=None):
227226
if self._p_train_step is None:
228227
self.setup_running_step(dummy_batch=batch)
229228

230-
rng = self._deployer.gen_rng()
231-
if self.mesh is None:
232-
rng = jax.random.split(
233-
rng, num=jax.process_count())[jax.process_index()]
234-
rng = shard_prng_key(rng)
229+
rng = self._deployer.gen_model_step_rng()
235230
self._state, metrics = self._deployer.run_model_step(
236231
step_fn=self._p_train_step,
237232
input_args=(rng, self._state, batch))
@@ -265,9 +260,9 @@ def eval_loss(self, examples, per_device_batch_size, desc=None):
265260
if self._p_eval_step is None:
266261
self.setup_running_step(dummy_batch=batch)
267262

263+
rng = self._deployer.gen_model_step_rng()
268264
metrics = self._deployer.run_model_step(
269-
step_fn=self._p_eval_step,
270-
input_args=(jax.random.PRNGKey(0), self._state, batch))
265+
step_fn=self._p_eval_step, input_args=(rng, self._state, batch))
271266
if self.mesh is None:
272267
metrics = unreplicate(metrics)
273268

0 commit comments

Comments
 (0)