14
14
from fortuna .prob_model .calib_config .base import CalibConfig
15
15
from fortuna .prob_model .fit_config .base import FitConfig
16
16
from fortuna .prob_model .prob_model_calibrator import (
17
- JittedProbModelOutputCalibrator ,
18
- MultiDeviceProbModelOutputCalibrator ,
17
+ ShardedProbModelOutputCalibrator ,
19
18
ProbModelOutputCalibrator ,
20
19
)
21
20
from fortuna .typing import (
@@ -137,7 +136,7 @@ def _calibrate(
137
136
"Pre-compute ensemble of outputs on the calibration data loader."
138
137
)
139
138
140
- distribute = jax . local_devices ()[ 0 ]. platform != "cpu"
139
+ shard = not calib_config . processor . disable_jit
141
140
142
141
(
143
142
calib_ensemble_outputs_loader ,
@@ -146,7 +145,7 @@ def _calibrate(
146
145
inputs_loader = calib_data_loader .to_inputs_loader (),
147
146
n_output_samples = calib_config .processor .n_posterior_samples ,
148
147
return_size = True ,
149
- distribute = distribute ,
148
+ shard = shard ,
150
149
)
151
150
if calib_config .monitor .verbose :
152
151
logging .info (
@@ -157,19 +156,20 @@ def _calibrate(
157
156
inputs_loader = val_data_loader .to_inputs_loader (),
158
157
n_output_samples = calib_config .processor .n_posterior_samples ,
159
158
return_size = True ,
160
- distribute = distribute ,
159
+ shard = shard ,
161
160
)
162
161
if val_data_loader is not None
163
162
else (None , None )
164
163
)
165
164
166
- trainer_cls = select_trainer_given_devices (
167
- devices = calib_config .processor .devices ,
168
- base_trainer_cls = ProbModelOutputCalibrator ,
169
- jitted_trainer_cls = JittedProbModelOutputCalibrator ,
170
- multi_device_trainer_cls = MultiDeviceProbModelOutputCalibrator ,
171
- disable_jit = calib_config .processor .disable_jit ,
172
- )
165
+ # trainer_cls = select_trainer_given_devices(
166
+ # devices=calib_config.processor.devices,
167
+ # base_trainer_cls=ProbModelOutputCalibrator,
168
+ # jitted_trainer_cls=JittedProbModelOutputCalibrator,
169
+ # multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator,
170
+ # disable_jit=calib_config.processor.disable_jit,
171
+ # )
172
+ trainer_cls = ShardedProbModelOutputCalibrator
173
173
174
174
calibrator = trainer_cls (
175
175
calib_outputs_loader = calib_ensemble_outputs_loader ,
0 commit comments