Skip to content

Commit

Permalink
add SG-MCMC methods (awslabs#46)
Browse files Browse the repository at this point in the history
* add SGHMC method

* add SG-MCMC diagnostics

* add polynomial step schedule

* implement SGHMC momentum resampling

* refactor SGMCMC posterior into a top-level class

* add Cyclical SGLD method

* implement chain thinning

* update documentation

* code formatting fixes

* refactor SGMCMC methods

* update documentation

* rebase to the new API & update docs

* address code review feedback

* fix the Cyclical SGLD sampler

* fix unit tests

* lint the code

* make black show diffs in the workflow

* rebase to the latest API
  • Loading branch information
master authored May 16, 2023
1 parent 262316c commit ab13c1d
Show file tree
Hide file tree
Showing 37 changed files with 2,419 additions and 206 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ jobs:
pip install black
- name: Run Black
run: black --check --verbose fortuna
run: black --check --diff --verbose fortuna
8 changes: 8 additions & 0 deletions docs/source/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ Posterior approximation methods
taken by averaging checkpoints over the stochastic optimization trajectory. The covariance is also estimated
empirically along the trajectory, and it is made of a diagonal component and a low-rank non-diagonal one.

- **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)** `[Chen et al., 2014] <http://proceedings.mlr.press/v32/cheni14.pdf>`_
SGHMC approximates the posterior as a steady-state distribution of a Monte Carlo Markov chain with Hamiltonian dynamics.
After the initial "burn-in" phase, each step of the chain generates samples from the posterior.

- **Cyclical Stochastic Gradient Langevin Dynamics (Cyclical SGLD)** `[Zhang et al., 2020] <https://openreview.net/pdf?id=rkeS1RVtPS>`_
Cyclical SGLD adapts the cyclical cosine step size schedule, and alternates between *exploration* and *sampling* stages to better
explore the multimodal posteriors for deep neural networks.

Parametric calibration methods
------------------------------
Fortuna supports parametric calibration by adding an output calibration model on top of the outputs of the model used for
Expand Down
1 change: 1 addition & 0 deletions docs/source/references/prob_model/posterior/posterior.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ calibration parameters. We support several posterior approximations:
laplace
swag
sngp
sgmcmc

.. _posterior:

Expand Down
82 changes: 82 additions & 0 deletions docs/source/references/prob_model/posterior/sgmcmc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
Stochastic Gradient Markov Chain Monte Carlo (SG-MCMC)
------------------------------------------------------
SG-MCMC procedures approximate the posterior as a steady-state distribution of
a Monte Carlo Markov chain, that utilizes noisy estimates of the gradient
computed on minibatches of data.

Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)
===================================================

SGHMC `[Chen T. et al., 2014] <http://proceedings.mlr.press/v32/cheni14.pdf>`_
is a popular MCMC algorithm that uses stochastic gradient estimates to scale
to large datasets.

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator.SGHMCPosteriorApproximator

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior.SGHMCPosterior
:show-inheritance:
:no-inherited-members:
:exclude-members: state
:members: fit, sample, load_state, save_state

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state.SGHMCState
:show-inheritance:
:no-inherited-members:
:inherited-members: init, init_from_dict
:members: convert_from_map_state
:exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create
:no-undoc-members:
:no-special-members:

Cyclical Stochastic Gradient Langevin Dynamics (CyclicalSGLD)
=============================================================

Cyclical SGLD method `[Zhang R. et al., 2019] <https://openreview.net/pdf?id=rkeS1RVtPS>`_ is a simple and automatic
procedure that adapts the cyclical cosine stepsize schedule, and alternates between
*exploration* and *sampling* stages to better explore the multimodal posteriors for deep neural networks.

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator.CyclicalSGLDPosteriorApproximator

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior.CyclicalSGLDPosterior
:show-inheritance:
:no-inherited-members:
:exclude-members: state
:members: fit, sample, load_state, save_state

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_state.CyclicalSGLDState
:show-inheritance:
:no-inherited-members:
:inherited-members: init, init_from_dict
:members: convert_from_map_state
:exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create
:no-undoc-members:
:no-special-members:


Step schedules
==============

Fortuna supports various step schedulers for SG-MCMC
algorithms. :class:`~fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule.StepSchedule`
is a function that takes step count as an input and returns `float` step
size as an output.

.. automodule:: fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule


Preconditioners
===============

Fortuna provides implementations of preconditioners to improve samplers efficacy.

.. automodule:: fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner
:exclude-members: Preconditioner, PreconditionerState, RMSPropPreconditionerState, IdentityPreconditionerState


Diagnostics
===========

The library includes toolings necessary for diagnostics of the convergence of
SG-MCMC sampling procedures.

.. automodule:: fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic
2 changes: 2 additions & 0 deletions examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ In this section we show some examples of how to use Fortuna in classification an
subnet_calibration
two_moons_classification_sngp
scaling_up_bayesian_inference
mnist_classification_sghmc
sgmcmc_diagnostics
204 changes: 204 additions & 0 deletions examples/mnist_classification_sghmc.pct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.5
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %% [markdown]
# # MNIST Classification with Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)

# %% [markdown]
# In this notebook we demonstrate how to use Fortuna to obtain predictions uncertainty estimates from a simple neural network model trained for MNIST classification task, using the [SGHMC](http://proceedings.mlr.press/v32/cheni14.pdf) method.

# %% [markdown]
# ### Download MNIST data from TensorFlow
# Let us first download the MNIST data from [TensorFlow Datasets](https://www.tensorflow.org/datasets). Other sources would be equivalently fine.

# %%
import tensorflow as tf
import tensorflow_datasets as tfds


def download(split_range, shuffle=False):
ds = tfds.load(
name="MNIST",
split=f"train[{split_range}]",
as_supervised=True,
shuffle_files=True,
).map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))
if shuffle:
ds = ds.shuffle(10, reshuffle_each_iteration=True)
return ds.batch(128).prefetch(1)


train_data_loader, val_data_loader, test_data_loader = (
download(":80%", shuffle=True),
download("80%:90%"),
download("90%:"),
)

# %% [markdown]
# ### Convert data to a compatible data loader
# Fortuna helps you converting data and data loaders into a data loader that Fortuna can digest.

# %%
from fortuna.data import DataLoader

train_data_loader = DataLoader.from_tensorflow_data_loader(train_data_loader)
val_data_loader = DataLoader.from_tensorflow_data_loader(val_data_loader)
test_data_loader = DataLoader.from_tensorflow_data_loader(test_data_loader)

# %% [markdown]
# ### Build a probabilistic classifier
# Let us build a probabilistic classifier. This is an interface object containing several attributes that you can configure, i.e. `model`, `prior`, and `posterior_approximator`. In this example, we use a multilayer perceptron and an SGHMC posterior approximator. SGHMC (and SGMCMC methods, broadly) allows configuring a step size schedule function. For simplicity, we create a constant step schedule.

# %%
import flax.linen as nn

from fortuna.prob_model import ProbClassifier, SGHMCPosteriorApproximator
from fortuna.model import MLP

output_dim = 10
prob_model = ProbClassifier(
model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)),
posterior_approximator=SGHMCPosteriorApproximator(burnin_length=300,
step_schedule=4e-6),
)


# %% [markdown]
# ### Train the probabilistic model: posterior fitting and calibration
# We can now train the probabilistic model. This includes fitting the posterior distribution and calibrating the probabilistic model. We set the Markov chain burn-in phase to 20 epochs, followed by obtaining samples from the approximated posterior.

# %%
from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer
from fortuna.metric.classification import accuracy

status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
monitor=FitMonitor(metrics=(accuracy,)),
optimizer=FitOptimizer(n_epochs=30),
),
)


# %% [markdown]
# ### Estimate predictive statistics
# We can now compute some predictive statistics by invoking the `predictive` attribute of the probabilistic classifier, and the method of interest. Most predictive statistics, e.g. mean or mode, require a loader of input data points. You can easily get this from the data loader calling its method `to_inputs_loader`.

# %% pycharm={"name": "#%%\n"}
test_log_probs = prob_model.predictive.log_prob(data_loader=test_data_loader)
test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
inputs_loader=test_inputs_loader, means=test_means
)

# %% [markdown]
# ### Compute metrics
# In classification, the predictive mode is a prediction for labels, while the predictive mean is a prediction for the probability of each label. As such, we can use these to compute several metrics, e.g. the accuracy, the Brier score, the expected calibration error (ECE), etc.

# %%
from fortuna.metric.classification import (
accuracy,
expected_calibration_error,
brier_score,
)

test_targets = test_data_loader.to_array_targets()
acc = accuracy(preds=test_modes, targets=test_targets)
brier = brier_score(probs=test_means, targets=test_targets)
ece = expected_calibration_error(
preds=test_modes,
probs=test_means,
targets=test_targets,
plot=True,
plot_options=dict(figsize=(10, 2)),
)
print(f"Test accuracy: {acc}")
print(f"Brier score: {brier}")
print(f"ECE: {ece}")

# %% [markdown]
# ### Conformal prediction sets
# Fortuna allows to produce conformal prediction sets, that are sets of likely labels up to some coverage probability threshold. These can be computed starting from probability estimates obtained with or without Fortuna.

# %%
from fortuna.conformal import AdaptivePredictionConformalClassifier

val_means = prob_model.predictive.mean(inputs_loader=val_data_loader.to_inputs_loader())
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
val_probs=val_means,
test_probs=test_means,
val_targets=val_data_loader.to_array_targets(),
)

# %% [markdown]
# We can check that, on average, conformal sets for misclassified inputs are larger than for well classified ones.

# %%
import numpy as np

avg_size = np.mean([len(s) for s in np.array(conformal_sets, dtype="object")])
avg_size_wellclassified = np.mean(
[
len(s)
for s in np.array(conformal_sets, dtype="object")[test_modes == test_targets]
]
)
avg_size_misclassified = np.mean(
[
len(s)
for s in np.array(conformal_sets, dtype="object")[test_modes != test_targets]
]
)
print(f"Average conformal set size: {avg_size}")
print(
f"Average conformal set size over well classified input: {avg_size_wellclassified}"
)
print(f"Average conformal set size over misclassified input: {avg_size_misclassified}")

# %% [markdown]
# Furthermore, we visualize some of the examples with the largest and the smallest conformal sets. Intutively, they correspond to the inputs where the model is the most uncertain or the most certain about its predictions.

# %%
from matplotlib import pyplot as plt

N_EXAMPLES = 10
images = test_data_loader.to_array_inputs()

def visualize_examples(indices, n_examples=N_EXAMPLES):
n_rows = min(len(indices), n_examples)
_, axs = plt.subplots(1, n_rows, figsize=(10, 2))
axs = axs.flatten()
for i, ax in enumerate(axs):
ax.imshow(images[indices[i]], cmap='gray')
ax.axis("off")
plt.show()

# %%
indices = np.argsort(
[
len(s)
for s in np.array(conformal_sets, dtype="object")
]
)

# %%
print("Examples with the smallest conformal sets:")
visualize_examples(indices[:N_EXAMPLES])

# %%
print("Examples with the largest conformal sets:")
visualize_examples(np.flip(indices[-N_EXAMPLES:]))
Loading

0 comments on commit ab13c1d

Please sign in to comment.