Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pmap slower with new CPU runtime #23110

Open
lockwo opened this issue Feb 25, 2025 · 0 comments
Open

Pmap slower with new CPU runtime #23110

lockwo opened this issue Feb 25, 2025 · 0 comments
Labels
CPU Related to XLA on CPU

Comments

@lockwo
Copy link

lockwo commented Feb 25, 2025

This was suggested to be posted in XLA (as opposed to JAX) since it deals with the new CPU runtime (original issue jax-ml/jax#26616).

Description

Something I noticed while using diffrax, was that the adaptive solvers where much slower using pmaping integration with the new runtime on CPUs (pmaping is used over sharding for this reason jax-ml/jax#26586). I adapted the code from the aforementioned issue to also show that pmap-ing is slower on the new runtime.

import os
import multiprocessing as mp

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    mp.cpu_count()
)

import jax
import jax.numpy as jnp

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 2

    def inner_loop_body(state):
        t, y, theta = state
        return (t + 0.1, y + 0.1, theta)
    
    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5000
    
    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = jax.lax.while_loop
    outer_while_loop = jax.lax.while_loop
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]

batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)

num_devices = len(jax.devices())

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

fn = jax.jit(jax.vmap(solve))
pmap_fn = jax.pmap(fn)

_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()

import time

start_time = time.time()
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()
end_time = time.time()

elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.6f} seconds")

with 0.4.31
Elapsed time: 0.002367 seconds

with 0.4.33
Elapsed time: 0.015590 seconds

with 0.5.0
Elapsed time: 0.018911 seconds

This example is of course trivial, but represents the core subroutine of adaptive SDE solvers. Currently this can be solved by disabling the new CPU thunk runtime, but I'm just reporting it so hopefully it can be fixed in the future :).

System info (python version, jaxlib version, accelerator, etc.)

multiple jax version, CPU, tested on Mac and colab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CPU Related to XLA on CPU
Projects
None yet
Development

No branches or pull requests

2 participants