-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX kmeans implementation #371
Draft
gileshd
wants to merge
3
commits into
main
Choose a base branch
from
ghd/kmeans-refactor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from functools import partial | ||
from jax import lax, jit | ||
from jax import numpy as jnp | ||
from jax import random as jr | ||
from jaxtyping import Array, Int, Float | ||
from typing import NamedTuple, Tuple | ||
|
||
|
||
def kmeans_sklearn( | ||
k: int, X: Float[Array, "num_samples state_dim"], key: Array | ||
) -> Tuple[Float[Array, "num_states state_dim"], Float[Array, "num_samples"]]: | ||
""" | ||
Compute the cluster centers and assignments using the sklearn K-means algorithm. | ||
|
||
Args: | ||
k (int): The number of clusters. | ||
X (Array(N, D)): The input data array. N samples of dimension D. | ||
key (Array): The random seed array. | ||
|
||
Returns: | ||
Array(k, D), Array(N,): The cluster centers and labels | ||
""" | ||
from sklearn.cluster import KMeans | ||
|
||
key, subkey = jr.split(key) # Create a random seed for SKLearn. | ||
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. | ||
km = KMeans(k, random_state=int(sklearn_key)).fit(X) | ||
return jnp.array(km.cluster_centers_), jnp.array(km.labels_) | ||
|
||
|
||
class KMeansState(NamedTuple): | ||
centroids: Float[Array, "num_states state_dim"] | ||
assignments: Int[Array, "num_samples"] | ||
prev_centroids: Float[Array, "num_states state_dim"] | ||
itr: int | ||
|
||
|
||
@partial(jit, static_argnums=(1, 3)) | ||
def kmeans_jax( | ||
X: Float[Array, "num_samples state_dim"], | ||
k: int, | ||
key: Array = jr.PRNGKey(0), | ||
max_iters: int = 1000, | ||
) -> KMeansState: | ||
""" | ||
Perform k-means clustering using JAX. | ||
|
||
K-means++ initialization is used to initialize the centroids. | ||
|
||
Args: | ||
X (Array): The input data array of shape (n_samples, n_features). | ||
k (int): The number of clusters. | ||
max_iters (int, optional): The maximum number of iterations. Defaults to 1000. | ||
key (PRNGKey, optional): The random key for initialization. Defaults to jr.PRNGKey(0). | ||
|
||
Returns: | ||
KMeansState: A named tuple containing the final centroids array of shape (k, n_features), | ||
the assignments array of shape (n_samples,) indicating the cluster index for each sample, | ||
the previous centroids array of shape (k, n_features), and the number of iterations. | ||
""" | ||
|
||
def _update_centroids(X: Array, assignments: Array): | ||
new_centroids = jnp.array([jnp.mean(X, axis=0, where=(assignments == i)[:, None]) for i in range(k)]) | ||
return new_centroids | ||
|
||
def _update_assignments(X, centroids): | ||
return jnp.argmin(jnp.linalg.norm(X[:, None] - centroids, axis=2), axis=1) | ||
|
||
def body(carry: KMeansState): | ||
centroids, assignments, *_ = carry | ||
new_centroids = _update_centroids(X, assignments) | ||
new_assignments = _update_assignments(X, new_centroids) | ||
return KMeansState(new_centroids, new_assignments, centroids, carry.itr + 1) | ||
|
||
def cond(carry: KMeansState): | ||
return jnp.any(carry.centroids != carry.prev_centroids) & (carry.itr < max_iters) | ||
|
||
def init(key): | ||
"""kmeans++ initialization of centroids | ||
|
||
Iteratively sample new centroids with probability proportional to the squared distance | ||
from the closest centroid. This initialization method is more stable than random | ||
initialization and leads to faster convergence. | ||
Ref: Arthur, D., & Vassilvitskii, S. (2006). | ||
""" | ||
centroids = jnp.zeros((k, X.shape[1])) | ||
centroids = centroids.at[0, :].set(jr.choice(key, X)) | ||
for i in range(1, k): | ||
squared_diffs = jnp.sum((X[:, None, :] - centroids[None, :i, :]) ** 2, axis=2) | ||
min_squared_dists = jnp.min(squared_diffs, axis=1) | ||
probs = min_squared_dists / jnp.sum(min_squared_dists) | ||
centroids = centroids.at[i, :].set(jr.choice(key, X, p=probs)) | ||
assignments = _update_assignments(X, centroids) | ||
# Perform one iteration to update centroids | ||
updated_centroids = _update_centroids(X, assignments) | ||
updated_assignments = _update_assignments(X, updated_centroids) | ||
return KMeansState(updated_centroids, updated_assignments, centroids, 1) | ||
|
||
init_state = init(key) | ||
state = lax.while_loop(cond, body, init_state) | ||
|
||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from jax import numpy as jnp | ||
from jax import random as jr | ||
from jax import vmap | ||
|
||
from dynamax.utils.cluster import kmeans_jax | ||
|
||
|
||
def test_kmeans_jax_toy(): | ||
"""Checks that kmeans works against toy example. | ||
|
||
Ref: scikit-learn tests | ||
""" | ||
|
||
key = jr.PRNGKey(101) | ||
x = jnp.array([[0, 0], [0.5, 0], [0.5, 1], [1, 1]]) | ||
|
||
centroids, assignments, *_ = kmeans_jax(x, 2, key) | ||
|
||
# There are two possible solutions for the centroids and assignments | ||
try: | ||
expected_labels = jnp.array([0, 0, 1, 1]) | ||
expected_centers = jnp.array([[0.25, 0], [0.75, 1]]) | ||
assert jnp.all(assignments == expected_labels) | ||
assert jnp.allclose(centroids, expected_centers) | ||
except AssertionError: | ||
expected_labels = jnp.array([1, 1, 0, 0]) | ||
expected_centers = jnp.array([[0.75, 1.0], [0.25, 0.0]]) | ||
assert jnp.all(assignments == expected_labels) | ||
assert jnp.allclose(centroids, expected_centers) | ||
|
||
|
||
def test_kmeans_jax_vmap(): | ||
"""Test that kmeans_jax works with vmap.""" | ||
|
||
def _gen_data(key): | ||
"""Generate 3 clusters of 10 samples each.""" | ||
subkeys = jr.split(key, 3) | ||
means = jnp.array([-2., 0., 2.]) | ||
_2D_normal = lambda key, mean: jr.normal(key, (10, 2))*0.2 + mean | ||
return vmap(_2D_normal)(subkeys, means).reshape(-1, 2) | ||
|
||
key = jr.PRNGKey(5) | ||
key, *data_subkeys = jr.split(key,3) | ||
# Generate 2 samples of the 3-cluster data | ||
x = vmap(_gen_data)(jnp.array(data_subkeys)) | ||
|
||
alg_subkeys = jr.split(key, 2) | ||
_, assignments, *_ = vmap(kmeans_jax, (0, None, 0))(x, 3, alg_subkeys) | ||
# Check that the assignments are the same for both samples (clusters are very distinct) | ||
assert jnp.all(assignments[0] == assignments[1]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is required so that this chunk is friendly with JAX transformations - the old version has intermediate arrays with variable shape.