diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py index 8f06738e..5cd83647 100755 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py @@ -48,6 +48,7 @@ Array, OptaxOptimizer, Params, + Path, Status, ) from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype @@ -209,6 +210,61 @@ def _get_base_and_architecture( ) return base, architecture + def load_state(self, checkpoint_path: Path) -> None: + """ + Load the state of the posterior distribution from a checkpoint path. The checkpoint must be + compatible with the current probabilistic model. + + Parameters + ---------- + checkpoint_path: Path + Path to checkpoint file or directory to restore. + """ + try: + self.restore_checkpoint(checkpoint_path) + except ValueError: + raise ValueError( + f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." + ) + self.state = PosteriorStateRepository(checkpoint_dir=checkpoint_path) + + state = self.state.get() + if state._encoded_which_params is None: + n_params = len(ravel_pytree(state.params)[0]) // 2 + which_params = None + else: + which_params = decode_encoded_tuple_of_lists_of_strings_to_array( + state._encoded_which_params + ) + n_params = len( + ravel_pytree( + nested_unpair( + d=state.params.unfreeze(), + key_paths=which_params, + labels=("mean", "log_std"), + )[1] + )[0] + ) + _base, _architecture = self._get_base_and_architecture(n_params) + _unravel = self._get_unravel( + FrozenDict( + nested_unpair( + d=state.params.unfreeze(), + key_paths=which_params, + labels=("mean", "log_std"), + )[0] + if which_params + else { + k: dict(params=v["params"]["mean"]) for k, v in state.params.items() + } + ), + which_params=which_params, + )[1] + + self._base = _base + self._architecture = _architecture + self._unravel = _unravel + def sample( self, rng: Optional[PRNGKeyArray] = None, @@ -238,47 +294,9 @@ def sample( rng = self.rng.get() state = self.state.get() - if self._base is None or self._unravel is None: - if state._encoded_which_params is None: - n_params = len(ravel_pytree(state.params)[0]) // 2 - which_params = None - else: - which_params = decode_encoded_tuple_of_lists_of_strings_to_array( - state._encoded_which_params - ) - n_params = len( - ravel_pytree( - nested_unpair( - d=state.params.unfreeze(), - key_paths=which_params, - labels=("mean", "log_std"), - )[1] - )[0] - ) - _base, _architecture = self._get_base_and_architecture(n_params) - _unravel = self._get_unravel( - FrozenDict( - nested_unpair( - d=state.params.unfreeze(), - key_paths=which_params, - labels=("mean", "log_std"), - )[0] - if which_params - else { - k: dict(params=v["params"]["mean"]) - for k, v in state.params.items() - } - ), - which_params=which_params, - )[1] - - self._base = _base - self._architecture = _architecture - self._unravel = _unravel - else: - _base = self._base - _architecture = self._architecture - _unravel = self._unravel + _base = self._base + _architecture = self._architecture + _unravel = self._unravel if state._encoded_which_params is None: means = _unravel(