Skip to content

Commit ae5c254

Browse files
Fix bug in output sampler on GPU (#170)
1 parent f2c9b3c commit ae5c254

File tree

1 file changed

+2
-1
lines changed
  • fortuna/prob_model/predictive

1 file changed

+2
-1
lines changed

fortuna/prob_model/predictive/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ def _sample(key, _inputs):
563563
)
564564
if distribute:
565565
outputs = jnp.stack(
566-
list(map(lambda key: _sample(shard_prng_key(key), inputs), keys))
566+
list(map(lambda key: _sample(shard_prng_key(key), inputs), keys)),
567+
axis=1,
567568
)
568569
outputs = self._unshard_ensemble_arrays(outputs)
569570
else:

0 commit comments

Comments
 (0)