You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
This PR enables `MXLinear` with `mxfp8_cublas` recipe to use
torch.compile.
The current approach is a short term workaround until
pytorch/pytorch#147873 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around `torch._scaled_mm` which takes `uint8` scales and does the cast to
e8m0 inside the wrapper, where torchinductor can't see it.
Test Plan:
```
// this now works (although performance is not ideal due to #1788)
python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas
// we can also uncomment the hardware check and run the unit test
pytest test/prototype/mx_formats -s -k test_linear_compile
```
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: f3ebd12edcb746b8abf992d00711ce2bdbb7fcf2
ghstack-comment-id: 2701679811
Pull Request resolved: #1841
0 commit comments