Added support of training with NANOO fp8 GEMM on AMD MI300/MI325 GPUs. #1262
+64
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR added support of training with NANOO fp8 GEMM on AMD MI300/MI325 GPUs.
There are several different genres of fp8 formats used by different HW vendors. Two popular genres include
These two genres of fp8 formats work very similarly. The support of training with NANOO fp8 GEMM in Maxtext is based on this PR in Flax: google/flax#3993
References:
OCP fp8 paper: https://arxiv.org/abs/2209.05433
NANOO fp8 paper: https://arxiv.org/abs/2206.02915
JAX PR: jax-ml/jax#21376
XLA PR: openxla/xla#9531
Flax PR: google/flax#3993
Tests
I had run llama2-7b with NANOO fp8 from this PR and verified it was functional and the loss went down quickly with the synthetic dataset. I was not able to run the full unit tests locally due to the requirement of Google Cloud API credentials.
Checklist
Before submitting this PR, please make sure (put X in square brackets):