Skip to content

Commit 3c83290

Browse files
pre-commit
1 parent 902a2a1 commit 3c83290

File tree

29 files changed

+619
-347
lines changed

29 files changed

+619
-347
lines changed

benchmarks/transformers/prob_model_text_classification.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,17 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
400400

401401
model_editor = None
402402
if args.enable_probit_model_editor:
403-
probit_freeze_fun = lambda p, v: True if "classifier" in p else False if args.probit_last_layer_only else None
403+
probit_freeze_fun = (
404+
lambda p, v: True
405+
if "classifier" in p
406+
else False
407+
if args.probit_last_layer_only
408+
else None
409+
)
404410
model_editor = ProbitModelEditor(
405411
freeze_fun=probit_freeze_fun,
406412
init_log_var=args.probit_init_log_var,
407-
stop_gradient=args.probit_stop_gradient
413+
stop_gradient=args.probit_stop_gradient,
408414
)
409415

410416
### TRAINING

fortuna/calib_model/base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def load_state(self, checkpoint_dir: Path) -> None:
176176
)
177177
self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_dir)
178178

179-
def save_state(
180-
self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1
181-
) -> None:
179+
def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None:
182180
return self.predictive.state.put(
183181
self.predictive.state.get(),
184182
checkpoint_dir=checkpoint_dir,

fortuna/calib_model/calib_model_calibrator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from optax._src.base import PyTree
1414

1515
from fortuna.calib_model.state import CalibState
16-
from fortuna.training.trainer import TrainerABC
1716
from fortuna.training.mixins.jitted import JittedMixin
1817
from fortuna.training.mixins.multi_device import MultiDeviceMixin
18+
from fortuna.training.trainer import TrainerABC
1919
from fortuna.typing import (
2020
Array,
2121
Batch,

fortuna/likelihood/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Tuple,
88
Union,
99
)
10+
1011
from flax.core import FrozenDict
1112
from jax import (
1213
jit,

0 commit comments

Comments
 (0)