Skip to content

Commit 7709cf3

Browse files
committed
enable torch.compile for mxfp8_cublas recipe
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#147873 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f3ebd12edcb746b8abf992d00711ce2bdbb7fcf2 ghstack-comment-id: 2701679811 Pull Request resolved: #1841
1 parent 881631e commit 7709cf3

File tree

2 files changed

+91
-7
lines changed

2 files changed

+91
-7
lines changed

Diff for: test/prototype/mx_formats/test_mx_linear.py

+59-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14+
from torchao.float8.float8_utils import is_row_major
1415
from torchao.prototype.mx_formats.config import (
1516
MXLinearConfig,
1617
MXLinearRecipeName,
@@ -22,6 +23,7 @@
2223
swap_linear_with_mx_inference_linear,
2324
swap_linear_with_mx_linear,
2425
)
26+
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales
2527
from torchao.quantization.utils import compute_error
2628
from torchao.utils import (
2729
TORCH_VERSION_AT_LEAST_2_4,
@@ -169,11 +171,18 @@ def test_activation_checkpointing():
169171

170172
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
171173
@pytest.mark.skipif(
172-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
174+
is_sm_at_least_100(),
175+
reason="triton does not work yet on CUDA capability 10.0",
173176
)
174177
@pytest.mark.parametrize(
175178
"recipe_name",
176-
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
179+
[
180+
"mxfp8_emulated",
181+
"mxfp4_emulated",
182+
"mxfp8_cublas",
183+
"mxfp8_cutlass",
184+
"mxfp4_cutlass",
185+
],
177186
)
178187
@pytest.mark.parametrize("bias", [False, True])
179188
# TODO(future PR): figure out why torch.compile does not match eager when
@@ -186,9 +195,9 @@ def test_linear_compile(recipe_name, bias):
186195
if not is_sm_at_least_89():
187196
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
188197

189-
if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
198+
if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
190199
# TODO(future PR): fix this, things are clearly broken with bias=True
191-
pytest.skip("this test is broken for cutlass recipes with bias=True")
200+
pytest.skip("this test is broken for non-emulated recipes with bias=True")
192201

193202
M, K, N = 128, 256, 512
194203
input_shape = (M, K)
@@ -281,6 +290,52 @@ def test_inference_compile_simple(elem_dtype):
281290
assert sqnr >= 13.5
282291

283292

293+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
294+
def test_scaled_mm_wrapper():
295+
# today, e8m0 isn't supported in torchinductor or triton
296+
# for now, work around this by creating a wrapper around torch._scaled_mm
297+
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
298+
299+
M, K, N = 128, 256, 512
300+
BLOCK_SIZE = 32
301+
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
302+
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)
303+
304+
a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
305+
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
306+
307+
out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)
308+
309+
def wrapped(a, b, a_scale, b_scale, out_dtype):
310+
if is_row_major(b.stride()):
311+
b = b.t().contiguous().t()
312+
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
313+
return res
314+
315+
wrapped = torch.compile(wrapped)
316+
317+
# correct memory format of `b`
318+
out2 = wrapped(
319+
a,
320+
b.t(),
321+
a_scale.view(torch.uint8),
322+
b_scale.view(torch.uint8),
323+
out_dtype=torch.bfloat16,
324+
)
325+
torch.testing.assert_close(out, out2, atol=0, rtol=0)
326+
327+
# incorrect memory format of `b`
328+
b_col_major = b.t().contiguous().t()
329+
out3 = wrapped(
330+
a,
331+
b_col_major.t(),
332+
a_scale.view(torch.uint8),
333+
b_scale.view(torch.uint8),
334+
out_dtype=torch.bfloat16,
335+
)
336+
torch.testing.assert_close(out, out3, atol=0, rtol=0)
337+
338+
284339
def test_filter_fn():
285340
m1 = nn.Sequential(
286341
nn.Linear(32, 32),

Diff for: torchao/prototype/mx_formats/mx_ops.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,35 @@
3636
MX_OPS_TABLE: Dict[Any, Any] = {}
3737

3838

39+
@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
40+
def _scaled_mm_with_uint8_scales(
41+
a: torch.Tensor,
42+
b: torch.Tensor,
43+
a_scale: torch.Tensor,
44+
b_scale: torch.Tensor,
45+
out_dtype: torch.dtype,
46+
) -> torch.Tensor:
47+
"""
48+
Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to
49+
work around the lack of support for `torch.float8_e8m0fnu` in
50+
torchinductor. We do so by hiding the cast of scales to e8m0 inside a
51+
custom op.
52+
"""
53+
# cast back to e8m0 where torchinductor can't see it
54+
a_scale = a_scale.view(torch.float8_e8m0fnu)
55+
b_scale = b_scale.view(torch.float8_e8m0fnu)
56+
res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
57+
return res
58+
59+
60+
@_scaled_mm_with_uint8_scales.register_fake
61+
def _(a, b, a_scale, b_scale, out_dtype):
62+
m, k = a.shape
63+
k2, n = b.shape
64+
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
65+
return res
66+
67+
3968
def implements(aten_ops):
4069
"""Register aten ops to the mx op table"""
4170

@@ -83,11 +112,11 @@ def mx_mm(aten_op, args, kwargs=None):
83112
if a._elem_dtype == torch.float8_e4m3fn:
84113
assert b._elem_dtype == torch.float8_e4m3fn
85114
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
86-
res = torch._scaled_mm(
115+
res = _scaled_mm_with_uint8_scales(
87116
a._data,
88117
b._data,
89-
a_scale_block.view(torch.float8_e8m0fnu),
90-
b_scale_block.view(torch.float8_e8m0fnu),
118+
a_scale_block,
119+
b_scale_block,
91120
out_dtype=torch.bfloat16,
92121
)
93122
else:

0 commit comments

Comments
 (0)