Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] add perf benchmarks for float8 training with rowwise + tensorwise scaling #1793

Merged
merged 6 commits into from
Mar 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,42 @@ python test/float8/test_fsdp2/test_fsdp2.py
# make sure to turn on torch.compile to get the best performance
./benchmarks/float8/bench_linear_float8.py -o ../tmp/test.txt --compile
```

### Training benchmarks

[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise
and tensorwise scaling. The training benchmarks were all run using:

- Single-node training on 8xH100 GPUs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add PyTorch version, torchtitan version, torchao version? Ideally the script could display them.

- Batch size 1
- Sequence length 8192
- Steps 100
- `torch.compile`
- FSDP2
- pytorch version: `2.7.0a0+gitb98af95`
- torchao version: `0.10.0+git890e0ac8`
- torchtitan version: `0.0.2`


| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
| ------------- | ---------------------------------- | ------------------------ | ------------------| -------------------- | ---------------------
| Llama3-8b | none (bfloat16) | per op SAC | 47.65 | 6150 | -
| Llama3-8b | tensorwise with float8 all-gather | per op SAC | 47.77 | 7689.5 | 25.03%
| Llama3-8b | rowwise with bfloat16 all-gather | per op SAC | 47.79 | 6768 | 10.05%

**Important notes**:
- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.

**Reproducing training benchmarks**
To reproduce these benchmarks, you can follow these steps:

1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above:
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./float8_training_benchmark.sh`
- float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh`
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`

See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.
Loading