Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Bug: Reverse is orders of magnitude slower on TPU #23191

Open
bjenik opened this issue Feb 27, 2025 · 5 comments
Open

[TPU] Bug: Reverse is orders of magnitude slower on TPU #23191

bjenik opened this issue Feb 27, 2025 · 5 comments
Labels
bug Something isn't working

Comments

@bjenik
Copy link

bjenik commented Feb 27, 2025

The following sample gets around 1400it/s on H100 and 11it/s on v6e. It seems there's an issue with the TPU implementation of reverse. This example is already isolated, in practice it gets generated when calling irfft.

import os 
os.environ["XLA_FLAGS"] = "--xla_dump_to=./hlo"
from tqdm import tqdm
import jax 
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec

num_devices = jax.device_count()
mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), ("batch",))

@jax.jit
def rp1(data):
    data = jax.lax.with_sharding_constraint(data, NamedSharding(mesh, PartitionSpec("batch")))
    data = data[:,:,:,:,:,::-1] + 1
    data = jax.lax.with_sharding_constraint(data, NamedSharding(mesh, PartitionSpec("batch")))
    return data

@jax.jit
def make_data():
    data = jnp.ones((num_devices * 32, 64, 16, 16, 8, 31))
    data = jax.lax.with_sharding_constraint(data, NamedSharding(mesh, PartitionSpec("batch")))
    return data

data = make_data()
with jax.profiler.trace("./tensorboard"):
    for i in tqdm(range(1000)):
        data = rp1(data)

@rdyro
Copy link

rdyro commented Feb 27, 2025

I experimented with two alternatives:

  • data = jnp.flip(data, axis=-1) + 1 - ~ 50 ms

Image

  • and a gather: data = data[..., jnp.arange(data.shape[-1])[::-1]] + 1 ~ 6 ms

Image

The gather generates a while loop which seems to be much more efficient.

One speculation is maybe the reverse operation is VMEM optimized, but your array size is large (500 MB) so needs to use HBM directly and ends up being inefficient - but I'm curious why the compiler does replace the gather, but not the reverse with a more optimized version!

@bjenik
Copy link
Author

bjenik commented Feb 27, 2025

I'd imagine this to run at HBM speeds (modulo maybe missing out on coalescing), so in the good case the 1400it/s on H100 seem about right. Given that v6e has about half the HBM bandwidth I'd still expect half that performance, meaning it shouldn't take much more than a millisecond each to achieve a target of around 700it/s. Do you have any insight what libtpu actually internally does for a reverse in terms of "algorithm" and memory access pattern?

Also curious: how would I get this (or any other) manual fix best integrated into the overall xla lowering pipeline? Issue here is that I'm in practice not calling reverse myself but xla lowers irfft into it. I was originally considering doing a matmul with a flipped identity matrix instead of the reverse (not the smartest thing but probably still better than whatever is actually happening in the reverse) but I'd be facing the same integration problem there.

@rdyro
Copy link

rdyro commented Feb 27, 2025

An important detail: I tested it on a v5e so far, and there the HBM BW suggests 2.5 ms, so I think the optimized gather is quite close to it (although it'd ideally be fused with the scalar add in this repro.

Yes, that lowering is problematic... What's the performance like when you implement irfft yourself with the gather trick and ifft? I'll work on creating an internal issue.

@bjenik
Copy link
Author

bjenik commented Mar 2, 2025

Re-implementing irfft to avoid the reverse works as a bandaid in the meantime - thanks. It's probably still better to have this fixed as others may hit it as well. I'm also personally curious what is actually happening on the TPU to cause this level of overhead.

@rdyro
Copy link

rdyro commented Mar 2, 2025

Re-implementing irfft to avoid the reverse works as a bandaid in the meantime - thanks.

That's good that an immediate workaround is possible, if you don't mind, do you have a short code snippet for this just as a public reference?

It's probably still better to have this fixed as others may hit it as well. I'm also personally curious what is actually happening on the TPU to cause this level of overhead.

Without a doubt! I'm following up on this internally, including if I can get an explanation for this.

@aniruthraj aniruthraj added the bug Something isn't working label Mar 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants