-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8] add perf benchmarks for float8 training with rowwise + tensorwise scaling #1793
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1793
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 4 PendingAs of commit 4610085 with merge base 711fa08 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
a1e5143
to
0e78699
Compare
torchao/float8/README.md
Outdated
- FSDP2 | ||
|
||
| Model | Scaling | Activation checkpointing | Average tokens/second | Peak Memory (GB) | | ||
| ------------- | ----------- | ------------------------ | ------------------------- | ---------------- | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should include the baseline (bf16 + compile) here so it's clear what the speedup is from baseline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
656814d
to
c227ac7
Compare
torchao/float8/README.md
Outdated
| Llama3-8b | tensorwise | per op SAC | 7190 | 47.77 | | ||
| Llama3-8b | rowwise | per op SAC | 6649 | 47.79 | | ||
|
||
In these benchmarks tensorwise scaling achieved ~8% higher tokens/second over rowwise scaling, and ~19.5% higher than the bf16 baseline. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: how about
- add column with "speedup over baseline" instead of only explaining it in a sentence
- saying something like "rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve"
- saying that speedups increase as M,K,N increase, and pointing to blogs such as https://pytorch.org/blog/training-using-float8-fsdp2/ where e2e speedups as high as 1.5x are quoted. This is just to clarify that the 1.2x shown here is not the max speedup - it's just the speedup given the benchmark setup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, let me know what you think
Looks like test failure is related to #1799 |
781a2c4
to
1adf995
Compare
Just to confirm, for tensorwise scaling, I see that https://github.com/pytorch/ao/blob/main/benchmarks/float8/training/float8_training_benchmark.sh is using recipe-lookup-by name. Unfortunately, for tensorwise scaling, this is correct but not optimal. We also need to enable the following flags to enable float8 all-gather with FSDP2, and those flags are currently not supported when using titan's recipe-string-to-recipe lookup:
We have pytorch/torchtitan#901 to track making the torchtitan side of this better. Any chance we can update the tensorwise benchmark to include these flags, and also call out in the table that the tensorwise recipe has float8 all-gather for FSDP? |
torchao/float8/README.md
Outdated
- `torch.compile` | ||
- FSDP2 | ||
|
||
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over basline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: baseline
(typo)
Sure, but I thought enabling float8 all gather would make it less of a 1:1 comparison with rowwise? Or is the goal here just to showcase the peak achievable speedup using all optimal configs for each scaling strategy? |
Yes, IMO that's what this should do. We should call out any features which are not implemented (or impossible to implement), but at the end we want the best speedup with each recipe, with each appropriate knobs turned on. |
ec5febb
to
9e655fc
Compare
@vkuzo I managed to find a machine without perf regression and reran all benchmarks, using the optimal configs for tensorwise scaling this time, so this is ready for another look when you have a sec |
torchao/float8/README.md
Outdated
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline | ||
| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | --------------------- | ||
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | - | ||
| Llama3-8b | tensorwise with optimal settings | per op SAC | 47.77 | 7689.5 | 25.03% |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about something like tensorwise with float8 all-gather
and rowwise with bfloat16 all-gather
? "optimal settings" should be true for all the rows in this table, it's just that the actual settings change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this but the thing is, tensorwise has 3 settings enabled, so listing all 3 would cause the column to become huge and make the table formatting clunky. So instead I listed the 3 settings in a bullet point below. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's important to specify that float8 all-gather is used for tensorwise, and not important to list out the three specific settings used to enable that feature.
torchao/float8/README.md
Outdated
|
||
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline | ||
| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | --------------------- | ||
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: bfloat16
, to match how we spell dtypes in the rest of PyTorch?
[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise | ||
and tensorwise scaling. The training benchmarks were all run using: | ||
|
||
- Single-node training on 8xH100 GPUs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add PyTorch version, torchtitan version, torchao version? Ideally the script could display them.
torchao/float8/README.md
Outdated
| Llama3-8b | rowwise | per op SAC | 47.79 | 6768 | 10.05% | ||
|
||
**Important notes**: | ||
- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "E2e speedups as high as 1.5x..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, thank you!
Summary