From ab13c1d2908acb3112ca95fab93a6a38d965186d Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 16 May 2023 13:36:36 +0200 Subject: [PATCH] add SG-MCMC methods (#46) * add SGHMC method * add SG-MCMC diagnostics * add polynomial step schedule * implement SGHMC momentum resampling * refactor SGMCMC posterior into a top-level class * add Cyclical SGLD method * implement chain thinning * update documentation * code formatting fixes * refactor SGMCMC methods * update documentation * rebase to the new API & update docs * address code review feedback * fix the Cyclical SGLD sampler * fix unit tests * lint the code * make black show diffs in the workflow * rebase to the latest API --- .github/workflows/lint.yml | 2 +- docs/source/methods.rst | 8 + .../prob_model/posterior/posterior.rst | 1 + .../prob_model/posterior/sgmcmc.rst | 82 +++++ examples/index.rst | 2 + examples/mnist_classification_sghmc.pct.py | 204 +++++++++++ examples/sgmcmc_diagnostics.pct.py | 112 ++++++ fortuna/data/loader/base.py | 33 ++ .../model_manager/name_to_model_manager.py | 4 + fortuna/prob_model/__init__.py | 6 + .../deep_ensemble/deep_ensemble_posterior.py | 27 +- .../posterior/name_to_posterior_state.py | 6 + .../posterior/posterior_approximations.py | 10 + ...py => posterior_multi_state_repository.py} | 23 +- .../prob_model/posterior/sgmcmc/__init__.py | 0 fortuna/prob_model/posterior/sgmcmc/base.py | 33 ++ .../sgmcmc/cyclical_sgld/__init__.py | 1 + .../cyclical_sgld_approximator.py | 51 +++ .../cyclical_sgld/cyclical_sgld_callback.py | 81 +++++ .../cyclical_sgld/cyclical_sgld_integrator.py | 107 ++++++ .../cyclical_sgld/cyclical_sgld_posterior.py | 226 ++++++++++++ .../cyclical_sgld/cyclical_sgld_state.py | 46 +++ .../posterior/sgmcmc/sghmc/__init__.py | 1 + .../sgmcmc/sghmc/sghmc_approximator.py | 60 +++ .../posterior/sgmcmc/sghmc/sghmc_callback.py | 74 ++++ .../sgmcmc/sghmc/sghmc_integrator.py | 93 +++++ .../posterior/sgmcmc/sghmc/sghmc_posterior.py | 221 ++++++++++++ .../posterior/sgmcmc/sghmc/sghmc_state.py | 46 +++ .../posterior/sgmcmc/sgmcmc_diagnostic.py | 133 +++++++ .../posterior/sgmcmc/sgmcmc_posterior.py | 96 +++++ .../posterior/sgmcmc/sgmcmc_preconditioner.py | 135 +++++++ .../sgmcmc/sgmcmc_sampling_callback.py | 61 ++++ .../posterior/sgmcmc/sgmcmc_step_schedule.py | 148 ++++++++ tests/fortuna/prob_model/test_diagnostic.py | 58 +++ .../fortuna/prob_model/test_preconditioner.py | 39 ++ .../fortuna/prob_model/test_step_schedule.py | 54 +++ tests/fortuna/prob_model/test_train.py | 341 +++++++++--------- 37 files changed, 2419 insertions(+), 206 deletions(-) create mode 100644 docs/source/references/prob_model/posterior/sgmcmc.rst create mode 100644 examples/mnist_classification_sghmc.pct.py create mode 100644 examples/sgmcmc_diagnostics.pct.py rename fortuna/prob_model/posterior/{deep_ensemble/deep_ensemble_repositories.py => posterior_multi_state_repository.py} (85%) create mode 100644 fortuna/prob_model/posterior/sgmcmc/__init__.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/base.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/__init__.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_approximator.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_callback.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sghmc/__init__.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_approximator.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sgmcmc_diagnostic.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sgmcmc_preconditioner.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/sgmcmc_step_schedule.py create mode 100755 tests/fortuna/prob_model/test_diagnostic.py create mode 100755 tests/fortuna/prob_model/test_preconditioner.py create mode 100755 tests/fortuna/prob_model/test_step_schedule.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8cbddf14..029c3cc1 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,4 +24,4 @@ jobs: pip install black - name: Run Black - run: black --check --verbose fortuna + run: black --check --diff --verbose fortuna diff --git a/docs/source/methods.rst b/docs/source/methods.rst index 25db0465..a576a10a 100644 --- a/docs/source/methods.rst +++ b/docs/source/methods.rst @@ -29,6 +29,14 @@ Posterior approximation methods taken by averaging checkpoints over the stochastic optimization trajectory. The covariance is also estimated empirically along the trajectory, and it is made of a diagonal component and a low-rank non-diagonal one. +- **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)** `[Chen et al., 2014] `_ + SGHMC approximates the posterior as a steady-state distribution of a Monte Carlo Markov chain with Hamiltonian dynamics. + After the initial "burn-in" phase, each step of the chain generates samples from the posterior. + +- **Cyclical Stochastic Gradient Langevin Dynamics (Cyclical SGLD)** `[Zhang et al., 2020] `_ + Cyclical SGLD adapts the cyclical cosine step size schedule, and alternates between *exploration* and *sampling* stages to better + explore the multimodal posteriors for deep neural networks. + Parametric calibration methods ------------------------------ Fortuna supports parametric calibration by adding an output calibration model on top of the outputs of the model used for diff --git a/docs/source/references/prob_model/posterior/posterior.rst b/docs/source/references/prob_model/posterior/posterior.rst index 1aa9009b..7225c25a 100644 --- a/docs/source/references/prob_model/posterior/posterior.rst +++ b/docs/source/references/prob_model/posterior/posterior.rst @@ -12,6 +12,7 @@ calibration parameters. We support several posterior approximations: laplace swag sngp + sgmcmc .. _posterior: diff --git a/docs/source/references/prob_model/posterior/sgmcmc.rst b/docs/source/references/prob_model/posterior/sgmcmc.rst new file mode 100644 index 00000000..2ab3ce31 --- /dev/null +++ b/docs/source/references/prob_model/posterior/sgmcmc.rst @@ -0,0 +1,82 @@ +Stochastic Gradient Markov Chain Monte Carlo (SG-MCMC) +------------------------------------------------------ +SG-MCMC procedures approximate the posterior as a steady-state distribution of +a Monte Carlo Markov chain, that utilizes noisy estimates of the gradient +computed on minibatches of data. + +Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) +=================================================== + +SGHMC `[Chen T. et al., 2014] `_ +is a popular MCMC algorithm that uses stochastic gradient estimates to scale +to large datasets. + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator.SGHMCPosteriorApproximator + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior.SGHMCPosterior + :show-inheritance: + :no-inherited-members: + :exclude-members: state + :members: fit, sample, load_state, save_state + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state.SGHMCState + :show-inheritance: + :no-inherited-members: + :inherited-members: init, init_from_dict + :members: convert_from_map_state + :exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create + :no-undoc-members: + :no-special-members: + +Cyclical Stochastic Gradient Langevin Dynamics (CyclicalSGLD) +============================================================= + +Cyclical SGLD method `[Zhang R. et al., 2019] `_ is a simple and automatic +procedure that adapts the cyclical cosine stepsize schedule, and alternates between +*exploration* and *sampling* stages to better explore the multimodal posteriors for deep neural networks. + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator.CyclicalSGLDPosteriorApproximator + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior.CyclicalSGLDPosterior + :show-inheritance: + :no-inherited-members: + :exclude-members: state + :members: fit, sample, load_state, save_state + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_state.CyclicalSGLDState + :show-inheritance: + :no-inherited-members: + :inherited-members: init, init_from_dict + :members: convert_from_map_state + :exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create + :no-undoc-members: + :no-special-members: + + +Step schedules +============== + +Fortuna supports various step schedulers for SG-MCMC +algorithms. :class:`~fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.StepSchedule` +is a function that takes step count as an input and returns `float` step +size as an output. + +.. automodule:: fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule + + +Preconditioners +=============== + +Fortuna provides implementations of preconditioners to improve samplers efficacy. + +.. automodule:: fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner + :exclude-members: Preconditioner, PreconditionerState, RMSPropPreconditionerState, IdentityPreconditionerState + + +Diagnostics +=========== + +The library includes toolings necessary for diagnostics of the convergence of +SG-MCMC sampling procedures. + +.. automodule:: fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic diff --git a/examples/index.rst b/examples/index.rst index b322d9a6..f62ac4bf 100644 --- a/examples/index.rst +++ b/examples/index.rst @@ -19,3 +19,5 @@ In this section we show some examples of how to use Fortuna in classification an subnet_calibration two_moons_classification_sngp scaling_up_bayesian_inference + mnist_classification_sghmc + sgmcmc_diagnostics diff --git a/examples/mnist_classification_sghmc.pct.py b/examples/mnist_classification_sghmc.pct.py new file mode 100644 index 00000000..84abf342 --- /dev/null +++ b/examples/mnist_classification_sghmc.pct.py @@ -0,0 +1,204 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.14.5 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # MNIST Classification with Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) + +# %% [markdown] +# In this notebook we demonstrate how to use Fortuna to obtain predictions uncertainty estimates from a simple neural network model trained for MNIST classification task, using the [SGHMC](http://proceedings.mlr.press/v32/cheni14.pdf) method. + +# %% [markdown] +# ### Download MNIST data from TensorFlow +# Let us first download the MNIST data from [TensorFlow Datasets](https://www.tensorflow.org/datasets). Other sources would be equivalently fine. + +# %% +import tensorflow as tf +import tensorflow_datasets as tfds + + +def download(split_range, shuffle=False): + ds = tfds.load( + name="MNIST", + split=f"train[{split_range}]", + as_supervised=True, + shuffle_files=True, + ).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y)) + if shuffle: + ds = ds.shuffle(10, reshuffle_each_iteration=True) + return ds.batch(128).prefetch(1) + + +train_data_loader, val_data_loader, test_data_loader = ( + download(":80%", shuffle=True), + download("80%:90%"), + download("90%:"), +) + +# %% [markdown] +# ### Convert data to a compatible data loader +# Fortuna helps you converting data and data loaders into a data loader that Fortuna can digest. + +# %% +from fortuna.data import DataLoader + +train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader) +val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader) +test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader) + +# %% [markdown] +# ### Build a probabilistic classifier +# Let us build a probabilistic classifier. This is an interface object containing several attributes that you can configure, i.e. `model`, `prior`, and `posterior_approximator`. In this example, we use a multilayer perceptron and an SGHMC posterior approximator. SGHMC (and SGMCMC methods, broadly) allows configuring a step size schedule function. For simplicity, we create a constant step schedule. + +# %% +import flax.linen as nn + +from fortuna.prob_model import ProbClassifier, SGHMCPosteriorApproximator +from fortuna.model import MLP + +output_dim = 10 +prob_model = ProbClassifier( + model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)), + posterior_approximator=SGHMCPosteriorApproximator(burnin_length=300, + step_schedule=4e-6), +) + + +# %% [markdown] +# ### Train the probabilistic model: posterior fitting and calibration +# We can now train the probabilistic model. This includes fitting the posterior distribution and calibrating the probabilistic model. We set the Markov chain burn-in phase to 20 epochs, followed by obtaining samples from the approximated posterior. + +# %% +from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer +from fortuna.metric.classification import accuracy + +status = prob_model.train( + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + calib_data_loader=val_data_loader, + fit_config=FitConfig( + monitor=FitMonitor(metrics=(accuracy,)), + optimizer=FitOptimizer(n_epochs=30), + ), +) + + +# %% [markdown] +# ### Estimate predictive statistics +# We can now compute some predictive statistics by invoking the `predictive` attribute of the probabilistic classifier, and the method of interest. Most predictive statistics, e.g. mean or mode, require a loader of input data points. You can easily get this from the data loader calling its method `to_inputs_loader`. + +# %% pycharm={"name": "#%%\n"} +test_log_probs = prob_model.predictive.log_prob(data_loader=test_data_loader) +test_inputs_loader = test_data_loader.to_inputs_loader() +test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader) +test_modes = prob_model.predictive.mode( + inputs_loader=test_inputs_loader, means=test_means +) + +# %% [markdown] +# ### Compute metrics +# In classification, the predictive mode is a prediction for labels, while the predictive mean is a prediction for the probability of each label. As such, we can use these to compute several metrics, e.g. the accuracy, the Brier score, the expected calibration error (ECE), etc. + +# %% +from fortuna.metric.classification import ( + accuracy, + expected_calibration_error, + brier_score, +) + +test_targets = test_data_loader.to_array_targets() +acc = accuracy(preds=test_modes, targets=test_targets) +brier = brier_score(probs=test_means, targets=test_targets) +ece = expected_calibration_error( + preds=test_modes, + probs=test_means, + targets=test_targets, + plot=True, + plot_options=dict(figsize=(10, 2)), +) +print(f"Test accuracy: {acc}") +print(f"Brier score: {brier}") +print(f"ECE: {ece}") + +# %% [markdown] +# ### Conformal prediction sets +# Fortuna allows to produce conformal prediction sets, that are sets of likely labels up to some coverage probability threshold. These can be computed starting from probability estimates obtained with or without Fortuna. + +# %% +from fortuna.conformal import AdaptivePredictionConformalClassifier + +val_means = prob_model.predictive.mean(inputs_loader=val_data_loader.to_inputs_loader()) +conformal_sets = AdaptivePredictionConformalClassifier().conformal_set( + val_probs=val_means, + test_probs=test_means, + val_targets=val_data_loader.to_array_targets(), +) + +# %% [markdown] +# We can check that, on average, conformal sets for misclassified inputs are larger than for well classified ones. + +# %% +import numpy as np + +avg_size = np.mean([len(s) for s in np.array(conformal_sets, dtype="object")]) +avg_size_wellclassified = np.mean( + [ + len(s) + for s in np.array(conformal_sets, dtype="object")[test_modes == test_targets] + ] +) +avg_size_misclassified = np.mean( + [ + len(s) + for s in np.array(conformal_sets, dtype="object")[test_modes != test_targets] + ] +) +print(f"Average conformal set size: {avg_size}") +print( + f"Average conformal set size over well classified input: {avg_size_wellclassified}" +) +print(f"Average conformal set size over misclassified input: {avg_size_misclassified}") + +# %% [markdown] +# Furthermore, we visualize some of the examples with the largest and the smallest conformal sets. Intutively, they correspond to the inputs where the model is the most uncertain or the most certain about its predictions. + +# %% +from matplotlib import pyplot as plt + +N_EXAMPLES = 10 +images = test_data_loader.to_array_inputs() + +def visualize_examples(indices, n_examples=N_EXAMPLES): + n_rows = min(len(indices), n_examples) + _, axs = plt.subplots(1, n_rows, figsize=(10, 2)) + axs = axs.flatten() + for i, ax in enumerate(axs): + ax.imshow(images[indices[i]], cmap='gray') + ax.axis("off") + plt.show() + +# %% +indices = np.argsort( + [ + len(s) + for s in np.array(conformal_sets, dtype="object") + ] +) + +# %% +print("Examples with the smallest conformal sets:") +visualize_examples(indices[:N_EXAMPLES]) + +# %% +print("Examples with the largest conformal sets:") +visualize_examples(np.flip(indices[-N_EXAMPLES:])) diff --git a/examples/sgmcmc_diagnostics.pct.py b/examples/sgmcmc_diagnostics.pct.py new file mode 100644 index 00000000..0925e5b0 --- /dev/null +++ b/examples/sgmcmc_diagnostics.pct.py @@ -0,0 +1,112 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.14.5 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Stochastic Gradient Markov chain Monte Carlo (SG-MCMC) disagnostics + +# %% [markdown] +# Markov chain Monte Carlo (MCMC) methods are powerful tools for approximating the posterior distribution. Stochastic procedures, such as Stochastic Gradient Hamiltonian Monte Carlo, enable rapid sampling at the cost of more biased inference. However, it has been shown that standard MCMC diagnostics fail to detect these biases. Kernel Stein discrepancy approach (KSD) with the recently proposed inverse multiquadric (IMQ) kernel [[Gorham and Mackey, 2017](https://proceedings.mlr.press/v70/gorham17a/gorham17a.pdf)] aims for comparing biased, exact, and deterministic sample sequences, that is also particularly suitable for parallelized computation. +# +# In this notebook, we show how to assess the quality of SG-MCMC samples. + +# %% [markdown] +# We create a toy example with a 2-D multivariate normal distribution. The distribution is parametrized a zero mean and a covariance matrix $\Sigma = P^{T} D P$, where $D$ is a diagonal scale matrix, and $P$ is a rotation matrix for some angle $r$. + +# %% +from jax import vmap, value_and_grad +import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse + +mu = jnp.zeros([2,]) +r = np.pi / 4 +D = jnp.array([2., 1.]) +P = jnp.array([[jnp.cos(r), jnp.sin(r)], [-jnp.sin(r), jnp.cos(r)]]) +sigma = P.T @ jnp.diag(D) @ P + +# %% [markdown] +# We create a ground truth dataset, and also two more dataset with underdispersed $\mathcal{N}(0, \sqrt[5]{\Sigma})$ and overdispersed $\mathcal{N}(0, \Sigma^{3})$ samples. + +# %% +N = 1_000 +disp = [1/5, 1, 3] +rng = np.random.default_rng(0) +samples = np.array([rng.multivariate_normal(mu, sigma ** d, size=N) for d in disp]) + +# %% [markdown] +# The dataset of samples from the target distribution (in the middle) clearly aligns with confidence ellipses. + +# %% +titles = ["$\sqrt[5]{\Sigma}$", "$\Sigma$", "$\Sigma^{3}$"] +_, axs = plt.subplots(1, len(samples), sharey=True, figsize=(12, 4)) +for i, ax in enumerate(axs.flatten()): + ax.axis('equal') + ax.grid() + ax.scatter(samples[i, :, 0], samples[i, :, 1], alpha=0.3) + for std in range(1, 4): + conf_ell = Ellipse( + xy=mu, width=D[0] * std, height=D[1] * std, angle=np.rad2deg(r), + edgecolor='black', linestyle='--', facecolor='none' + ) + ax.add_artist(conf_ell) + ax.set_title(titles[i]) + +plt.show() + +# %% [markdown] +# Kernel Stein discrepancy with inverse multiquadric kernel is computed over an array of samples and corresponding gradients. Note that it has quadratic time complexity that would make it challenging to scale to large sequences. + +# %% +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic import kernel_stein_discrepancy_imq + +logpdf = lambda params: stats.multivariate_normal.logpdf(params, mu, sigma) +_, grads = vmap(vmap(value_and_grad(logpdf), 0, 0), 1, 1)(samples) + +ksd = vmap(kernel_stein_discrepancy_imq, 0, 0)(samples, grads) +log_ksd = jnp.log10(ksd) + +# %% [markdown] +# As expected, the lowest value of (log-)KSD is obtained in the dataset that is sampled from the ground truth distribution. + +# %% +fig, ax = plt.subplots(1, 1, figsize=(6, 3)) +ax.grid() +ax.plot(disp, log_ksd) +ax.set_ylabel("log KSD") +ax.set_xlabel("$\Sigma$") +plt.show() + +# %% [markdown] +# ### Estimating effective sample size +# +# Effective Sample Size (ESS) is a metric that quantifies autocorrelation in a sequence. Intuitively, ESS is the size of an i.i.d. sample with the same variance as the input sample. Typical usage includes computing the standard error for the MCMC estimator: + + +# %% +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic import effective_sample_size + +ess = effective_sample_size(samples[0]) +variance = jnp.var(samples[0], axis=0) +standard_error = jnp.sqrt(variance / ess) +standard_error + +# %% [markdown] +# Note that a sequence of strongly autocorrelated samples leads to a very low ESS: + +# %% +print("ESS for no auto-correlation:", effective_sample_size(rng.normal(size=200))) +print("ESS for strong auto-correlation:", effective_sample_size(jnp.arange(200) + rng.normal(size=200))) diff --git a/fortuna/data/loader/base.py b/fortuna/data/loader/base.py index 646596cf..2f7aed43 100644 --- a/fortuna/data/loader/base.py +++ b/fortuna/data/loader/base.py @@ -67,6 +67,17 @@ def num_unique_labels(self) -> Optional[int]: """ return self._num_unique_labels + def __len__(self) -> int: + """ + The number of batches in the data loader. + + Returns + ------- + int + Number of batches. + """ + return sum(1 for _ in self) + @property @abc.abstractmethod def input_shape(self) -> Shape: @@ -292,6 +303,17 @@ def fun(): return fun() + def __len__(self) -> int: + """ + The number of batches in the inputs loader. + + Returns + ------- + int + Number of batches. + """ + return sum(1 for _ in self) + @classmethod def from_data_loader(cls: Type[T], data_loader: BaseDataLoaderABC) -> T: """ @@ -405,6 +427,17 @@ def size(self) -> int: c += targets.shape[0] return c + def __len__(self) -> int: + """ + The number of batches in the targets loader. + + Returns + ------- + int + Number of batches. + """ + return sum(1 for _ in self) + @classmethod def from_data_loader(cls: Type[T], data_loader: BaseDataLoaderABC) -> T: """ diff --git a/fortuna/model/model_manager/name_to_model_manager.py b/fortuna/model/model_manager/name_to_model_manager.py index 3da74571..2c248e62 100644 --- a/fortuna/model/model_manager/name_to_model_manager.py +++ b/fortuna/model/model_manager/name_to_model_manager.py @@ -10,6 +10,8 @@ from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME from fortuna.prob_model.posterior.sngp import SNGP_NAME from fortuna.prob_model.posterior.swag import SWAG_NAME +from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME class ClassificationModelManagers(enum.Enum): @@ -21,3 +23,5 @@ class ClassificationModelManagers(enum.Enum): vars()[LAPLACE_NAME] = ClassificationModelManager vars()[SWAG_NAME] = ClassificationModelManager vars()[SNGP_NAME] = SNGPClassificationModelManager + vars()[SGHMC_NAME] = ClassificationModelManager + vars()[CYCLICAL_SGLD_NAME] = ClassificationModelManager diff --git a/fortuna/prob_model/__init__.py b/fortuna/prob_model/__init__.py index 8a645009..7a872d70 100644 --- a/fortuna/prob_model/__init__.py +++ b/fortuna/prob_model/__init__.py @@ -25,4 +25,10 @@ from fortuna.prob_model.posterior.swag.swag_approximator import ( SWAGPosteriorApproximator, ) +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator import ( + SGHMCPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator import ( + CyclicalSGLDPosteriorApproximator, +) from fortuna.prob_model.regression import ProbRegressor diff --git a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py index 7c5259f7..92619ae4 100755 --- a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py +++ b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py @@ -26,8 +26,8 @@ from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator import ( DeepEnsemblePosteriorApproximator, ) -from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_repositories import ( - DeepEnsemblePosteriorStateRepository, +from fortuna.prob_model.posterior.posterior_multi_state_repository import ( + PosteriorMultiStateRepository, ) from fortuna.prob_model.posterior.map.map_posterior import MAPState from fortuna.prob_model.posterior.map.map_trainer import ( @@ -160,17 +160,17 @@ def _fit(i): gradient_accumulation_steps=fit_config.hyperparameters.gradient_accumulation_steps, ) - if isinstance(self.state, DeepEnsemblePosteriorStateRepository): + if isinstance(self.state, PosteriorMultiStateRepository): for i in range(self.posterior_approximator.ensemble_size): self.state.state[i].checkpoint_dir = ( - os.path.join(fit_config.checkpointer.save_checkpoint_dir, str(i)) + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / str(i) if fit_config.checkpointer.save_checkpoint_dir is not None and fit_config.checkpointer.dump_state else None ) else: - self.state = DeepEnsemblePosteriorStateRepository( - ensemble_size=self.posterior_approximator.ensemble_size, + self.state = PosteriorMultiStateRepository( + size=self.posterior_approximator.ensemble_size, checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir if fit_config.checkpointer.dump_state is True else None, @@ -206,13 +206,13 @@ def sample(self, rng: Optional[PRNGKeyArray] = None, **kwargs) -> JointState: def load_state(self, checkpoint_dir: Path) -> None: try: - self.restore_checkpoint(os.path.join(checkpoint_dir, "0")) + self.restore_checkpoint(pathlib.Path(checkpoint_dir) / "0") except ValueError: raise ValueError( f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." ) - self.state = DeepEnsemblePosteriorStateRepository( - ensemble_size=self.posterior_approximator.ensemble_size, + self.state = PosteriorMultiStateRepository( + size=self.posterior_approximator.ensemble_size, checkpoint_dir=checkpoint_dir, ) @@ -267,12 +267,11 @@ def _restore_state_from_somewhere( allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, ) -> MAPState: if fit_config.checkpointer.restore_checkpoint_path is not None: + restore_checkpoint_path = pathlib.Path( + fit_config.checkpointer.restore_checkpoint_path + ) / str(i) state = self.restore_checkpoint( - restore_checkpoint_path=str( - fit_config.checkpointer.restore_checkpoint_path - ) - + "/" - + str(i), + restore_checkpoint_path=restore_checkpoint_path, optimizer=fit_config.optimizer.method, ) elif fit_config.checkpointer.start_from_current_state is not None: diff --git a/fortuna/prob_model/posterior/name_to_posterior_state.py b/fortuna/prob_model/posterior/name_to_posterior_state.py index 6623b957..c22e1b77 100644 --- a/fortuna/prob_model/posterior/name_to_posterior_state.py +++ b/fortuna/prob_model/posterior/name_to_posterior_state.py @@ -6,6 +6,10 @@ from fortuna.prob_model.posterior.normalizing_flow.advi.advi_state import ADVIState from fortuna.prob_model.posterior.state import PosteriorState from fortuna.prob_model.posterior.swag.swag_state import SWAGState +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state import SGHMCState +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_state import ( + CyclicalSGLDState, +) class NameToPosteriorState(enum.Enum): @@ -15,3 +19,5 @@ class NameToPosteriorState(enum.Enum): vars()[ADVIState.__name__] = ADVIState vars()[LaplaceState.__name__] = LaplaceState vars()[SWAGState.__name__] = SWAGState + vars()[SGHMCState.__name__] = SGHMCState + vars()[CyclicalSGLDState.__name__] = CyclicalSGLDState diff --git a/fortuna/prob_model/posterior/posterior_approximations.py b/fortuna/prob_model/posterior/posterior_approximations.py index 80fbe56d..69ba92fb 100644 --- a/fortuna/prob_model/posterior/posterior_approximations.py +++ b/fortuna/prob_model/posterior/posterior_approximations.py @@ -16,6 +16,14 @@ from fortuna.prob_model.posterior.sngp.sngp_posterior import SNGPPosterior from fortuna.prob_model.posterior.swag import SWAG_NAME from fortuna.prob_model.posterior.swag.swag_posterior import SWAGPosterior +from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior import ( + SGHMCPosterior, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior import ( + CyclicalSGLDPosterior, +) class PosteriorApproximations(enum.Enum): @@ -27,3 +35,5 @@ class PosteriorApproximations(enum.Enum): vars()[LAPLACE_NAME] = LaplacePosterior vars()[SWAG_NAME] = SWAGPosterior vars()[SNGP_NAME] = SNGPPosterior + vars()[SGHMC_NAME] = SGHMCPosterior + vars()[CYCLICAL_SGLD_NAME] = CyclicalSGLDPosterior diff --git a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_repositories.py b/fortuna/prob_model/posterior/posterior_multi_state_repository.py similarity index 85% rename from fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_repositories.py rename to fortuna/prob_model/posterior/posterior_multi_state_repository.py index dbbcdbcc..1cca1e86 100644 --- a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_repositories.py +++ b/fortuna/prob_model/posterior/posterior_multi_state_repository.py @@ -6,9 +6,6 @@ Union, ) -from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_state import ( - DeepEnsembleState, -) from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, ) @@ -19,16 +16,16 @@ ) -class DeepEnsemblePosteriorStateRepository: - def __init__(self, ensemble_size: int, checkpoint_dir: Optional[Path] = None): - self.ensemble_size = ensemble_size +class PosteriorMultiStateRepository: + def __init__(self, size: int, checkpoint_dir: Optional[Path] = None): + self.size = size self.state = [ PosteriorStateRepository( checkpoint_dir=os.path.join(checkpoint_dir, str(i)) if checkpoint_dir else None ) - for i in range(ensemble_size) + for i in range(size) ] def get( @@ -50,7 +47,7 @@ def _get(_i): if i is not None: return _get(i) state = [] - for i in range(self.ensemble_size): + for i in range(self.size): state.append(_get(i)) return state @@ -70,7 +67,7 @@ def _put(_i): if i is not None: _put(i) else: - for i in range(self.ensemble_size): + for i in range(self.size): state.append(_put(i)) def pull( @@ -80,7 +77,7 @@ def pull( optimizer: Optional[OptaxOptimizer] = None, prefix: str = "checkpoint_", **kwargs, - ) -> Union[DeepEnsembleState, PosteriorState]: + ) -> PosteriorState: def _pull(_i): return self.state[_i].pull( checkpoint_path=checkpoint_path, @@ -92,7 +89,7 @@ def _pull(_i): if i is not None: return _pull(i) state = [] - for i in range(self.ensemble_size): + for i in range(self.size): state.append(_pull(i)) return state @@ -119,7 +116,7 @@ def _update(_i): if i is not None: _update(i) else: - for i in range(self.ensemble_size): + for i in range(self.size): _update(i) def extract( @@ -138,7 +135,7 @@ def _extract(_i): if i is not None: return _extract(i) dicts = [] - for i in range(self.ensemble_size): + for i in range(self.size): dicts.append(_extract(i)) return dicts diff --git a/fortuna/prob_model/posterior/sgmcmc/__init__.py b/fortuna/prob_model/posterior/sgmcmc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/prob_model/posterior/sgmcmc/base.py b/fortuna/prob_model/posterior/sgmcmc/base.py new file mode 100644 index 00000000..96bf90c7 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/base.py @@ -0,0 +1,33 @@ +from fortuna.prob_model.posterior.base import PosteriorApproximator +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + Preconditioner, + identity_preconditioner, +) + + +class SGMCMCPosteriorApproximator(PosteriorApproximator): + def __init__( + self, + n_samples: int = 10, + n_thinning: int = 1, + preconditioner: Preconditioner = identity_preconditioner(), + ) -> None: + """ + SGMCMC posterior approximator. It is responsible to define how the posterior distribution is approximated. + + Parameters + ---------- + n_samples: int + The desired number of the posterior samples. + n_thinning: int + If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. + preconditioner: Preconditioner + A `Preconditioner` instance that preconditions the approximator with information about the posterior distribution, if available. + + """ + self.n_samples = n_samples + self.n_thinning = n_thinning + self.preconditioner = preconditioner + + def __str__(self) -> str: + raise NotImplementedError diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/__init__.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/__init__.py new file mode 100644 index 00000000..37360c22 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/__init__.py @@ -0,0 +1 @@ +CYCLICAL_SGLD_NAME = "cyclical_sgld" diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_approximator.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_approximator.py new file mode 100644 index 00000000..ef42a32e --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_approximator.py @@ -0,0 +1,51 @@ +from fortuna.prob_model.posterior.sgmcmc.base import ( + SGMCMCPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + Preconditioner, + identity_preconditioner, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import ( + CYCLICAL_SGLD_NAME, +) + + +class CyclicalSGLDPosteriorApproximator(SGMCMCPosteriorApproximator): + def __init__( + self, + n_samples: int = 10, + n_thinning: int = 1, + cycle_length: int = 1000, + init_step_size: float = 1e-5, + exploration_ratio: float = 0.25, + preconditioner: Preconditioner = identity_preconditioner(), + ) -> None: + """ + Cyclical SGLD posterior approximator. It is responsible to define how the posterior distribution is approximated. + + Parameters + ---------- + n_samples: int + The desired number of the posterior samples. + n_thinning: int + If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. + cycle_length: int + The length of each exploration/sampling cycle, in steps. + init_step_size: float + The initial step size. + exploration_ratio: float + The fraction of steps to allocate to the mode exploration phase. + preconditioner: Preconditioner + A `Preconditioner` instance that preconditions the approximator with information about the posterior distribution, if available. + """ + super().__init__( + n_samples=n_samples, + n_thinning=n_thinning, + preconditioner=preconditioner, + ) + self.cycle_length = cycle_length + self.init_step_size = init_step_size + self.exploration_ratio = exploration_ratio + + def __str__(self) -> str: + return CYCLICAL_SGLD_NAME diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_callback.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_callback.py new file mode 100644 index 00000000..e511559c --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_callback.py @@ -0,0 +1,81 @@ +from typing import Optional +import pathlib + +from fortuna.training.train_state import TrainState +from fortuna.training.callback import Callback +from fortuna.training.train_state_repository import TrainStateRepository +from fortuna.training.trainer import TrainerABC +from fortuna.typing import Path + +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import ( + SGMCMCSamplingCallback, +) + + +class CyclicalSGLDSamplingCallback(SGMCMCSamplingCallback): + def __init__( + self, + n_epochs: int, + n_training_steps: int, + n_samples: int, + n_thinning: int, + cycle_length: int, + exploration_ratio: float, + trainer: TrainerABC, + state_repository: TrainStateRepository, + keep_top_n_checkpoints: int, + save_checkpoint_dir: Optional[Path] = None, + ): + """ + Cyclical Stochastic Gradient Langevin Dynamics (SGLD) callback that collects samples + in different cycles. See `Zhang R. et al., 2020 `_ + for more details. + + Parameters + ---------- + n_epochs: int + The number of epochs. + n_training_steps: int + The number of steps per epoch. + n_samples: int + The desired number of the posterior samples. + n_thinning: int + Keep only each `n_thinning` sample during the sampling phase. + cycle_length: int + The length of each exploration/sampling cycle, in steps. + init_step_size: float + The initial step size. + exploration_ratio: float + The fraction of steps to allocate to the mode exploration phase. + trainer: TrainerABC + An instance of the trainer class. + state_repository: TrainStateRepository + An instance of the state repository. + keep_top_n_checkpoints: int + Number of past checkpoint files to keep. + save_checkpoint_dir: Optional[Path] + The optional path to save checkpoints. + """ + super().__init__( + trainer=trainer, + state_repository=state_repository, + keep_top_n_checkpoints=keep_top_n_checkpoints, + save_checkpoint_dir=save_checkpoint_dir, + ) + + self._do_sample = ( + lambda current_step, samples_count: samples_count < n_samples + and ((current_step % cycle_length) / cycle_length) >= exploration_ratio + and (current_step % cycle_length) % n_thinning == 0 + ) + + total_samples = sum( + self._do_sample(step, 0) + for step in range(1, n_epochs * n_training_steps + 1) + ) + if total_samples < n_samples: + raise ValueError( + f"The number of desired samples `n_samples` is {n_samples}. However, only " + f"{total_samples} samples will be collected. Consider adjusting the cycle length, " + "number of epochs, exploration ratio, or the thinning parameter." + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py new file mode 100644 index 00000000..6374aa0e --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py @@ -0,0 +1,107 @@ +import jax + +import optax +from optax import GradientTransformation + +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + Preconditioner, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( + cyclical_cosine_schedule_with_const_burnin, +) +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_integrator import ( + sghmc_integrator, +) +from jax._src.prng import PRNGKeyArray +from typing import NamedTuple + + +class OptaxCyclicalSGLDState(NamedTuple): + """Optax state for the Cyclical SGLD integrator.""" + + sgd_state: NamedTuple + sgld_state: NamedTuple + + +def cyclical_sgld_integrator( + rng_key: PRNGKeyArray, + init_step_size: float, + cycle_length: int, + exploration_ratio: float, + preconditioner: Preconditioner, +) -> GradientTransformation: + """Optax implementation of the Cyclical SGLD integrator. + + Parameters + ---------- + rng_key: PRNGKeyArray + An initial random number generator. + init_step_size: float + The initial step size. + cycle_length: int + The length of each exploration/sampling cycle, in steps. + exploration_ratio: float + The fraction of steps to allocate to the mode exploration phase. + preconditioner: Preconditioner + See :class:`Preconditioner` for reference. + """ + step_schedule = cyclical_cosine_schedule_with_const_burnin( + init_step_size=init_step_size, + burnin_steps=0, + cycle_length=cycle_length, + ) + # SGHMC with no momentum is equivalent to SGLD + sgld = sghmc_integrator( + momentum_decay=0.0, + momentum_resample_steps=None, + rng_key=rng_key, + step_schedule=step_schedule, + preconditioner=preconditioner, + ) + sgd = optax.sgd(learning_rate=1.0) + + def init_fn(params): + return OptaxCyclicalSGLDState( + sgd_state=sgd.init(params), + sgld_state=sgld.init(params), + ) + + def update_fn(gradient, state, *_): + def sgd_step(): + step_size = step_schedule(state.sgld_state.count) + preconditioner_state = preconditioner.update_preconditioner( + gradient, state.sgld_state.preconditioner_state + ) + new_sgld_state = state.sgld_state._replace( + count=state.sgld_state.count + 1, + preconditioner_state=preconditioner_state, + ) + rescaled_gradient = jax.tree_map( + lambda g: -1.0 * step_size * g, + gradient, + ) + updates, new_sgd_state = sgd.update(rescaled_gradient, state.sgd_state) + updates = preconditioner.multiply_by_m_inv(updates, preconditioner_state) + new_state = OptaxCyclicalSGLDState( + sgd_state=new_sgd_state, + sgld_state=new_sgld_state, + ) + return updates, new_state + + def sgld_step(): + updates, new_sgld_state = sgld.update(gradient, state.sgld_state) + new_state = OptaxCyclicalSGLDState( + sgd_state=state.sgd_state, + sgld_state=new_sgld_state, + ) + return updates, new_state + + updates, state = jax.lax.cond( + ((state.sgld_state.count % cycle_length) / cycle_length) + >= exploration_ratio, + sgld_step, + sgd_step, + ) + return updates, state + + return GradientTransformation(init_fn, update_fn) diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py new file mode 100644 index 00000000..2bda57ae --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py @@ -0,0 +1,226 @@ +import logging +from typing import Optional +import pathlib + +from flax.core import FrozenDict +from fortuna.utils.freeze import get_trainable_paths +from fortuna.utils.nested_dicts import nested_set, nested_get +from fortuna.data.loader import DataLoader +from fortuna.prob_model.fit_config.base import FitConfig +from fortuna.prob_model.joint.base import Joint +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.prob_model.posterior.map.map_trainer import ( + MAPTrainer, + JittedMAPTrainer, + MultiDeviceMAPTrainer, +) +from fortuna.prob_model.posterior.run_preliminary_map import ( + run_preliminary_map, +) +from fortuna.prob_model.posterior.posterior_multi_state_repository import ( + PosteriorMultiStateRepository, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior import ( + SGMCMCPosterior, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import ( + CYCLICAL_SGLD_NAME, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator import ( + CyclicalSGLDPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_integrator import ( + cyclical_sgld_integrator, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_callback import ( + CyclicalSGLDSamplingCallback, +) +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_state import ( + CyclicalSGLDState, +) +from fortuna.typing import Status +from fortuna.utils.device import select_trainer_given_devices + +logger = logging.getLogger(__name__) + + +class CyclicalSGLDPosterior(SGMCMCPosterior): + def __init__( + self, + joint: Joint, + posterior_approximator: CyclicalSGLDPosteriorApproximator, + ): + """ + Cyclical Stochastic Gradient Langevin Dynamics (SGLD) approximate posterior class. + + Parameters + ---------- + joint: Joint + A Joint distribution object. + posterior_approximator: CyclicalSGLDPosteriorApproximator + A cyclical SGLD posterior approximator. + """ + super().__init__(joint=joint, posterior_approximator=posterior_approximator) + + def __str__(self): + return CYCLICAL_SGLD_NAME + + def fit( + self, + train_data_loader: DataLoader, + val_data_loader: Optional[DataLoader] = None, + fit_config: FitConfig = FitConfig(), + map_fit_config: Optional[FitConfig] = None, + **kwargs, + ) -> Status: + super()._checks_on_fit_start(fit_config, map_fit_config) + + status = {} + + map_state = None + if map_fit_config is not None and fit_config.optimizer.freeze_fun is None: + logging.warning( + "It appears that you are trying to configure `map_fit_config`. " + "However, a preliminary run with MAP is supported only if " + "`fit_config.optimizer.freeze_fun` is given. " + "Since the latter was not given, `map_fit_config` will be ignored." + ) + elif not super()._is_state_available_somewhere( + fit_config + ) and super()._should_run_preliminary_map(fit_config, map_fit_config): + map_state, status["map"] = run_preliminary_map( + joint=self.joint, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + map_fit_config=map_fit_config, + rng=self.rng, + **kwargs, + ) + + if fit_config.optimizer.method is not None: + logging.info(f"`FitOptimizer` method in CyclicalSGLD is ignored.") + + fit_config.optimizer.method = cyclical_sgld_integrator( + rng_key=self.rng.get(), + init_step_size=self.posterior_approximator.init_step_size, + cycle_length=self.posterior_approximator.cycle_length, + exploration_ratio=self.posterior_approximator.exploration_ratio, + preconditioner=self.posterior_approximator.preconditioner, + ) + + trainer_cls = select_trainer_given_devices( + devices=fit_config.processor.devices, + base_trainer_cls=MAPTrainer, + jitted_trainer_cls=JittedMAPTrainer, + multi_device_trainer_cls=MultiDeviceMAPTrainer, + disable_jit=fit_config.processor.disable_jit, + ) + + save_checkpoint_dir = ( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "c" + if fit_config.checkpointer.save_checkpoint_dir + else None + ) + trainer = trainer_cls( + predict_fn=self.joint.likelihood.prob_output_layer.predict, + save_checkpoint_dir=save_checkpoint_dir, + save_every_n_steps=fit_config.checkpointer.save_every_n_steps, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, + eval_every_n_epochs=fit_config.monitor.eval_every_n_epochs, + early_stopping_monitor=fit_config.monitor.early_stopping_monitor, + early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta, + early_stopping_patience=fit_config.monitor.early_stopping_patience, + ) + + if super()._is_state_available_somewhere(fit_config): + state = self._restore_state_from_somewhere(fit_config=fit_config) + else: + state = self._init_map_state(map_state, train_data_loader, fit_config) + + state = super()._freeze_optimizer_in_state(state, fit_config) + + self.state = PosteriorMultiStateRepository( + size=self.posterior_approximator.n_samples, + checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir + if fit_config.checkpointer.dump_state is True + else None, + ) + + cyclical_sampling_callback = CyclicalSGLDSamplingCallback( + n_epochs=fit_config.optimizer.n_epochs, + n_training_steps=len(train_data_loader), + n_samples=self.posterior_approximator.n_samples, + n_thinning=self.posterior_approximator.n_thinning, + cycle_length=self.posterior_approximator.cycle_length, + exploration_ratio=self.posterior_approximator.exploration_ratio, + trainer=trainer, + state_repository=self.state, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, + ) + + state = CyclicalSGLDState.convert_from_map_state( + map_state=state, + optimizer=fit_config.optimizer.method, + ) + + state = super()._freeze_optimizer_in_state(state, fit_config) + + logging.info(f"Run CyclicalSGLD.") + state, status = trainer.train( + rng=self.rng.get(), + state=state, + loss_fun=self.joint._batched_log_joint_prob, + training_dataloader=train_data_loader, + training_dataset_size=train_data_loader.size, + n_epochs=fit_config.optimizer.n_epochs, + metrics=fit_config.monitor.metrics, + validation_dataloader=val_data_loader, + validation_dataset_size=val_data_loader.size + if val_data_loader is not None + else None, + verbose=fit_config.monitor.verbose, + callbacks=[cyclical_sampling_callback], + ) + logging.info("Fit completed.") + + return status + + def _init_map_state( + self, + state: Optional[MAPState], + data_loader: DataLoader, + fit_config: FitConfig, + ) -> MAPState: + if state is None or fit_config.optimizer.freeze_fun is None: + state = super()._init_joint_state(data_loader) + + return MAPState.init( + params=state.params, + mutable=state.mutable, + optimizer=fit_config.optimizer.method, + calib_params=state.calib_params, + calib_mutable=state.calib_mutable, + ) + else: + random_state = super()._init_joint_state(data_loader) + trainable_paths = get_trainable_paths( + state.params, fit_config.optimizer.freeze_fun + ) + state = state.replace( + params=FrozenDict( + nested_set( + d=state.params.unfreeze(), + key_paths=trainable_paths, + objs=tuple( + [ + nested_get(d=random_state.params, keys=path) + for path in trainable_paths + ] + ), + ) + ) + ) + + return state diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py new file mode 100644 index 00000000..0bd42d99 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import jax.numpy as jnp + +from fortuna.prob_model.posterior.state import PosteriorState +from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.typing import OptaxOptimizer + + +class CyclicalSGLDState(PosteriorState): + """ + Attributes + ---------- + encoded_name: jnp.ndarray + CyclicalSGLDState state name encoded as an array. + """ + + encoded_name: jnp.ndarray = convert_string_to_jnp_array("CyclicalSGLDState") + + @classmethod + def convert_from_map_state( + cls, map_state: MAPState, optimizer: OptaxOptimizer + ) -> CyclicalSGLDState: + """ + Convert a MAP state into an CyclicalSGLDState state. + + Parameters + ---------- + map_state: MAPState + A MAP posterior state. + optimizer: OptaxOptimizer + An Optax optimizer. + + Returns + ------- + SGHMCState + An SGHMC state. + """ + return CyclicalSGLDState.init( + params=map_state.params, + mutable=map_state.mutable, + optimizer=optimizer, + calib_params=map_state.calib_params, + calib_mutable=map_state.calib_mutable, + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/__init__.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/__init__.py new file mode 100644 index 00000000..0a536fca --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/__init__.py @@ -0,0 +1 @@ +SGHMC_NAME = "sghmc" diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_approximator.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_approximator.py new file mode 100644 index 00000000..d221d0e2 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_approximator.py @@ -0,0 +1,60 @@ +from typing import Union + +from fortuna.prob_model.posterior.sgmcmc.base import ( + SGMCMCPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + Preconditioner, + identity_preconditioner, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( + StepSchedule, + constant_schedule, +) +from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME + + +class SGHMCPosteriorApproximator(SGMCMCPosteriorApproximator): + def __init__( + self, + n_samples: int = 10, + n_thinning: int = 1, + burnin_length: int = 1000, + momentum_decay: float = 0.01, + step_schedule: Union[StepSchedule, float] = 1e-5, + preconditioner: Preconditioner = identity_preconditioner(), + ) -> None: + """ + SGHMC posterior approximator. It is responsible to define how the posterior distribution is approximated. + + Parameters + ---------- + n_samples: int + The desired number of the posterior samples. + n_thinning: int + If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. + burnin_length: int + Length of the initial burn-in phase, in steps. + momentum_decay: float + The "friction" term that counters the noise of stochastic gradient estimates. Setting this argument to zero recovers the overamped Langevin dynamics. + step_schedule: Union[StepSchedule, float] + Either a constant `float` step size or a schedule function. + preconditioner: Preconditioner + A `Preconditioner` instance that preconditions the approximator with information about the posterior distribution, if available. + + """ + super().__init__( + n_samples=n_samples, + n_thinning=n_thinning, + preconditioner=preconditioner, + ) + if isinstance(step_schedule, float): + step_schedule = constant_schedule(step_schedule) + elif not callable(step_schedule): + raise ValueError(f"`step_schedule` must be a a callable function.") + self.burnin_length = burnin_length + self.momentum_decay = momentum_decay + self.step_schedule = step_schedule + + def __str__(self) -> str: + return SGHMC_NAME diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py new file mode 100644 index 00000000..88c1d645 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py @@ -0,0 +1,74 @@ +from typing import Optional +import pathlib + +from fortuna.training.train_state import TrainState +from fortuna.training.callback import Callback +from fortuna.training.train_state_repository import TrainStateRepository +from fortuna.training.trainer import TrainerABC +from fortuna.typing import Path +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import ( + SGMCMCSamplingCallback, +) + + +class SGHMCSamplingCallback(SGMCMCSamplingCallback): + def __init__( + self, + n_epochs: int, + n_training_steps: int, + n_samples: int, + n_thinning: int, + burnin_length: int, + trainer: TrainerABC, + state_repository: TrainStateRepository, + keep_top_n_checkpoints: int, + save_checkpoint_dir: Optional[Path] = None, + ): + """ + Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) callback that collects samples + after the initial burn-in phase. + + Parameters + ---------- + n_epochs: int + The number of epochs. + n_training_steps: int + The number of steps per epoch. + n_samples: int + The desired number of the posterior samples. + n_thinning: int + Keep only each `n_thinning` sample during the sampling phase. + burnin_length: int + Length of the initial burn-in phase, in steps. + trainer: TrainerABC + An instance of the trainer class. + state_repository: TrainStateRepository + An instance of the state repository. + keep_top_n_checkpoints: int + Number of past checkpoint files to keep. + save_checkpoint_dir: Optional[Path] + The optional path to save checkpoints. + """ + super().__init__( + trainer=trainer, + state_repository=state_repository, + keep_top_n_checkpoints=keep_top_n_checkpoints, + save_checkpoint_dir=save_checkpoint_dir, + ) + + self._do_sample = ( + lambda current_step, samples_count: samples_count < n_samples + and current_step > burnin_length + and (current_step - burnin_length) % n_thinning == 0 + ) + + total_samples = sum( + self._do_sample(step, 0) + for step in range(1, n_epochs * n_training_steps + 1) + ) + if total_samples < n_samples: + raise ValueError( + f"The number of desired samples `n_samples` is {n_samples}. However, only " + f"{total_samples} samples will be collected. Consider adjusting the burnin " + "length, number of epochs, or the thinning parameter." + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py new file mode 100644 index 00000000..a534405a --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py @@ -0,0 +1,93 @@ +import jax +import jax.numpy as jnp + +from fortuna.typing import Array +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + PreconditionerState, + Preconditioner, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( + StepSchedule, +) +from fortuna.utils.random import generate_random_normal_like_tree +from jax._src.prng import PRNGKeyArray +from optax._src.base import PyTree +from optax import GradientTransformation +from typing import NamedTuple, Optional + + +class OptaxSGHMCState(NamedTuple): + """Optax state for the SGHMC integrator.""" + + count: Array + rng_key: PRNGKeyArray + momentum: PyTree + preconditioner_state: PreconditionerState + + +def sghmc_integrator( + momentum_decay: float, + momentum_resample_steps: Optional[int], + rng_key: PRNGKeyArray, + step_schedule: StepSchedule, + preconditioner: Preconditioner, +) -> GradientTransformation: + """Optax implementation of the SGHMC integrator. + + Parameters + ---------- + momentum_decay: float + The momentum decay parameter. + rng_key: PRNGKeyArray + An initial random number generator. + step_schedule: StepSchedule + A function that takes training step as input and returns the step size. + preconditioner: Preconditioner + See :class:`Preconditioner` for reference. + """ + # Implementation was partially adapted from https://github.com/google-research/google-research/blob/master/bnn_hmc/core/sgmcmc.py#L56 + + def init_fn(params): + return OptaxSGHMCState( + count=jnp.zeros([], jnp.int32), + rng_key=rng_key, + momentum=jax.tree_util.tree_map(jnp.zeros_like, params), + preconditioner_state=preconditioner.init(params), + ) + + def update_fn(gradient, state, *_): + step_size = step_schedule(state.count) + + preconditioner_state = preconditioner.update_preconditioner( + gradient, state.preconditioner_state + ) + + key, new_key = jax.random.split(state.rng_key) + noise = generate_random_normal_like_tree(key, gradient) + noise = preconditioner.multiply_by_m_sqrt(noise, preconditioner_state) + + momentum = jax.lax.cond( + momentum_resample_steps is not None + and state.count % momentum_resample_steps == 0, + lambda: jax.tree_util.tree_map(jnp.zeros_like, gradient), + lambda: state.momentum, + ) + + momentum = jax.tree_map( + lambda m, g, n: momentum_decay * m + + g * jnp.sqrt(step_size) + + n * jnp.sqrt(2 * (1 - momentum_decay)), + momentum, + gradient, + noise, + ) + updates = preconditioner.multiply_by_m_inv(momentum, preconditioner_state) + updates = jax.tree_map(lambda m: m * jnp.sqrt(step_size), updates) + return updates, OptaxSGHMCState( + count=state.count + 1, + rng_key=new_key, + momentum=momentum, + preconditioner_state=preconditioner_state, + ) + + return GradientTransformation(init_fn, update_fn) diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py new file mode 100644 index 00000000..3fcf82c1 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py @@ -0,0 +1,221 @@ +import logging +from typing import Optional +from itertools import cycle +import pathlib +from flax.core import FrozenDict + +from fortuna.utils.freeze import get_trainable_paths +from fortuna.utils.nested_dicts import nested_set, nested_get +from fortuna.data.loader import DataLoader +from fortuna.prob_model.fit_config.base import FitConfig +from fortuna.prob_model.joint.base import Joint +from fortuna.prob_model.posterior.map.map_trainer import ( + MAPTrainer, + JittedMAPTrainer, + MultiDeviceMAPTrainer, +) +from fortuna.prob_model.posterior.run_preliminary_map import ( + run_preliminary_map, +) +from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.prob_model.posterior.posterior_multi_state_repository import ( + PosteriorMultiStateRepository, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior import ( + SGMCMCPosterior, +) +from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator import ( + SGHMCPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_callback import ( + SGHMCSamplingCallback, +) +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_integrator import ( + sghmc_integrator, +) +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state import SGHMCState +from fortuna.typing import Status +from fortuna.utils.device import select_trainer_given_devices + +logger = logging.getLogger(__name__) + + +class SGHMCPosterior(SGMCMCPosterior): + def __init__( + self, + joint: Joint, + posterior_approximator: SGHMCPosteriorApproximator, + ): + """ + Stochastic Gradient Hamiltonian Monte Carlo approximate posterior class. + + Parameters + ---------- + joint: Joint + A Joint distribution object. + posterior_approximator: SGHMCPosteriorApproximator + A SGHMC posterior approximator. + """ + super().__init__(joint=joint, posterior_approximator=posterior_approximator) + + def __str__(self): + return SGHMC_NAME + + def fit( + self, + train_data_loader: DataLoader, + val_data_loader: Optional[DataLoader] = None, + fit_config: FitConfig = FitConfig(), + map_fit_config: Optional[FitConfig] = None, + **kwargs, + ) -> Status: + super()._checks_on_fit_start(fit_config, map_fit_config) + + status = {} + + map_state = None + if map_fit_config is not None and fit_config.optimizer.freeze_fun is None: + logging.warning( + "It appears that you are trying to configure `map_fit_config`. " + "However, a preliminary run with MAP is supported only if " + "`fit_config.optimizer.freeze_fun` is given. " + "Since the latter was not given, `map_fit_config` will be ignored." + ) + elif not super()._is_state_available_somewhere( + fit_config + ) and super()._should_run_preliminary_map(fit_config, map_fit_config): + map_state, status["map"] = run_preliminary_map( + joint=self.joint, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + map_fit_config=map_fit_config, + rng=self.rng, + **kwargs, + ) + + if fit_config.optimizer.method is not None: + logging.info(f"`FitOptimizer` method in SGHMC is ignored.") + + fit_config.optimizer.method = sghmc_integrator( + momentum_decay=self.posterior_approximator.momentum_decay, + momentum_resample_steps=None, + rng_key=self.rng.get(), + step_schedule=self.posterior_approximator.step_schedule, + preconditioner=self.posterior_approximator.preconditioner, + ) + + trainer_cls = select_trainer_given_devices( + devices=fit_config.processor.devices, + base_trainer_cls=MAPTrainer, + jitted_trainer_cls=JittedMAPTrainer, + multi_device_trainer_cls=MultiDeviceMAPTrainer, + disable_jit=fit_config.processor.disable_jit, + ) + + save_checkpoint_dir = ( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "c" + if fit_config.checkpointer.save_checkpoint_dir + else None + ) + trainer = trainer_cls( + predict_fn=self.joint.likelihood.prob_output_layer.predict, + save_checkpoint_dir=save_checkpoint_dir, + save_every_n_steps=fit_config.checkpointer.save_every_n_steps, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, + eval_every_n_epochs=fit_config.monitor.eval_every_n_epochs, + early_stopping_monitor=fit_config.monitor.early_stopping_monitor, + early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta, + early_stopping_patience=fit_config.monitor.early_stopping_patience, + ) + + if super()._is_state_available_somewhere(fit_config): + state = self._restore_state_from_somewhere(fit_config=fit_config) + else: + state = self._init_map_state(map_state, train_data_loader, fit_config) + + state = SGHMCState.convert_from_map_state( + map_state=state, + optimizer=fit_config.optimizer.method, + ) + + state = super()._freeze_optimizer_in_state(state, fit_config) + + self.state = PosteriorMultiStateRepository( + size=self.posterior_approximator.n_samples, + checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir + if fit_config.checkpointer.dump_state is True + else None, + ) + + sghmc_sampling_callback = SGHMCSamplingCallback( + n_epochs=fit_config.optimizer.n_epochs, + n_training_steps=len(train_data_loader), + n_samples=self.posterior_approximator.n_samples, + n_thinning=self.posterior_approximator.n_thinning, + burnin_length=self.posterior_approximator.burnin_length, + trainer=trainer, + state_repository=self.state, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, + ) + + logging.info(f"Run SGHMC.") + state, status["sghmc"] = trainer.train( + rng=self.rng.get(), + state=state, + loss_fun=self.joint._batched_log_joint_prob, + training_dataloader=train_data_loader, + training_dataset_size=train_data_loader.size, + n_epochs=fit_config.optimizer.n_epochs, + metrics=fit_config.monitor.metrics, + validation_dataloader=val_data_loader, + validation_dataset_size=val_data_loader.size + if val_data_loader is not None + else None, + verbose=fit_config.monitor.verbose, + callbacks=[sghmc_sampling_callback], + ) + logging.info("Fit completed.") + + return status + + def _init_map_state( + self, + state: Optional[MAPState], + data_loader: DataLoader, + fit_config: FitConfig, + ) -> MAPState: + if state is None or fit_config.optimizer.freeze_fun is None: + state = super()._init_joint_state(data_loader) + + return MAPState.init( + params=state.params, + mutable=state.mutable, + optimizer=fit_config.optimizer.method, + calib_params=state.calib_params, + calib_mutable=state.calib_mutable, + ) + else: + random_state = super()._init_joint_state(data_loader) + trainable_paths = get_trainable_paths( + state.params, fit_config.optimizer.freeze_fun + ) + state = state.replace( + params=FrozenDict( + nested_set( + d=state.params.unfreeze(), + key_paths=trainable_paths, + objs=tuple( + [ + nested_get(d=random_state.params, keys=path) + for path in trainable_paths + ] + ), + ) + ) + ) + + return state diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py new file mode 100644 index 00000000..4db93b5a --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import jax.numpy as jnp + +from fortuna.prob_model.posterior.state import PosteriorState +from fortuna.utils.strings import convert_string_to_jnp_array +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.typing import OptaxOptimizer + + +class SGHMCState(PosteriorState): + """ + Attributes + ---------- + encoded_name: jnp.ndarray + SGHMC state name encoded as an array. + """ + + encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState") + + @classmethod + def convert_from_map_state( + cls, map_state: MAPState, optimizer: OptaxOptimizer + ) -> SGHMCState: + """ + Convert a MAP state into an SGHMC state. + + Parameters + ---------- + map_state: MAPState + A MAP posterior state. + optimizer: OptaxOptimizer + An Optax optimizer. + + Returns + ------- + SGHMCState + An SGHMC state. + """ + return SGHMCState.init( + params=map_state.params, + mutable=map_state.mutable, + optimizer=optimizer, + calib_params=map_state.calib_params, + calib_mutable=map_state.calib_mutable, + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_diagnostic.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_diagnostic.py new file mode 100644 index 00000000..77f06e35 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_diagnostic.py @@ -0,0 +1,133 @@ +from typing import Optional, List +from optax._src.base import PyTree + +import jax.numpy as jnp +from jax import jit, lax, vmap +from jax.flatten_util import ravel_pytree + + +@jit +def kernel_stein_discrepancy_imq( + samples: List[PyTree], + grads: List[PyTree], + c: float = 1.0, + beta: float = -0.5, +) -> float: + """Kernel Stein Discrepancy with the Inverse Multiquadric (IMQ) kernel. + + See `Gorham J. and Mackey L., 2017 `_ for more details. + + Parameters + ---------- + samples: List[PyTree] + The list of `PyTree`, each representing an MCMC sample. + grads: List[PyTree] + The list of the corresponding density gradients. + c: float + :math:`c > 0` kernel bias hyperparameter. + beta: float + :math:`beta < 0` kernel exponent hyperparameter. + + Returns + ------- + ksd_img: float + The kernel Stein discrepancy value. + """ + if not c > 0: + raise ValueError("`c` should be > 0.") + if not beta < 0: + raise ValueError("`beta` should be < 0.") + + samples = ravel_pytree(samples)[0].reshape(len(samples), -1) + grads = ravel_pytree(grads)[0].reshape(len(grads), -1) + + def _k_0(param1, param2, grad1, grad2, c, beta): + dim = param1.shape[0] + diff = param1 - param2 + base = c**2 + jnp.dot(diff, diff) + kern = jnp.dot(grad1, grad2) * base**beta + kern += -2 * beta * jnp.dot(grad1, diff) * base ** (beta - 1) + kern += 2 * beta * jnp.dot(grad2, diff) * base ** (beta - 1) + kern += -2 * dim * beta * (base ** (beta - 1)) + kern += -4 * beta * (beta - 1) * base ** (beta - 2) * jnp.sum(jnp.square(diff)) + return kern + + _batched_k_0 = vmap(_k_0, in_axes=(None, 0, None, 0, None, None)) + + def _ksd(accum, x): + sample1, grad1 = x + accum += jnp.sum(_batched_k_0(sample1, samples, grad1, grads, c, beta)) + return accum, None + + ksd_sum, _ = lax.scan(_ksd, 0.0, (samples, grads)) + return jnp.sqrt(ksd_sum) / samples.shape[0] + + +def effective_sample_size( + samples: List[PyTree], filter_threshold: Optional[float] = 0.0 +) -> PyTree: + """Estimate the effective sample size of a sequence. + + For a sequence of length :math:`N`, the effective sample size is defined as + + :math:`ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]` + + where :math:`R_k` is the auto-correlation sequence, + :math:`R_k := Cov{X_1, X_{1+k}} / Var{X_1}` + + Parameters + ---------- + samples: List[PyTree] + The list of `PyTree`, each representing an MCMC sample. + filter_threshold: Optional[float] + The cut-off value to truncate the sequence at the first index where + the estimated auto-correlation is less than the threshold. + + Returns + ------- + ESS: PyTree + Parameter-wise estimates of the effective sample size. + + """ + unravel_fn = ravel_pytree(samples[0])[1] + samples = ravel_pytree(samples)[0].reshape(len(samples), -1) + + def _autocorr(x, axis=-1, center=True): + """Compute auto-correlation along one axis.""" + + dtype = x.dtype + shift = (-1 - axis) if axis < 0 else (len(x.shape) - 1 - axis) + x = jnp.transpose(x, jnp.roll(jnp.arange(len(x.shape)), shift)) + if center: + x -= x.mean(axis=-1, keepdims=True) + + # Zero pad to the next power of 2 greater than 2 * x_len + x_len = x.shape[-1] + pad_len = int(2.0 ** jnp.ceil(jnp.log2(x_len * 2)) - x_len) + x = jnp.pad(x, (0, pad_len))[:-pad_len] + + # Autocorrelation is IFFT of power-spectral density + fft = jnp.fft.fft(x.astype(jnp.complex64)) + prod = jnp.fft.ifft(fft * jnp.conj(fft)) + prod = jnp.real(prod[..., :x_len]).astype(dtype) + + # Divide to obtain an unbiased estimate of the expectation + denominator = x_len - jnp.arange(0.0, x_len) + res = prod / denominator + return jnp.transpose(res, jnp.roll(jnp.arange(len(res.shape)), -shift)) + + auto_cov = _autocorr(samples, axis=0) + auto_corr = auto_cov / auto_cov[:1] + + n = len(samples) + nk_factor = (n - jnp.arange(0.0, n)) / n + weighted_auto_corr = nk_factor[..., None] * auto_corr + + if filter_threshold is not None: + mask = (auto_corr < filter_threshold).astype(auto_corr.dtype) + mask = jnp.cumsum(mask, axis=0) + mask = jnp.maximum(1.0 - mask, 0.0) + weighted_auto_corr *= mask + + ess = n / (-1 + 2 * weighted_auto_corr.sum(axis=0)) + return unravel_fn(ess) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py new file mode 100644 index 00000000..bdd7e530 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py @@ -0,0 +1,96 @@ +from typing import Optional, Tuple, Type +import pathlib + +from jax._src.prng import PRNGKeyArray +from jax import pure_callback, random + +from fortuna.prob_model.posterior.base import Posterior +from fortuna.prob_model.posterior.state import PosteriorState +from fortuna.prob_model.fit_config.base import FitConfig +from fortuna.prob_model.joint.state import JointState +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.prob_model.posterior.posterior_multi_state_repository import ( + PosteriorMultiStateRepository, +) +from fortuna.typing import Path + + +class SGMCMCPosterior(Posterior): + """Base SGMCMC posterior approximators class.""" + + def sample( + self, + rng: Optional[PRNGKeyArray] = None, + **kwargs, + ) -> JointState: + """ + Sample from the posterior distribution. + + Parameters + ---------- + rng : Optional[PRNGKeyArray] + A random number generator. If not passed, this will be taken from the attributes of this class. + + Returns + ------- + JointState + A sample from the posterior distribution. + """ + if rng is None: + rng = self.rng.get() + state = pure_callback( + lambda j: self.state.get(i=j), + self.state.get(i=0), + random.choice(rng, self.posterior_approximator.n_samples), + ) + return JointState( + params=state.params, + mutable=state.mutable, + calib_params=state.calib_params, + calib_mutable=state.calib_mutable, + ) + + def load_state(self, checkpoint_dir: Path) -> None: + try: + self.restore_checkpoint(pathlib.Path(checkpoint_dir) / "0") + except ValueError: + raise ValueError( + f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`." + ) + self.state = PosteriorMultiStateRepository( + size=self.posterior_approximator.n_samples, + checkpoint_dir=checkpoint_dir, + ) + + def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None: + for i in range(self.posterior_approximator.n_samples): + self.state.put(state=self.state.get(i), i=i, keep=keep_top_n_checkpoints) + + def _restore_state_from_somewhere( + self, + fit_config: FitConfig, + allowed_states: Optional[Tuple[Type[MAPState], ...]] = None, + ) -> MAPState: + if fit_config.checkpointer.restore_checkpoint_path is not None: + restore_checkpoint_path = ( + pathlib.Path(fit_config.checkpointer.restore_checkpoint_path) / "c" + ) + state = self.restore_checkpoint( + restore_checkpoint_path=restore_checkpoint_path, + optimizer=fit_config.optimizer.method, + ) + elif fit_config.checkpointer.start_from_current_state is not None: + state = self.state.get( + i=self.state.size - 1, + optimizer=fit_config.optimizer.method, + ) + + if allowed_states is not None and not isinstance(state, allowed_states): + raise ValueError( + f"The type of the restored checkpoint must be within {allowed_states}. " + f"However, {fit_config.checkpointer.restore_checkpoint_path} pointed to a state " + f"with type {type(state)}." + ) + + self._check_state(state) + return state diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_preconditioner.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_preconditioner.py new file mode 100644 index 00000000..a937b5d7 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_preconditioner.py @@ -0,0 +1,135 @@ +import jax +import jax.numpy as jnp + +from optax._src.base import PyTree +from optax import Params +from typing import NamedTuple, Callable + + +PreconditionerState = NamedTuple + + +class Preconditioner(NamedTuple): + """A sampler preconditioner class. + + Attributes + ---------- + init: Callable + The state initialization function. + update_preconditioner: Callable + The state update function that takes gradients as an input. + multiply_by_m_sqrt: Callable + The function that multiples its input by the square root of mass matrix :math:`\sqrt{M}`. + multiply_by_m_inv: Callable + The function that multiples its input by the mass matrix inverse :math:`M^{-1}`. + multiply_by_m_sqrt_inv: Callable + The function that multiples its input by the square root of mass matrix inverse. + + """ + + init: Callable[[Params], PreconditionerState] + update_preconditioner: Callable[[PyTree, PreconditionerState], PreconditionerState] + multiply_by_m_sqrt: Callable[[PyTree, PreconditionerState], PyTree] + multiply_by_m_inv: Callable[[PyTree, PreconditionerState], PyTree] + multiply_by_m_sqrt_inv: Callable[[PyTree, PreconditionerState], PyTree] + + +class RMSPropPreconditionerState(PreconditionerState): + grad_moment_estimates: Params + + +def rmsprop_preconditioner(running_average_factor: float = 0.99, eps: float = 1.0e-7): + """Create an instance of the adaptive RMSProp preconditioner. + + Parameters + ---------- + running_average_factor: float + The decay factor for the squared gradients moving average. + eps: float + :math:`\epsilon` constant for numerical stability. + + Returns + ------- + preconditioner: Preconditioner + An instance of RMSProp preconditioner. + """ + + def init_fn(params): + return RMSPropPreconditionerState( + grad_moment_estimates=jax.tree_util.tree_map(jnp.zeros_like, params) + ) + + def update_preconditioner_fn(gradient, preconditioner_state): + grad_moment_estimates = jax.tree_util.tree_map( + lambda e, g: e * running_average_factor + + g**2 * (1 - running_average_factor), + preconditioner_state.grad_moment_estimates, + gradient, + ) + return RMSPropPreconditionerState(grad_moment_estimates=grad_moment_estimates) + + def multiply_by_m_inv_fn(vec, preconditioner_state): + return jax.tree_util.tree_map( + lambda e, v: v / (eps + jnp.sqrt(e)), + preconditioner_state.grad_moment_estimates, + vec, + ) + + def multiply_by_m_sqrt_fn(vec, preconditioner_state): + return jax.tree_util.tree_map( + lambda e, v: v * jnp.sqrt(eps + jnp.sqrt(e)), + preconditioner_state.grad_moment_estimates, + vec, + ) + + def multiply_by_m_sqrt_inv_fn(vec, preconditioner_state): + return jax.tree_util.tree_map( + lambda e, v: v / jnp.sqrt(eps + jnp.sqrt(e)), + preconditioner_state.grad_moment_estimates, + vec, + ) + + return Preconditioner( + init=init_fn, + update_preconditioner=update_preconditioner_fn, + multiply_by_m_inv=multiply_by_m_inv_fn, + multiply_by_m_sqrt=multiply_by_m_sqrt_fn, + multiply_by_m_sqrt_inv=multiply_by_m_sqrt_inv_fn, + ) + + +class IdentityPreconditionerState(PreconditionerState): + pass + + +def identity_preconditioner(): + """Create an instance of no-op identity preconditioner. + + Returns + ------- + preconditioner: Preconditioner + An instance of identity preconditioner. + """ + + def init_fn(_): + return IdentityPreconditionerState() + + def update_preconditioner_fn(*args, **kwargs): + return IdentityPreconditionerState() + + def multiply_by_m_inv_fn(vec, _): + return vec + + def multiply_by_m_sqrt_fn(vec, _): + return vec + + def multiply_by_m_sqrt_inv_fn(vec, _): + return vec + + return Preconditioner( + init=init_fn, + update_preconditioner=update_preconditioner_fn, + multiply_by_m_inv=multiply_by_m_inv_fn, + multiply_by_m_sqrt=multiply_by_m_sqrt_fn, + multiply_by_m_sqrt_inv=multiply_by_m_sqrt_inv_fn, + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py new file mode 100644 index 00000000..36abf636 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py @@ -0,0 +1,61 @@ +from typing import Optional +import pathlib + +from fortuna.training.train_state import TrainState +from fortuna.training.callback import Callback +from fortuna.training.train_state_repository import TrainStateRepository +from fortuna.training.trainer import TrainerABC +from fortuna.typing import Path + + +class SGMCMCSamplingCallback(Callback): + def __init__( + self, + trainer: TrainerABC, + state_repository: TrainStateRepository, + keep_top_n_checkpoints: int, + save_checkpoint_dir: Optional[Path] = None, + ): + """ + Sampling callback that collects samples from the MCMC chain. + + Parameters + ---------- + trainer: TrainerABC + An instance of the trainer class. + state_repository: TrainStateRepository + An instance of the state repository. + keep_top_n_checkpoints: int + Number of past checkpoint files to keep. + save_checkpoint_dir: Optional[Path] + The optional path to save checkpoints. + """ + self._trainer = trainer + self._state_repository = state_repository + self._keep_top_n_checkpoints = keep_top_n_checkpoints + self._save_checkpoint_dir = save_checkpoint_dir + + self._current_step = 0 + self._samples_count = 0 + + def _do_sample(self, current_step, samples_count): + raise NotImplementedError + + def training_step_end(self, state: TrainState) -> TrainState: + self._current_step += 1 + + if self._do_sample(self._current_step, self._samples_count): + if self._save_checkpoint_dir: + self._trainer.save_checkpoint( + state, + pathlib.Path(self._save_checkpoint_dir) / str(self._samples_count), + force_save=True, + ) + self._state_repository.put( + state=state, + i=self._samples_count, + keep=self._keep_top_n_checkpoints, + ) + self._samples_count += 1 + + return state diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_step_schedule.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_step_schedule.py new file mode 100644 index 00000000..ef383787 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_step_schedule.py @@ -0,0 +1,148 @@ +import numpy as np +import jax.numpy as jnp +from fortuna.typing import Array + +from typing import Callable + +StepSchedule = Callable[[Array], Array] + + +def constant_schedule(init_step_size: float) -> StepSchedule: + """Create a constant step schedule. + + Parameters + ---------- + init_step_size: float + The step size. + + Returns + ------- + schedule_fn: StepSchedule + """ + if not init_step_size >= 0: + raise ValueError("`init_step_size` should be >= 0.") + + def schedule(_step: Array): + return init_step_size + + return schedule + + +def cosine_schedule(init_step_size: float, total_steps: int) -> StepSchedule: + """Create a cosine step schedule. + + Parameters + ---------- + init_step_size: float + The initial step size. + total_steps: int + The cycle length, in steps. + + Returns + ------- + schedule_fn: StepSchedule + """ + if not init_step_size >= 0: + raise ValueError("`init_step_size` should be >= 0.") + if not total_steps > 0: + raise ValueError("`total_steps` should be > 0.") + + def schedule(step: Array): + t = step / total_steps + return 0.5 * init_step_size * (1 + jnp.cos(t * np.pi)) + + return schedule + + +def polynomial_schedule( + a: float = 1.0, b: float = 1.0, gamma: float = 0.55 +) -> StepSchedule: + """Create a polynomial step schedule. + + Parameters + ---------- + a: float + Scale of all step sizes. + b: float + The stabilization constant. + gamma: float + The decay rate :math:`\gamma \in (0.5, 1.0]`. + + Returns + ------- + schedule_fn: StepSchedule + """ + + if not 0.5 < gamma <= 1.0: + raise ValueError("`gamma` should be in (0.5, 1.0] range.") + + def schedule(step: Array): + return a * (b + step) ** (-gamma) + + return schedule + + +def constant_schedule_with_cosine_burnin( + init_step_size: float, final_step_size: float, burnin_steps: int +) -> StepSchedule: + """Create a constant schedule with cosine burn-in. + + Parameters + ---------- + init_step_size: float + The initial step size. + final_step_size: float + The desired final step size. + burnin_steps: int + The length of burn-in, in steps. + + Returns + ------- + schedule_fn: StepSchedule + """ + if not init_step_size >= 0: + raise ValueError("`init_step_size` should be >= 0.") + if not final_step_size >= 0: + raise ValueError("`final_step_size` should be >= 0.") + if not burnin_steps >= 0: + raise ValueError("`burnin_steps` should be >= 0.") + + def schedule(step: Array): + t = jnp.minimum(step / burnin_steps, 1.0) + coef = (1 + jnp.cos(t * np.pi)) * 0.5 + return coef * init_step_size + (1 - coef) * final_step_size + + return schedule + + +def cyclical_cosine_schedule_with_const_burnin( + init_step_size: float, burnin_steps: int, cycle_length: int +) -> StepSchedule: + """Create a cyclical cosine schedule with constant burn-in. + + Parameters + ---------- + init_step_size: float + The initial step size. + burnin_steps: int + The length of burn-in, in steps. + cycle_length: int + The length of the cosine cycle, in steps. + + Returns + ------- + schedule_fn: StepSchedule + """ + if not init_step_size >= 0: + raise ValueError("`init_step_size` should be >= 0.") + if not burnin_steps >= 0: + raise ValueError("`burnin_steps` should be >= 0.") + if not cycle_length >= 0: + raise ValueError("`cycle_length` should be >= 0.") + + def schedule(step: Array): + t = jnp.maximum(step - burnin_steps - 1, 0.0) + t = (t % cycle_length) / cycle_length + return 0.5 * init_step_size * (1 + jnp.cos(t * np.pi)) + + return schedule diff --git a/tests/fortuna/prob_model/test_diagnostic.py b/tests/fortuna/prob_model/test_diagnostic.py new file mode 100755 index 00000000..6ba640fa --- /dev/null +++ b/tests/fortuna/prob_model/test_diagnostic.py @@ -0,0 +1,58 @@ +import unittest +from functools import partial + +import numpy as np +import jax.numpy as jnp +from jax import value_and_grad, vmap +from jax.flatten_util import ravel_pytree + +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic import ( + kernel_stein_discrepancy_imq, + effective_sample_size, +) + +DATA_SIZE = 1000 + + +class TestDiagnostic(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.rng = np.random.default_rng(0) + + self.mu = np.array([0., 0.]) + self.sigma = np.array([[1.5, 0.5], [0.5, 1.5]]) + + def _mvn_log_density(params, mu=self.mu, sigma=self.sigma): + diff = params - mu + log_density = -jnp.log(2 * jnp.pi) * mu.size / 2 + log_density -= jnp.log(jnp.linalg.det(sigma)) / 2 + log_density -= diff.T @ jnp.linalg.inv(sigma) @ diff / 2 + return log_density + + self.mvn_log_density_grad = vmap(value_and_grad(_mvn_log_density)) + + def unflatten(self, x, keys=("x", "y")): + assert len(x.shape) == 2 and x.shape[-1] == len(keys) + return [{k:v for k, v in zip(keys, val)} for val in x] + + def test_ksd_imq(self): + samp1_flat = self.rng.multivariate_normal(self.mu, self.sigma, size=DATA_SIZE) + samp2_flat = self.rng.multivariate_normal(self.mu, self.sigma ** 3, size=DATA_SIZE) + _, grad1_flat = self.mvn_log_density_grad(samp1_flat) + _, grad2_flat = self.mvn_log_density_grad(samp2_flat) + assert kernel_stein_discrepancy_imq(samp1_flat, grad1_flat) < \ + kernel_stein_discrepancy_imq(samp2_flat, grad2_flat) + samp1_tree = self.unflatten(samp1_flat) + grad1_tree = self.unflatten(grad1_flat) + assert jnp.allclose(kernel_stein_discrepancy_imq(samp1_flat, grad1_flat), + kernel_stein_discrepancy_imq(samp1_tree, grad1_tree)) + + def test_ess(self): + samp1_flat = self.rng.multivariate_normal(self.mu, self.sigma, size=DATA_SIZE) + ess = effective_sample_size(samp1_flat) + assert samp1_flat.shape[-1] == ess.shape[0] + assert jnp.alltrue(0 <= ess) and jnp.alltrue(ess <= len(samp1_flat)) + samp1_tree = self.unflatten(samp1_flat) + ess_tree = effective_sample_size(samp1_tree) + vals, _unravel_fn = ravel_pytree(ess_tree) + assert jnp.allclose(vals, ess) diff --git a/tests/fortuna/prob_model/test_preconditioner.py b/tests/fortuna/prob_model/test_preconditioner.py new file mode 100755 index 00000000..50d7fc3b --- /dev/null +++ b/tests/fortuna/prob_model/test_preconditioner.py @@ -0,0 +1,39 @@ +import unittest + +import jax.numpy as jnp + +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + rmsprop_preconditioner, + identity_preconditioner, +) + + +class TestPreconditioner(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.params = {"p1": jnp.zeros([1, 2], jnp.float32), + "p2": jnp.zeros([2, 1], jnp.float32)} + self.grad = {"p1": jnp.ones([1, 2], jnp.float32), + "p2": jnp.ones([2, 1], jnp.float32)} + + def test_rmsprop(self): + preconditioner = rmsprop_preconditioner() + state = preconditioner.init(self.params) + state = preconditioner.update_preconditioner(self.grad, state) + result = preconditioner.multiply_by_m_inv(self.params, state) + assert "p1" in result and "p2" in result + result = preconditioner.multiply_by_m_sqrt(self.params, state) + assert "p1" in result and "p2" in result + result = preconditioner.multiply_by_m_sqrt_inv(self.params, state) + assert "p1" in result and "p2" in result + + def test_identity(self): + preconditioner = identity_preconditioner() + state = preconditioner.init(self.params) + state = preconditioner.update_preconditioner(self.grad, state) + result = preconditioner.multiply_by_m_inv(self.params, state) + assert "p1" in result and "p2" in result + result = preconditioner.multiply_by_m_sqrt(self.params, state) + assert "p1" in result and "p2" in result + result = preconditioner.multiply_by_m_sqrt_inv(self.params, state) + assert "p1" in result and "p2" in result diff --git a/tests/fortuna/prob_model/test_step_schedule.py b/tests/fortuna/prob_model/test_step_schedule.py new file mode 100755 index 00000000..bdd6f2ad --- /dev/null +++ b/tests/fortuna/prob_model/test_step_schedule.py @@ -0,0 +1,54 @@ +import unittest + +import jax.numpy as jnp + +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( + constant_schedule, + cosine_schedule, + polynomial_schedule, + constant_schedule_with_cosine_burnin, + cyclical_cosine_schedule_with_const_burnin, +) + + +class TestStepSchedule(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.count = jnp.zeros([], jnp.int32) + + def test_constant(self): + schedule_fn = constant_schedule(init_step_size=1e-1) + assert jnp.allclose(schedule_fn(self.count), 1e-1) + assert jnp.allclose(schedule_fn(self.count + 1), 1e-1) + + def test_cosine(self): + schedule_fn = cosine_schedule(init_step_size=1e-1, total_steps=10) + assert jnp.allclose(schedule_fn(self.count), 1e-1) + assert not jnp.allclose( + schedule_fn(self.count + 1), schedule_fn(self.count) + ) + assert jnp.allclose(schedule_fn(self.count + 10), 0) + assert jnp.allclose(schedule_fn(self.count + 20), 1e-1) + + def test_polynomial(self): + schedule_fn = polynomial_schedule() + assert schedule_fn(self.count + 1) < schedule_fn(self.count) + + def test_cosine_burnin(self): + schedule_fn = constant_schedule_with_cosine_burnin( + init_step_size=1e-1, final_step_size=1e-2, burnin_steps=10 + ) + assert jnp.allclose(schedule_fn(self.count), 1e-1) + assert not jnp.allclose( + schedule_fn(self.count + 1), schedule_fn(self.count) + ) + assert jnp.allclose(schedule_fn(self.count + 10), 1e-2) + assert jnp.allclose(schedule_fn(self.count + 11), 1e-2) + + def test_const_burnin(self): + schedule_fn = cyclical_cosine_schedule_with_const_burnin( + init_step_size=1e-1, burnin_steps=10, cycle_length=10 + ) + assert jnp.allclose(schedule_fn(self.count), 1e-1) + assert jnp.allclose(schedule_fn(self.count + 1), 1e-1) + assert not jnp.allclose(schedule_fn(self.count + 12), 1e-1) diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index 033c506d..24c3ae92 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -1,62 +1,57 @@ import tempfile -import numpy as np -import pytest - from fortuna.data.loader import DataLoader from fortuna.metric.classification import accuracy from fortuna.metric.regression import rmse -from fortuna.prob_model import ( - CalibConfig, - CalibOptimizer, - FitConfig, - FitMonitor, - SNGPPosteriorApproximator, -) +from fortuna.prob_model import (CalibConfig, CalibOptimizer, SNGPPosteriorApproximator) from fortuna.prob_model.classification import ProbClassifier +from fortuna.prob_model import FitConfig, FitMonitor from fortuna.prob_model.fit_config.checkpointer import FitCheckpointer from fortuna.prob_model.fit_config.optimizer import FitOptimizer -from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_posterior import ( - DeepEnsemblePosteriorApproximator, -) -from fortuna.prob_model.posterior.laplace.laplace_posterior import ( - LaplacePosteriorApproximator, -) -from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator -from fortuna.prob_model.posterior.normalizing_flow.advi.advi_posterior import ( - ADVIPosteriorApproximator, -) -from fortuna.prob_model.posterior.swag.swag_posterior import SWAGPosteriorApproximator +from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_posterior import \ + DeepEnsemblePosteriorApproximator +from fortuna.prob_model.posterior.laplace.laplace_posterior import \ + LaplacePosteriorApproximator +from fortuna.prob_model.posterior.map.map_approximator import \ + MAPPosteriorApproximator +from fortuna.prob_model.posterior.normalizing_flow.advi.advi_posterior import \ + ADVIPosteriorApproximator +from fortuna.prob_model.posterior.swag.swag_posterior import \ + SWAGPosteriorApproximator +from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior import \ + SGHMCPosteriorApproximator +from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior import \ + CyclicalSGLDPosteriorApproximator from fortuna.prob_model.regression import ProbRegressor from tests.make_data import make_array_random_data -from tests.make_model import ( - MyModel, - MyModelWithSpectralNorm, -) +from tests.make_model import MyModel, MyModelWithSpectralNorm +import numpy as np +import pytest -np.random.seed(42) OUTPUT_DIM = 2 +BATCH_SIZE = 8 +INPUT_SHAPE = (3,) +N_DATA = 10 -TASKS = ["regression", "classification"] METHODS = { "map": MAPPosteriorApproximator(), "advi": ADVIPosteriorApproximator(), "laplace": LaplacePosteriorApproximator(), "swag": SWAGPosteriorApproximator(rank=2), "deep_ensemble": DeepEnsemblePosteriorApproximator(ensemble_size=2), - "sngp": SNGPPosteriorApproximator(output_dim=OUTPUT_DIM), + "sngp": SNGPPosteriorApproximator(output_dim=OUTPUT_DIM, + gp_hidden_features=2), + "sghmc": SGHMCPosteriorApproximator(n_samples=3, + n_thinning=1, + burnin_length=1), + "cyclical_sgld": CyclicalSGLDPosteriorApproximator(n_samples=3, + n_thinning=1, + cycle_length=4), } -TASKS_METHODS = [ - (task, method) - for task in TASKS - for method in METHODS - if (task, method) != ("regression", "sngp") -] -TASKS_IDS = [t + "-" + m for t, m in TASKS_METHODS] -def make_data_loader(task, n_data, input_shape, output_dim, batch_size): +def make_data_loader(task, n_data=N_DATA, input_shape=INPUT_SHAPE, output_dim=OUTPUT_DIM, batch_size=BATCH_SIZE): x_train, y_train = make_array_random_data( n_data=n_data, shape_inputs=input_shape, @@ -69,166 +64,156 @@ def make_data_loader(task, n_data, input_shape, output_dim, batch_size): return DataLoader.from_array_data((x_train, y_train), batch_size=batch_size) -@pytest.mark.parametrize("task, method", TASKS_METHODS, ids=TASKS_IDS) -def test_dryrun(task, method): - batch_size = 32 - input_shape = (3,) - n_data = 100 - - train_data_loader = make_data_loader( - task, n_data, input_shape, OUTPUT_DIM, batch_size - ) - val_data_loader = make_data_loader( - task, n_data, input_shape, OUTPUT_DIM, batch_size - ) - calib_data_loader = make_data_loader( - task, n_data, input_shape, OUTPUT_DIM, batch_size - ) - - freeze_fun = lambda p, v: "trainable" if "l2" in p and "model" in p else "frozen" - - fit_config = lambda restore_path, start_current, save_dir, dump_state, save_n_steps, freeze: FitConfig( - optimizer=FitOptimizer(n_epochs=3, freeze_fun=freeze), - monitor=FitMonitor(metrics=(accuracy if task == "classification" else rmse,)), +def fit_config(task, restore_path, start_current, save_dir, dump_state, save_n_steps, freeze): + return FitConfig( + optimizer=FitOptimizer( + n_epochs=3, + freeze_fun=freeze + ), + monitor=FitMonitor( + metrics=(accuracy if task == "classification" else rmse,) + ), checkpointer=FitCheckpointer( start_from_current_state=start_current, restore_checkpoint_path=restore_path, save_checkpoint_dir=save_dir, dump_state=dump_state, - save_every_n_steps=save_n_steps, + save_every_n_steps=save_n_steps + ) + ) + + +calib_config = CalibConfig( + optimizer=CalibOptimizer(n_epochs=3) +) + + +def train(task, model, train_data_loader, val_data_loader, calib_data_loader, + restore_path=None, start_current=False, save_dir=None, + dump_state=False, save_n_steps=None, freeze=None, + map_fit_config=None): + model.train( + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + calib_data_loader=calib_data_loader, + fit_config=fit_config( + task, restore_path, start_current, save_dir, + dump_state, save_n_steps, freeze ), + calib_config=calib_config, + map_fit_config=map_fit_config ) - calib_config = CalibConfig(optimizer=CalibOptimizer(n_epochs=3)) - - def train( - restore_path=None, - start_current=False, - save_dir=None, - dump_state=False, - save_n_steps=None, - freeze=None, - map_fit_config=None, - ): - prob_model.train( - train_data_loader=train_data_loader, - val_data_loader=val_data_loader, - calib_data_loader=calib_data_loader, - fit_config=fit_config( - restore_path, start_current, save_dir, dump_state, save_n_steps, freeze - ), - calib_config=calib_config, - map_fit_config=map_fit_config, + +def sample(method, model, train_data_loader): + if method in ["swag"]: + sample = model.posterior.sample( + inputs_loader=train_data_loader.to_inputs_loader() ) + else: + sample = model.posterior.sample() + - def sample(): - if method in ["swag"]: - sample = prob_model.posterior.sample( - inputs_loader=train_data_loader.to_inputs_loader() - ) - else: - sample = prob_model.posterior.sample() - - def train_and_sample( - restore_path=None, - start_current=False, - save_dir=None, - dump_state=False, - save_n_steps=None, - freeze=None, - map_fit_config=None, - ): - train( - restore_path, - start_current, - save_dir, - dump_state, - save_n_steps, - freeze, - map_fit_config, +def train_and_sample(task, method, model, train_data_loader, val_data_loader, + calib_data_loader, restore_path=None, + start_current=False, save_dir=None, dump_state=False, + save_n_steps=None, freeze=None, map_fit_config=None): + train(task, model, train_data_loader, val_data_loader, calib_data_loader, + restore_path, start_current, save_dir, dump_state, save_n_steps, + freeze, map_fit_config) + sample(method, model, train_data_loader) + + +def define_prob_model(task, method): + if task == "regression": + return ProbRegressor( + model=MyModel(OUTPUT_DIM), + likelihood_log_variance_model=MyModel(OUTPUT_DIM), + posterior_approximator=METHODS[method] ) - sample() - - def define_prob_model(): - if task == "regression": - return ProbRegressor( - model=MyModel(OUTPUT_DIM), - likelihood_log_variance_model=MyModel(OUTPUT_DIM), - posterior_approximator=METHODS[method], - ) - else: - return ProbClassifier( - model=MyModel(OUTPUT_DIM) - if method != "sngp" - else MyModelWithSpectralNorm(OUTPUT_DIM), - posterior_approximator=METHODS[method], - ) - - prob_model = define_prob_model() - train_and_sample( - map_fit_config=fit_config( - restore_path=None, - start_current=None, - save_dir=None, - dump_state=False, - save_n_steps=None, - freeze=None, + else: + return ProbClassifier( + model=MyModel(OUTPUT_DIM) if method != "sngp" else MyModelWithSpectralNorm(OUTPUT_DIM), + posterior_approximator=METHODS[method] ) - ) - train_and_sample(start_current=True) + + +def dryrun_task(task, method): + freeze_fun = lambda p, v: "trainable" if "l2" in p and "model" in p else "frozen" + + train_data_loader = make_data_loader(task) + val_data_loader = make_data_loader(task) + calib_data_loader = make_data_loader(task) + + prob_model = define_prob_model(task, method) + map_fit_config = fit_config(task, restore_path=None, start_current=None, + save_dir=None, dump_state=False, + save_n_steps=None, freeze=None) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, map_fit_config=map_fit_config) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, start_current=True) + if method not in ["laplace", "swag"]: - train_and_sample() + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader) with tempfile.TemporaryDirectory() as tmp_dir: - train_and_sample( - map_fit_config=fit_config( - restore_path=None, - start_current=None, - save_dir=None, - dump_state=False, - save_n_steps=None, - freeze=None, - ), - save_dir=tmp_dir, - dump_state=True, - ) - train_and_sample(restore_path=tmp_dir) - - prob_model = define_prob_model() + map_fit_config = fit_config(task, restore_path=None, + start_current=None, save_dir=None, + dump_state=False, save_n_steps=None, + freeze=None) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, + map_fit_config=map_fit_config, save_dir=tmp_dir, + dump_state=True) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, + restore_path=tmp_dir) + + prob_model = define_prob_model(task, method) prob_model.load_state(tmp_dir) - sample() + sample(method, prob_model, train_data_loader) prob_model.predictive.log_prob(train_data_loader) if method not in ["laplace", "swag"]: - train_and_sample(freeze=freeze_fun) - train_and_sample(start_current=True, freeze=freeze_fun) - train_and_sample( - save_dir=tmp_dir, dump_state=True, restore_path=tmp_dir, freeze=freeze_fun - ) - train_and_sample( - save_dir=tmp_dir, dump_state=True, restore_path=tmp_dir, freeze=freeze_fun - ) - train_and_sample( - map_fit_config=fit_config( - restore_path=None, - start_current=None, - save_dir=None, - dump_state=False, - save_n_steps=None, - freeze=None, - ), - save_dir=tmp_dir, - dump_state=True, - freeze=freeze_fun, - ) - - train_and_sample( - start_current=True, - save_dir=tmp_dir + "/tmp", - save_n_steps=1, - freeze=freeze_fun, - ) - prob_model = define_prob_model() + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, + freeze=freeze_fun) + + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, + start_current=True, freeze=freeze_fun) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, save_dir=tmp_dir, + dump_state=True, restore_path=tmp_dir, + freeze=freeze_fun) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, save_dir=tmp_dir, + dump_state=True, restore_path=tmp_dir, + freeze=freeze_fun) + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, + map_fit_config=fit_config(task, restore_path=None, + start_current=None, save_dir=None, dump_state=False, + save_n_steps=None, freeze=None), save_dir=tmp_dir, + dump_state=True, freeze=freeze_fun) + + train_and_sample(task, method, prob_model, train_data_loader, + val_data_loader, calib_data_loader, + start_current=True, save_dir=tmp_dir + "/tmp", + save_n_steps=1, freeze=freeze_fun) + prob_model = define_prob_model(task, method) prob_model.load_state(tmp_dir + "/tmp") - sample() + sample(method, prob_model, train_data_loader) prob_model.predictive.log_prob(train_data_loader) + + +@pytest.mark.parametrize("method", METHODS.keys()) +def test_dryrun_classification(method): + dryrun_task(task="classification", method=method) + + +@pytest.mark.parametrize("method", [m for m in METHODS.keys() if m != "sngp"]) +def test_dryrun_regression(method): + dryrun_task(task="regression", method=method)