Skip to content

Commit 85303f3

Browse files
add sharding mixin
1 parent 3c83290 commit 85303f3

39 files changed

+539
-713
lines changed

fortuna/calib_model/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def _calibrate(
137137
rng=self.rng.get(),
138138
state=state,
139139
loss_fun=loss,
140-
training_dataloader=calib_data_loader,
140+
training_data_loader=calib_data_loader,
141141
training_dataset_size=n_calib_data,
142142
n_epochs=config.optimizer.n_epochs,
143143
metrics=config.monitor.metrics,
144-
validation_dataloader=val_data_loader,
144+
validation_data_loader=val_data_loader,
145145
validation_dataset_size=n_val_data,
146146
verbose=config.monitor.verbose,
147147
callbacks=config.callbacks,

fortuna/calib_model/calib_mixin.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Optional
33

4-
from flax.training import checkpoints
4+
# from flax.training import checkpoints
55

66
from fortuna.calib_model.state import CalibState
77
from fortuna.training.mixins.checkpointing import WithCheckpointingMixin
@@ -12,29 +12,30 @@
1212

1313

1414
class WithCalibCheckpointingMixin(WithCheckpointingMixin):
15-
def restore_checkpoint(
16-
self,
17-
restore_checkpoint_dir: Path,
18-
optimizer: Optional[OptaxOptimizer] = None,
19-
prefix: str = "",
20-
**kwargs,
21-
) -> CalibState:
22-
if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile(
23-
restore_checkpoint_dir
24-
):
25-
raise ValueError(
26-
f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found."
27-
)
28-
d = checkpoints.restore_checkpoint(
29-
ckpt_dir=str(restore_checkpoint_dir),
30-
target=None,
31-
step=None,
32-
prefix=prefix,
33-
parallel=True,
34-
)
35-
if d is None:
36-
raise ValueError(
37-
f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`."
38-
)
39-
40-
return CalibState.init_from_dict(d, optimizer, **kwargs)
15+
pass
16+
# def restore_checkpoint(
17+
# self,
18+
# restore_checkpoint_dir: Path,
19+
# optimizer: Optional[OptaxOptimizer] = None,
20+
# prefix: str = "",
21+
# **kwargs,
22+
# ) -> CalibState:
23+
# if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile(
24+
# restore_checkpoint_dir
25+
# ):
26+
# raise ValueError(
27+
# f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found."
28+
# )
29+
# d = checkpoints.restore_checkpoint(
30+
# ckpt_dir=str(restore_checkpoint_dir),
31+
# target=None,
32+
# step=None,
33+
# prefix=prefix,
34+
# parallel=True,
35+
# )
36+
# if d is None:
37+
# raise ValueError(
38+
# f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`."
39+
# )
40+
#
41+
# return CalibState.init_from_dict(d, optimizer, **kwargs)

fortuna/data/dataset/huggingface_datasets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ def get_data_loader(
112112
drop_last: bool
113113
if True, the last batch (which potentially is smaller then the default batch size) is dropped.
114114
verbose: bool
115-
Whether to show a progress bar while iterating over the dataloader or not.
115+
Whether to show a progress bar while iterating over the data_loader or not.
116116
117117
Returns
118118
-------
119119
HuggingFaceDataLoader
120-
The dataloader
120+
The data_loader
121121
"""
122122
iterable = IterableData.from_callable(
123123
lambda *args, **kwargs: self._get_data_loader(

fortuna/data/loader/base.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Tuple,
1010
Type,
1111
TypeVar,
12+
Union
1213
)
1314

1415
from flax import jax_utils
@@ -24,6 +25,10 @@
2425
Status,
2526
Targets,
2627
)
28+
from fortuna.utils.prefetch import prefetch_to_mesh
29+
from fortuna.partitioner.partition_manager.base import PartitionManager
30+
from jax import device_put
31+
from jax.sharding import NamedSharding, PartitionSpec
2732

2833
T = TypeVar("T")
2934

@@ -185,7 +190,7 @@ def from_tensorflow_data_loader(cls: Type[T], tf_data_loader) -> T:
185190
T
186191
A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`.
187192
"""
188-
return cls(iterable=IterableData.from_tf_dataloader(tf_data_loader))
193+
return cls(iterable=IterableData.from_tf_data_loader(tf_data_loader))
189194

190195
@classmethod
191196
def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T:
@@ -203,7 +208,7 @@ def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T:
203208
T
204209
A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`.
205210
"""
206-
return cls(iterable=IterableData.from_torch_dataloader(torch_data_loader))
211+
return cls(iterable=IterableData.from_torch_data_loader(torch_data_loader))
207212

208213
@classmethod
209214
def from_inputs_loaders(
@@ -545,3 +550,30 @@ def __iter__(self, *args, **kwargs):
545550
loader = map(lambda batch: tree_map(self._reshape_inputs, batch), self._loader)
546551
loader = jax_utils.prefetch_to_device(loader, 2)
547552
yield from loader
553+
554+
555+
class ShardedPrefetchedLoader:
556+
def __init__(
557+
self,
558+
loader,
559+
partition_manager: Optional[PartitionManager] = None,
560+
shard: bool = True,
561+
partition_spec: Optional[PartitionSpec] = None
562+
):
563+
self._loader = loader
564+
self.partition_manager = partition_manager
565+
self.shard = shard
566+
self.partition_spec = partition_spec
567+
if partition_manager is None and shard:
568+
raise ValueError("`partition_manager` cannot be None when `shard` is set to True.")
569+
570+
def _shard(self, data: Union[Batch, InputData, Targets]):
571+
return device_put(data, NamedSharding(self.partition_manager.partitioner.mesh, self.partition_spec))
572+
573+
def __iter__(self, *args, **kwargs):
574+
if self.shard:
575+
loader = map(lambda data: tree_map(self._shard, data), self._loader)
576+
loader = prefetch_to_mesh(loader, 2, self.partition_manager.partitioner.mesh, self.partition_spec)
577+
else:
578+
loader = jax_utils.prefetch_to_device(self._loader, 2)
579+
yield from loader

fortuna/data/loader/huggingface_loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
Parameters
3636
----------
3737
iterable : Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array],Array]]]
38-
A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_dataloader`.
38+
A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_data_loader`.
3939
num_unique_labels: int
4040
Number of unique target labels in the task (classification only)
4141
num_inputs: Optional[int]

fortuna/data/loader/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def _inner():
4444
return cls(_inner)
4545

4646
@classmethod
47-
def from_tf_dataloader(cls, tf_dataloader) -> IterableData:
47+
def from_tf_data_loader(cls, tf_data_loader) -> IterableData:
4848
def _inner():
49-
for batch_inputs, batch_targets in tf_dataloader:
49+
for batch_inputs, batch_targets in tf_data_loader:
5050
if not isinstance(batch_inputs, dict):
5151
batch_inputs = batch_inputs.numpy()
5252
else:
@@ -57,9 +57,9 @@ def _inner():
5757
return cls(_inner)
5858

5959
@classmethod
60-
def from_torch_dataloader(cls, torch_dataloader) -> IterableData:
60+
def from_torch_data_loader(cls, torch_data_loader) -> IterableData:
6161
def _inner():
62-
for batch_inputs, batch_targets in torch_dataloader:
62+
for batch_inputs, batch_targets in torch_data_loader:
6363
if not isinstance(batch_inputs, dict):
6464
batch_inputs = batch_inputs.numpy()
6565
else:

fortuna/likelihood/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,10 @@ def _batched_log_joint_prob(
215215
mutable=mutable,
216216
rng=rng,
217217
)
218-
if "mutable" in return_aux:
218+
if mutable is not None:
219219
outputs, aux = outs
220-
mutable = aux["mutable"]
220+
if mutable in return_aux:
221+
mutable = aux["mutable"]
221222
else:
222223
outputs = outs
223224

fortuna/model/model_manager/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from flax import linen as nn
1111
from flax.core import FrozenDict
12-
from flax.training.checkpoints import PyTree
12+
from optax._src.base import PyTree
1313
from jax._src.prng import PRNGKeyArray
1414
import jax.numpy as jnp
1515

fortuna/model/model_manager/classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from flax.core import FrozenDict
1212
import flax.linen as nn
13-
from flax.training.checkpoints import PyTree
13+
from optax._src.base import PyTree
1414
import jax
1515
from jax import random
1616
from jax._src.prng import PRNGKeyArray

fortuna/model/model_manager/regression.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from flax.core import FrozenDict
99
import flax.linen as nn
10-
from flax.training.checkpoints import PyTree
10+
from optax._src.base import PyTree
1111
import jax
1212
from jax import random
1313
from jax._src.prng import PRNGKeyArray
@@ -65,6 +65,7 @@ def apply(
6565
lik_log_var_rngs = None
6666

6767
if mutable is not None:
68+
mutable = mutable.unfreeze()
6869
mutable["model"] = mutable.get("model")
6970
mutable["lik_log_var"] = mutable.get("lik_log_var")
7071

fortuna/model/model_manager/transformers/classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from flax import linen as nn
1010
from flax.core import FrozenDict
11-
from flax.training.checkpoints import PyTree
11+
from optax._src.base import PyTree
1212
import jax
1313
from jax import (
1414
numpy as jnp,

fortuna/output_calib_model/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from fortuna.output_calib_model.config.base import Config
1212
from fortuna.output_calib_model.loss import Loss
13-
from fortuna.output_calib_model.output_calib_mixin import (
14-
WithOutputCalibCheckpointingMixin,
15-
)
13+
from fortuna.training.mixins.checkpointing import WithCheckpointingMixin
1614
from fortuna.output_calib_model.output_calib_model_calibrator import (
1715
JittedOutputCalibModelCalibrator,
1816
MultiDeviceOutputCalibModelCalibrator,
@@ -34,7 +32,7 @@
3432
from fortuna.utils.random import RandomNumberGenerator
3533

3634

37-
class OutputCalibModel(WithOutputCalibCheckpointingMixin, abc.ABC):
35+
class OutputCalibModel(WithCheckpointingMixin, abc.ABC):
3836
"""
3937
Abstract calibration model class.
4038
"""

fortuna/output_calib_model/output_calib_mixin.py

-40
This file was deleted.

fortuna/output_calibrator/output_calib_manager/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from flax.core import FrozenDict
88
import flax.linen as nn
9-
from flax.training.checkpoints import PyTree
9+
from optax._src.base import PyTree
1010
from jax import random
1111
from jax._src.prng import PRNGKeyArray
1212
import jax.numpy as jnp

fortuna/partitioner/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
n_devices: Optional[int] = None,
2020
):
2121
if axis_dims is None:
22-
axis_dims = {"dp": 1, "fsdp": 1, "mp": 1}
22+
axis_dims = {"dp": 1, "fsdp": 1, "mp": -1}
2323
if rules is None:
2424
rules = {}
2525
self.specs = {

fortuna/prob_model/base.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from fortuna.prob_model.calib_config.base import CalibConfig
1515
from fortuna.prob_model.fit_config.base import FitConfig
1616
from fortuna.prob_model.prob_model_calibrator import (
17-
JittedProbModelOutputCalibrator,
18-
MultiDeviceProbModelOutputCalibrator,
17+
ShardedProbModelOutputCalibrator,
1918
ProbModelOutputCalibrator,
2019
)
2120
from fortuna.typing import (
@@ -137,7 +136,7 @@ def _calibrate(
137136
"Pre-compute ensemble of outputs on the calibration data loader."
138137
)
139138

140-
distribute = jax.local_devices()[0].platform != "cpu"
139+
shard = not calib_config.processor.disable_jit
141140

142141
(
143142
calib_ensemble_outputs_loader,
@@ -146,7 +145,7 @@ def _calibrate(
146145
inputs_loader=calib_data_loader.to_inputs_loader(),
147146
n_output_samples=calib_config.processor.n_posterior_samples,
148147
return_size=True,
149-
distribute=distribute,
148+
shard=shard,
150149
)
151150
if calib_config.monitor.verbose:
152151
logging.info(
@@ -157,19 +156,20 @@ def _calibrate(
157156
inputs_loader=val_data_loader.to_inputs_loader(),
158157
n_output_samples=calib_config.processor.n_posterior_samples,
159158
return_size=True,
160-
distribute=distribute,
159+
shard=shard,
161160
)
162161
if val_data_loader is not None
163162
else (None, None)
164163
)
165164

166-
trainer_cls = select_trainer_given_devices(
167-
devices=calib_config.processor.devices,
168-
base_trainer_cls=ProbModelOutputCalibrator,
169-
jitted_trainer_cls=JittedProbModelOutputCalibrator,
170-
multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator,
171-
disable_jit=calib_config.processor.disable_jit,
172-
)
165+
# trainer_cls = select_trainer_given_devices(
166+
# devices=calib_config.processor.devices,
167+
# base_trainer_cls=ProbModelOutputCalibrator,
168+
# jitted_trainer_cls=JittedProbModelOutputCalibrator,
169+
# multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator,
170+
# disable_jit=calib_config.processor.disable_jit,
171+
# )
172+
trainer_cls = ShardedProbModelOutputCalibrator
173173

174174
calibrator = trainer_cls(
175175
calib_outputs_loader=calib_ensemble_outputs_loader,

0 commit comments

Comments
 (0)