Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of training with NANOO fp8 GEMM on AMD MI300/MI325 GPUs. #1262

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

wenchenvincent
Copy link

@wenchenvincent wenchenvincent commented Feb 11, 2025

Description

This PR added support of training with NANOO fp8 GEMM on AMD MI300/MI325 GPUs.

There are several different genres of fp8 formats used by different HW vendors. Two popular genres include

  • OCP fp8, which is used natively on NVIDIA H100.
  • NANOO fp8, which is used natively on AMD MI300/MI325.

These two genres of fp8 formats work very similarly. The support of training with NANOO fp8 GEMM in Maxtext is based on this PR in Flax: google/flax#3993

References:

OCP fp8 paper: https://arxiv.org/abs/2209.05433
NANOO fp8 paper: https://arxiv.org/abs/2206.02915
JAX PR: jax-ml/jax#21376
XLA PR: openxla/xla#9531
Flax PR: google/flax#3993

Tests

I had run llama2-7b with NANOO fp8 from this PR and verified it was functional and the loss went down quickly with the synthetic dataset. I was not able to run the full unit tests locally due to the requirement of Google Cloud API credentials.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@shralex shralex requested a review from yangyuwei February 20, 2025 19:37
Copy link
Collaborator

@anfals anfals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; also using FP8Ops but just with different types; should work the same as everything else we have with _overwrite_with_gradient, so no additional lift needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants