You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
| Llama3-8b | none (bf16) | per op SAC | 47.65 | 6150 | -
221
+
| Llama3-8b | tensorwise with optimal settings | per op SAC | 47.77 | 7689.5| 25.03%
222
+
| Llama3-8b | rowwise | per op SAC | 47.79 | 6768 | 10.05%
223
223
224
224
**Important notes**:
225
225
- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
226
226
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
227
+
- Tensorwise scaling benchmarks were ran with optimal settings, namely: `enable_fsdp_float8_all_gather`, `precompute_float8_dynamic_scale_for_fsdp`, `force_recompute_fp8_weight_in_bwd`.
227
228
228
229
**Reproducing training benchmarks**
229
230
To reproduce these benchmarks, you can follow these steps:
@@ -233,7 +234,7 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
233
234
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
234
235
3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above:
0 commit comments