Skip to content

Commit 83eabd8

Browse files
committed
updated.
1 parent 0eb3dcf commit 83eabd8

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

redco/predictors/utils.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
import numpy as np
1616
import jax
17-
from jax.sharding import Mesh, PartitionSpec
18-
from jax.experimental.multihost_utils import host_local_array_to_global_array
17+
from jax.experimental.multihost_utils import process_allgather
1918

2019

2120
def add_idxes(examples):
@@ -49,14 +48,9 @@ def pred_step(pred_rng, params, batch, pred_fn, mesh):
4948

5049
def process_batch_preds(batch_preds_with_idxes, mesh):
5150
if mesh is None:
52-
global_mesh = Mesh(
53-
devices=np.array(jax.devices()).reshape(
54-
jax.process_count(), jax.local_device_count()),
55-
axis_names=('host', 'local'))
56-
batch_preds_with_idxes = host_local_array_to_global_array(
57-
batch_preds_with_idxes,
58-
global_mesh=global_mesh,
59-
pspecs=PartitionSpec('host'))
51+
batch_preds_with_idxes = jax.tree.map(
52+
lambda t: t.reshape((-1,) + t.shape[2:]),
53+
process_allgather(batch_preds_with_idxes))
6054

6155
batch_preds_with_idxes = jax.tree.map(np.asarray, batch_preds_with_idxes)
6256
preds = jax.tree.map(

0 commit comments

Comments
 (0)