Skip to content

Commit

Permalink
Refactoring in ADVI (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso authored Dec 19, 2023
1 parent 9f19785 commit 052cb7d
Showing 1 changed file with 59 additions and 41 deletions.
100 changes: 59 additions & 41 deletions fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
Array,
OptaxOptimizer,
Params,
Path,
Status,
)
from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype
Expand Down Expand Up @@ -209,6 +210,61 @@ def _get_base_and_architecture(
)
return base, architecture

def load_state(self, checkpoint_path: Path) -> None:
"""
Load the state of the posterior distribution from a checkpoint path. The checkpoint must be
compatible with the current probabilistic model.
Parameters
----------
checkpoint_path: Path
Path to checkpoint file or directory to restore.
"""
try:
self.restore_checkpoint(checkpoint_path)
except ValueError:
raise ValueError(
f"No checkpoint was found in `checkpoint_path={checkpoint_path}`."
)
self.state = PosteriorStateRepository(checkpoint_dir=checkpoint_path)

state = self.state.get()
if state._encoded_which_params is None:
n_params = len(ravel_pytree(state.params)[0]) // 2
which_params = None
else:
which_params = decode_encoded_tuple_of_lists_of_strings_to_array(
state._encoded_which_params
)
n_params = len(
ravel_pytree(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "log_std"),
)[1]
)[0]
)
_base, _architecture = self._get_base_and_architecture(n_params)
_unravel = self._get_unravel(
FrozenDict(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "log_std"),
)[0]
if which_params
else {
k: dict(params=v["params"]["mean"]) for k, v in state.params.items()
}
),
which_params=which_params,
)[1]

self._base = _base
self._architecture = _architecture
self._unravel = _unravel

def sample(
self,
rng: Optional[PRNGKeyArray] = None,
Expand Down Expand Up @@ -238,47 +294,9 @@ def sample(
rng = self.rng.get()
state = self.state.get()

if self._base is None or self._unravel is None:
if state._encoded_which_params is None:
n_params = len(ravel_pytree(state.params)[0]) // 2
which_params = None
else:
which_params = decode_encoded_tuple_of_lists_of_strings_to_array(
state._encoded_which_params
)
n_params = len(
ravel_pytree(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "log_std"),
)[1]
)[0]
)
_base, _architecture = self._get_base_and_architecture(n_params)
_unravel = self._get_unravel(
FrozenDict(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "log_std"),
)[0]
if which_params
else {
k: dict(params=v["params"]["mean"])
for k, v in state.params.items()
}
),
which_params=which_params,
)[1]

self._base = _base
self._architecture = _architecture
self._unravel = _unravel
else:
_base = self._base
_architecture = self._architecture
_unravel = self._unravel
_base = self._base
_architecture = self._architecture
_unravel = self._unravel

if state._encoded_which_params is None:
means = _unravel(
Expand Down

0 comments on commit 052cb7d

Please sign in to comment.