Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Discrete MCMC with JAX and Numpyro #29

Draft
wants to merge 1 commit into
base: generative_models
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions src/qml_benchmarks/mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Markov-chain Monte Carlo (MCMC) implementation with JAX and Numpyro."""

from collections import namedtuple
from functools import partial

import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from numpyro.infer.mcmc import MCMCKernel

MHState = namedtuple("MHState", ["state", "rng_key"])


class MetropolisHastings(MCMCKernel):
"""A simple Metropolis-Hastings MCMC kernel.

Args:
potential_fn (callable):
A callable representing the energy function.
"""

sample_field = "state"

def __init__(self, potential_fn):
"""_summary_

Args:
potential_fn (callable): Potenital energy function.
"""
self.potential_fn = potential_fn

def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
"""_summary_

Args:
rng_key (_type_): _description_
num_warmup (_type_): _description_
init_params (_type_): _description_
model_args (_type_): _description_
model_kwargs (_type_): _description_

Returns:
_type_: _description_
"""
return MHState(init_params, rng_key)

def sample(self, state, model_args, model_kwargs):
"""_summary_

Args:
state (_type_): _description_
model_args (_type_): _description_
model_kwargs (_type_): _description_

Returns:
_type_: _description_
"""
u, rng_key = state
rng_key, key_proposal, key_accept = random.split(rng_key, 3)
u_proposal = self.proposal(key_proposal, u)

accept_prob = jnp.exp(
self.potential_fn(u, *model_args, **model_kwargs)
- self.potential_fn(u_proposal, *model_args, **model_kwargs)
)
u_new = jnp.where(
dist.Uniform().sample(key_accept) < accept_prob, u_proposal, u
)
return MHState(u_new, rng_key)
# spins, rng_key = state
# num_spins = spins.size

# def mh_step(i, val):
# spins, rng_key = val
# rng_key, subkey = random.split(rng_key)
# flip_index = random.randint(subkey, (), 0, num_spins)
# spins_proposal = spins.at[flip_index].set(-spins[flip_index])

# current_energy = self.potential_fn(
# spins, *model_args, **model_kwargs)
# proposed_energy = self.potential_fn(
# spins_proposal, *model_args, **model_kwargs)
# delta_energy = proposed_energy - current_energy
# accept_prob = jnp.exp(-delta_energy)

# rng_key, subkey = random.split(rng_key)
# accept = random.uniform(subkey) < accept_prob
# spins = jnp.where(accept, spins_proposal, spins)
# return spins, rng_key

# spins, rng_key = jax.lax.fori_loop(0, num_spins, mh_step, (spins, rng_key))
# return MHState(spins, rng_key)

def proposal(self, key, u):
"""Make a new proposal by flipping spins randomly.

Args:
key (_type_): _description_
u (_type_): _description_

Returns:
_type_: _description_
"""
num_spins = u.size
# Generate a number of indices for flipping spins
flip_indices = random.randint(key, (num_spins,), 0, num_spins)
# Flip the selected spins
u_proposal = u.at[flip_indices].set(-u[flip_indices])
return u_proposal
84 changes: 84 additions & 0 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for the MCMC implementation using JAX and Numpyro."""

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import random
from numpyro.infer import MCMC
from qml_benchmarks.mcmc import MetropolisHastings


# Simple energy functions for which we should know the posteriors for testing
@jax.jit
def energy_sum(x: jnp.array) -> float:
"""Simple energy function as the sum of the spins.

Args:
x (jnp.array): State of the system.

Returns:
float: Energy value.
"""
return jnp.sum(x)


@pytest.mark.parametrize("num_samples", [10, 100, 1000])
@pytest.mark.parametrize("num_chains", [2, 4])
@pytest.mark.parametrize("dim", [1, 3, 5])
def test_mcmc_runs(num_samples, num_chains, dim):
"""Testing that the MCMC implementation runs."""

kernel = MetropolisHastings(energy_sum)

# Create initial state with random values in [-1, 1]
key = random.PRNGKey(0)
init_params = random.choice(key, jnp.array([-1, 1]), shape=(num_chains, dim))

mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples, num_chains=num_chains)

mcmc.run(random.PRNGKey(0), init_params=init_params)
posterior_samples = mcmc.get_samples()
assert posterior_samples.shape == (num_samples * num_chains, dim)
assert jnp.all(jnp.isin(posterior_samples, jnp.array([1, -1])))


@pytest.mark.parametrize(
"energy_fn",
[
energy_sum,
],
)
@pytest.mark.parametrize("num_samples", [10000])
@pytest.mark.parametrize("dim", [3, 4, 5])
def test_mcmc(energy_fn, num_samples, dim):
"""Test MCMC sampling with different energy functions"""

# Initialize the kernel with the potential function
kernel = MetropolisHastings(energy_fn)
num_chains = 4

# Create initial state with random values in [-1, 1]
key = random.PRNGKey(0)
init_params = random.choice(key, jnp.array([-1, 1]), shape=(num_chains, dim))

# Run the MCMC
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples, num_chains=num_chains)
mcmc.run(random.PRNGKey(0), init_params=init_params)

samples = mcmc.get_samples()
assert jnp.all(jnp.isin(samples, jnp.array([1, -1])))

energies = jax.vmap(energy_fn)(samples)
energy_values, counts = np.unique(np.array(energies), return_counts=True)
energy_hist = dict(zip(tuple(counts), tuple(energy_values)))

# Check that the lowest energy value appears most frequently
assert energy_hist[np.max(counts)] == jnp.min(energies)

# Check that the highest energy value appears least frequently
assert energy_hist[np.min(counts)] == jnp.max(energies)


if __name__ == "__main__":
pytest.main(["-v", __file__])