From 0641e747b9acbda9b0a407566667fe9d61c9be97 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 27 Feb 2025 22:03:11 -0800 Subject: [PATCH 1/6] add perf benchmarks to float8 training with rowwise scaling --- torchao/float8/README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 65105d1f89..53eff02346 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -202,3 +202,38 @@ python test/float8/test_fsdp2/test_fsdp2.py # make sure to turn on torch.compile to get the best performance ./benchmarks/float8/bench_linear_float8.py -o ../tmp/test.txt --compile ``` + +### Training benchmarks + +[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise +and tensorwise scaling. The training benchmarks were all run using: + +- Single-node training on 8xH100 GPUs +- Batch size 1 +- Sequence length 8192 +- Steps 100 +- `torch.compile` +- FSDP2 + +| Model | Scaling | Activation checkpointing | Median tokens/second | Peak Memory (GB) | +| ------------- | ------------ | ------------------------ | ------------------------- | ---------------- | +| Llama3-8b | none (bf16) | per op SAC | 6019 | 47.65 | +| Llama3-8b | tensorwise | per op SAC | 7190 | 47.77 | +| Llama3-8b | rowwise | per op SAC | 6649 | 47.79 | + +In these benchmarks tensorwise scaling achieved ~8% higher tokens/second over rowwise scaling, and ~19.5% higher than the bf16 baseline. +However, it is important to note that rowwise scaling has been shown to yield improvments in training loss/accuracy due to reduced quantization error, particularly +when training large models for many steps. + +**Reproducing training benchmarks** +To reproduce these benchmarks, you can follow these steps: + +1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation), +including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer). +2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation). +3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above: + - bf16 + compile: `TORCHTITAN_ROOT= ./float8_training_benchmark.sh` + - float8 tensorwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE="tensorwise" ./float8_training_benchmark.sh` + - float8 rowwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE="rowwise" ./float8_training_benchmark.sh` + +See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details. From c4e7e3b665c32971192a252c4f135fd4b3211b6b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 4 Mar 2025 09:27:03 -0800 Subject: [PATCH 2/6] address comment --- torchao/float8/README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 53eff02346..0e8270ea50 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -215,15 +215,15 @@ and tensorwise scaling. The training benchmarks were all run using: - `torch.compile` - FSDP2 -| Model | Scaling | Activation checkpointing | Median tokens/second | Peak Memory (GB) | -| ------------- | ------------ | ------------------------ | ------------------------- | ---------------- | -| Llama3-8b | none (bf16) | per op SAC | 6019 | 47.65 | -| Llama3-8b | tensorwise | per op SAC | 7190 | 47.77 | -| Llama3-8b | rowwise | per op SAC | 6649 | 47.79 | - -In these benchmarks tensorwise scaling achieved ~8% higher tokens/second over rowwise scaling, and ~19.5% higher than the bf16 baseline. -However, it is important to note that rowwise scaling has been shown to yield improvments in training loss/accuracy due to reduced quantization error, particularly -when training large models for many steps. +| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over basline +| ------------- | ------------ | ------------------------ | ------------------| -------------------- | --------------------- +| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6019 | - +| Llama3-8b | tensorwise | per op SAC | 47.77 | 7190 | 19.45% +| Llama3-8b | rowwise | per op SAC | 47.79 | 6649 | 10.47% + +**Important notes**: +- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]). +- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve. **Reproducing training benchmarks** To reproduce these benchmarks, you can follow these steps: From 71ce7e62d6aaafbb240dc481886405e434d4d443 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 4 Mar 2025 12:46:06 -0800 Subject: [PATCH 3/6] fix typo --- torchao/float8/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 0e8270ea50..f22e6a8a7f 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -215,7 +215,7 @@ and tensorwise scaling. The training benchmarks were all run using: - `torch.compile` - FSDP2 -| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over basline +| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline | ------------- | ------------ | ------------------------ | ------------------| -------------------- | --------------------- | Llama3-8b | none (bf16) | per op SAC | 47.65 | 6019 | - | Llama3-8b | tensorwise | per op SAC | 47.77 | 7190 | 19.45% From 9e655fc7d6891d86bbbb725f9c77b4d1dca970a0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 11 Mar 2025 18:48:58 -0700 Subject: [PATCH 4/6] updated perf results --- torchao/float8/README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index f22e6a8a7f..b359175c7f 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -215,15 +215,16 @@ and tensorwise scaling. The training benchmarks were all run using: - `torch.compile` - FSDP2 -| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline -| ------------- | ------------ | ------------------------ | ------------------| -------------------- | --------------------- -| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6019 | - -| Llama3-8b | tensorwise | per op SAC | 47.77 | 7190 | 19.45% -| Llama3-8b | rowwise | per op SAC | 47.79 | 6649 | 10.47% +| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline +| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | --------------------- +| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | - +| Llama3-8b | tensorwise with optimal settings | per op SAC | 47.77 | 7689.5 | 25.03% +| Llama3-8b | rowwise | per op SAC | 47.79 | 6768 | 10.05% **Important notes**: - Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]). - Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve. +- Tensorwise scaling benchmarks were ran with optimal settings, namely: `enable_fsdp_float8_all_gather`, `precompute_float8_dynamic_scale_for_fsdp`, `force_recompute_fp8_weight_in_bwd`. **Reproducing training benchmarks** To reproduce these benchmarks, you can follow these steps: @@ -233,7 +234,7 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re 2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation). 3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above: - bf16 + compile: `TORCHTITAN_ROOT= ./float8_training_benchmark.sh` - - float8 tensorwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE="tensorwise" ./float8_training_benchmark.sh` - - float8 rowwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE="rowwise" ./float8_training_benchmark.sh` + - float8 tensorwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh` + - float8 rowwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh` See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details. From 73d93910ff4fa53ec08b8fa6cc044cc2a85eb944 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 12 Mar 2025 08:21:54 -0700 Subject: [PATCH 5/6] add pkg versions --- torchao/float8/README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index b359175c7f..5c421cd548 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -214,17 +214,20 @@ and tensorwise scaling. The training benchmarks were all run using: - Steps 100 - `torch.compile` - FSDP2 +- pytorch version: `2.7.0a0+gitb98af95` +- torchao version: `0.10.0+git890e0ac8` +- torchtitan version: `0.0.2` -| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline -| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | --------------------- -| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | - -| Llama3-8b | tensorwise with optimal settings | per op SAC | 47.77 | 7689.5 | 25.03% -| Llama3-8b | rowwise | per op SAC | 47.79 | 6768 | 10.05% + +| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline +| ------------- | ---------------------------------- | ------------------------ | ------------------| -------------------- | --------------------- +| Llama3-8b | none (bfloat16) | per op SAC | 47.65 | 6150 | - +| Llama3-8b | tensorwise with float8 all-gather | per op SAC | 47.77 | 7689.5 | 25.03% +| Llama3-8b | rowwise with bfloat16 all-gather | per op SAC | 47.79 | 6768 | 10.05% **Important notes**: -- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]). +- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]). - Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve. -- Tensorwise scaling benchmarks were ran with optimal settings, namely: `enable_fsdp_float8_all_gather`, `precompute_float8_dynamic_scale_for_fsdp`, `force_recompute_fp8_weight_in_bwd`. **Reproducing training benchmarks** To reproduce these benchmarks, you can follow these steps: From 4610085634c1a199e0e5bce3e11112cff5391c45 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 12 Mar 2025 08:23:06 -0700 Subject: [PATCH 6/6] more details in repro steps --- torchao/float8/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 5c421cd548..9e37bb001d 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -237,7 +237,7 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re 2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation). 3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above: - bf16 + compile: `TORCHTITAN_ROOT= ./float8_training_benchmark.sh` - - float8 tensorwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh` - - float8 rowwise: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh` + - float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh` + - float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh` See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.