Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] add perf benchmarks for float8 training with rowwise + tensorwise scaling #1793

Merged
merged 6 commits into from
Mar 12, 2025

Conversation

danielvegamyhre
Copy link
Contributor

Summary

  • Add float8 training performance benchmarks for rowwise + tensorwise scaling.
  • Add repro steps for these benchmarks.

@danielvegamyhre danielvegamyhre added the topic: documentation Use this tag if this PR adds or improves documentation label Feb 28, 2025
Copy link

pytorch-bot bot commented Feb 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1793

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 4 Pending

As of commit 4610085 with merge base 711fa08 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 28, 2025
@danielvegamyhre danielvegamyhre changed the title Add perf benchmarks to float8 training with rowwise + tensorwise scaling Add perf benchmarks for float8 training with rowwise + tensorwise scaling Feb 28, 2025
@danielvegamyhre danielvegamyhre force-pushed the fp8readme branch 2 times, most recently from a1e5143 to 0e78699 Compare February 28, 2025 07:16
- FSDP2

| Model | Scaling | Activation checkpointing | Average tokens/second | Peak Memory (GB) |
| ------------- | ----------- | ------------------------ | ------------------------- | ---------------- |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should include the baseline (bf16 + compile) here so it's clear what the speedup is from baseline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@danielvegamyhre danielvegamyhre changed the title Add perf benchmarks for float8 training with rowwise + tensorwise scaling [float8] add perf benchmarks for float8 training with rowwise + tensorwise scaling Feb 28, 2025
@danielvegamyhre danielvegamyhre force-pushed the fp8readme branch 2 times, most recently from 656814d to c227ac7 Compare February 28, 2025 23:42
| Llama3-8b | tensorwise | per op SAC | 7190 | 47.77 |
| Llama3-8b | rowwise | per op SAC | 6649 | 47.79 |

In these benchmarks tensorwise scaling achieved ~8% higher tokens/second over rowwise scaling, and ~19.5% higher than the bf16 baseline.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about

  • add column with "speedup over baseline" instead of only explaining it in a sentence
  • saying something like "rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve"
  • saying that speedups increase as M,K,N increase, and pointing to blogs such as https://pytorch.org/blog/training-using-float8-fsdp2/ where e2e speedups as high as 1.5x are quoted. This is just to clarify that the 1.2x shown here is not the max speedup - it's just the speedup given the benchmark setup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, let me know what you think

@danielvegamyhre
Copy link
Contributor Author

Looks like test failure is related to #1799

@vkuzo
Copy link
Contributor

vkuzo commented Mar 4, 2025

Just to confirm, for tensorwise scaling, I see that https://github.com/pytorch/ao/blob/main/benchmarks/float8/training/float8_training_benchmark.sh is using recipe-lookup-by name.

Unfortunately, for tensorwise scaling, this is correct but not optimal. We also need to enable the following flags to enable float8 all-gather with FSDP2, and those flags are currently not supported when using titan's recipe-string-to-recipe lookup:

# source: https://github.com/pytorch/torchtitan/blob/main/docs/float8.md
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile

We have pytorch/torchtitan#901 to track making the torchtitan side of this better. Any chance we can update the tensorwise benchmark to include these flags, and also call out in the table that the tensorwise recipe has float8 all-gather for FSDP?

- `torch.compile`
- FSDP2

| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over basline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: baseline (typo)

@danielvegamyhre
Copy link
Contributor Author

Just to confirm, for tensorwise scaling, I see that https://github.com/pytorch/ao/blob/main/benchmarks/float8/training/float8_training_benchmark.sh is using recipe-lookup-by name.

Unfortunately, for tensorwise scaling, this is correct but not optimal. We also need to enable the following flags to enable float8 all-gather with FSDP2, and those flags are currently not supported when using titan's recipe-string-to-recipe lookup:

# source: https://github.com/pytorch/torchtitan/blob/main/docs/float8.md
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile

We have pytorch/torchtitan#901 to track making the torchtitan side of this better. Any chance we can update the tensorwise benchmark to include these flags, and also call out in the table that the tensorwise recipe has float8 all-gather for FSDP?

Sure, but I thought enabling float8 all gather would make it less of a 1:1 comparison with rowwise? Or is the goal here just to showcase the peak achievable speedup using all optimal configs for each scaling strategy?

@vkuzo
Copy link
Contributor

vkuzo commented Mar 4, 2025

Or is the goal here just to showcase the peak achievable speedup using all optimal configs for each scaling strategy?

Yes, IMO that's what this should do. We should call out any features which are not implemented (or impossible to implement), but at the end we want the best speedup with each recipe, with each appropriate knobs turned on.

@danielvegamyhre
Copy link
Contributor Author

Or is the goal here just to showcase the peak achievable speedup using all optimal configs for each scaling strategy?

Yes, IMO that's what this should do. We should call out any features which are not implemented (or impossible to implement), but at the end we want the best speedup with each recipe, with each appropriate knobs turned on.

@vkuzo I managed to find a machine without perf regression and reran all benchmarks, using the optimal configs for tensorwise scaling this time, so this is ready for another look when you have a sec

| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | ---------------------
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | -
| Llama3-8b | tensorwise with optimal settings | per op SAC | 47.77 | 7689.5 | 25.03%
Copy link
Contributor

@vkuzo vkuzo Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about something like tensorwise with float8 all-gather and rowwise with bfloat16 all-gather? "optimal settings" should be true for all the rows in this table, it's just that the actual settings change.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this but the thing is, tensorwise has 3 settings enabled, so listing all 3 would cause the column to become huge and make the table formatting clunky. So instead I listed the 3 settings in a bullet point below. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's important to specify that float8 all-gather is used for tensorwise, and not important to list out the three specific settings used to enable that feature.


| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | ---------------------
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | -
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: bfloat16, to match how we spell dtypes in the rest of PyTorch?

[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise
and tensorwise scaling. The training benchmarks were all run using:

- Single-node training on 8xH100 GPUs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add PyTorch version, torchtitan version, torchao version? Ideally the script could display them.

| Llama3-8b | rowwise | per op SAC | 47.79 | 6768 | 10.05%

**Important notes**:
- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "E2e speedups as high as 1.5x..."

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, thank you!

@danielvegamyhre danielvegamyhre merged commit 8c81863 into pytorch:main Mar 12, 2025
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: documentation Use this tag if this PR adds or improves documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants