Skip to content

Commit 9e655fc

Browse files
updated perf results
1 parent 71ce7e6 commit 9e655fc

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torchao/float8/README.md

+8-7
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,16 @@ and tensorwise scaling. The training benchmarks were all run using:
215215
- `torch.compile`
216216
- FSDP2
217217

218-
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
219-
| ------------- | ------------ | ------------------------ | ------------------| -------------------- | ---------------------
220-
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6019 | -
221-
| Llama3-8b | tensorwise | per op SAC | 47.77 | 7190 | 19.45%
222-
| Llama3-8b | rowwise | per op SAC | 47.79 | 6649 | 10.47%
218+
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
219+
| ------------- | --------------------------------- | ------------------------ | ------------------| -------------------- | ---------------------
220+
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | -
221+
| Llama3-8b | tensorwise with optimal settings | per op SAC | 47.77 | 7689.5 | 25.03%
222+
| Llama3-8b | rowwise | per op SAC | 47.79 | 6768 | 10.05%
223223

224224
**Important notes**:
225225
- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
226226
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
227+
- Tensorwise scaling benchmarks were ran with optimal settings, namely: `enable_fsdp_float8_all_gather`, `precompute_float8_dynamic_scale_for_fsdp`, `force_recompute_fp8_weight_in_bwd`.
227228

228229
**Reproducing training benchmarks**
229230
To reproduce these benchmarks, you can follow these steps:
@@ -233,7 +234,7 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
233234
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
234235
3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above:
235236
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./float8_training_benchmark.sh`
236-
- float8 tensorwise: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE="tensorwise" ./float8_training_benchmark.sh`
237-
- float8 rowwise: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE="rowwise" ./float8_training_benchmark.sh`
237+
- float8 tensorwise: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh`
238+
- float8 rowwise: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`
238239

239240
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.

0 commit comments

Comments
 (0)