Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.autocast with ZeRO #6993

Open
wants to merge 67 commits into
base: master
Choose a base branch
from
Open

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Feb 3, 2025

DeepSpeed supports mixed precision training, but the behavior is different from torch.autocast. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while torch.autocast maintains parameters in FP32 but computes only certain operators in the lower precision.
This leads to differences in:

  • performance: torch.autocast needs downcast in forward/backward
  • memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision
  • accuracy: torch.autocast has a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32.

To align DeepSpeed's behavior with torch.autocast when necessary, this PR adds the integration with torch.autocast with ZeRO. Here is an examples of the configuration.

"torch_autocast": {
  "enabled": true,
  "dtype": "bfloat16",
  "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"]
}

Each configuration works as follows:

  • enabled: Enable the integration with torch.autocast if this is set to True. You don't need to call torch.autocast in your code. The grad scaler is also applied in the DeepSpeed optimizer.
  • dtype: lower precision dtype passed to torch.autocast. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) of lower_precision_safe_modules are also downcasted to this dtype.
  • lower_precision_safe_modules: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward follows torch.autocast's policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list: [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d].

Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable fp16 or bf16 in DeepSpeed config.

tjruwase and others added 30 commits February 28, 2025 22:53
Fix #6772

---------

Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
…#6967)

- Issues with nv-sd updates, will follow up with a subsequent PR

Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
NVIDIA Blackwell GPU generation has number 10. The SM code and
architecture should be `100`, but the current code generates `1.`,
because it expects a 2 characters string.

This change modifies the logic to consider it as a string that contains
a `.`, hence splits the string and uses the array of strings.

Signed-off-by: Fabien Dupont <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: Fabien Dupont <[email protected]>
Co-authored-by: Fabien Dupont <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
1. update intel oneAPI basekit to 2025.0
2. update torch/ipex/oneccl to 2.5

Signed-off-by: Masahiro Tanaka <[email protected]>
Same as [this PR](#6922).
[affeb88](affeb88)
I noticed the CI updated the DCO check recently. Using the suggested
rebase method for sign-off would reintroduce many conflicts, so I opted
for a squash merge with sign-off instead. thanks: )

Signed-off-by: inkcherry <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Those files have code that gets run when importing them, so in systems
that doesn't support triton but have triton installed this causes
issues.

In general, I think it is better to import triton when it is installed
and supported.

Signed-off-by: Omar Elayan <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Logan Adams <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Fix #7014
Avoid naming collision on `partition()`

---------

Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Fix typos

Signed-off-by: Masahiro Tanaka <[email protected]>
BUGFIX for Apple Silicon hostname
#6497

---------

Signed-off-by: Fabien Dupont <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: inkcherry <[email protected]>
Signed-off-by: Roman Fitzjalen <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Fabien Dupont <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Liangliang Ma <[email protected]>
Co-authored-by: inkcherry <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
- Update existing workflows that use cu121 to cu124. Note, this means
that where we download torch latest, we will now be getting torch 2.6
rather than the torch latest 2.5 provided with cuda 12.1.
- Note, nv-nightly is failing in master currently due to unrelated
errors, so this could be ignored in this PR (nv-nightly tested locally,
where it passes with 12.1 and it also passes with 12.4).

---------

Signed-off-by: Fabien Dupont <[email protected]>
Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: inkcherry <[email protected]>
Signed-off-by: Omar Elayan <[email protected]>
Co-authored-by: Fabien Dupont <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Liangliang Ma <[email protected]>
Co-authored-by: inkcherry <[email protected]>
Co-authored-by: Omar Elayan <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
This change is required to successfully build fp_quantizer extension on
ROCm.

---------

Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
cc @tjruwase @jomayeri

---------

Co-authored-by: root <root@ftqtmec25000000.taxzvufipdhelhupulxcbvr15f.ux.internal.cloudapp.net>
Signed-off-by: Masahiro Tanaka <[email protected]>
Fix #7029
- Add Chinese blog for deepspeed windows
- Fix format in README.md

Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Adding compile support for AIO library on AMD GPUs.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Make trace cache warnings configurable, and disabled by default.

Fix #6985, #4081, #5033, #5006, #5662

---------

Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Update CUDA compute capability for cross compile according to wiki page.
https://en.wikipedia.org/wiki/CUDA#GPUs_supported

---------

Signed-off-by: Hongwei <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Propagate API change.

Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Yejing-Lai and others added 6 commits February 28, 2025 22:53
Add deepseekv3 autotp.

Signed-off-by: Lai, Yejing <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Fixes: #7082

---------

Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Latest transformers causes failures when cpu-torch-latest test, so we
pin it for now to unblock other PRs.

---------

Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
…/runner (#7086)

Signed-off-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
These jobs haven't been run in a long time and were originally used when
compatibility with torch <2 was more important.

Signed-off-by: Logan Adams <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
@tohtana tohtana force-pushed the tohtana/support_autocast branch from 453cc16 to f2b89ec Compare February 28, 2025 22:54
@tohtana tohtana marked this pull request as ready for review March 5, 2025 23:08
@tohtana tohtana requested review from tjruwase and loadams as code owners March 5, 2025 23:08
@stas00
Copy link
Collaborator

stas00 commented Mar 21, 2025

I have a question about lower_precision_safe_modules

https://pytorch.org/docs/stable/amp.html#torch.autocast doesn't have an option to specify lower_precision_safe_modules - why do we then put the onus on the deepspeed user? (and I'm aware that it's optional and there is a default list).

My question is - can we automatically retrieve that safe list from pytorch? Or is this because what's pytorch considers safe isn't necessarily what deepspeed considers safe?

Copy link
Collaborator

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the training loop uses torch.autocast

  1. could deepspeed detect that and probably assert if ds config isn't done right?
  2. or actually should it assert if this happens because then the behavior is unsupported?

"torch_autocast": {
"enabled": true,
"dtype": "bfloat16",
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add ... here as in other sections as it's incomplete.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or update it with the rest of flags?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.