-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable torch.autocast with ZeRO #6993
Open
tohtana
wants to merge
67
commits into
master
Choose a base branch
from
tohtana/support_autocast
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
a4fbc3a
Use ds-specific module id to avoid conflicts (#6847)
tjruwase 9d48dc6
add autocast support and ds_config item
tohtana 8957009
prepare ipg buckets for multiple dtypes
tohtana c390c76
switch communication data type
tohtana 458797d
add gradscaler
tohtana 2984415
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams 6b6a600
fix import and formatting
tohtana 2817c02
convert comm type for z3
tohtana 790c43a
Allow NVIDIA Blackwell (#6991)
fabiendupont fcdeda3
Update GH org references (#6998)
tjruwase b476d07
Update CNAME
loadams 72f9687
Update CNAME
loadams a2425da
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma a896693
autotp training(fix dco) (#7004)
inkcherry 1c1d43a
cast dtype for allgather
tohtana a265cad
import triton files when triton is supported and installed (#6989)
oelayan7 359c85d
Update A6000 tests transformers version (#7016)
loadams 735fc2c
Fix ds-chat CI regression (#7015)
tjruwase 1e7888c
[Ulysses tutorial] typos (#7024)
stas00 7d15f26
fix hostname -I for macOS #6497 (#6990)
fitzjalen ef5a2a4
Update workflows to cuda 12.4 (#7000)
loadams f1aea5d
[ROCm] Enable fp_quantizer on ROCm (#7027)
rraminen 8152824
add gds chinese blog (#7034)
GuanhuaWang c898ac5
Add chinese blog for deepspeed windows, and fix format (#7035)
hwchen2017 e3ea926
AIO on ROCM (#7023)
jomayeri e946615
Control trace cache warnings (#7039)
tjruwase 38e9bf3
Update CUDA compute capability to support Blackwell (#7047)
hwchen2017 acc6a1e
Update setup.py handling of ROCm cupy (#7051)
loadams bda1430
nv-ds-chat breaks with latest transformers (#7052)
loadams c184b16
Rename aio_thread_count to intra_op_parallelism (#7056)
tjruwase a2b8219
add autoTP training zero2 tests (#7049)
inkcherry 9f50cde
Fix, bf16 optimizer remove dup loop (#7054)
wukong1992 71bd64e
Update version.txt after 0.16.4 release (#7063)
loadams 5c9fd4b
fix an outdated doc wrt CUDA_VISIBLE_DEVICES (#7058)
stas00 20a46b7
Tecorigin sdaa accelerator (#6903)
siqi654321 5f68587
Handle special case of libuv for Windows (#7064)
loadams 3f5cd1a
Update README with info on newest accelerator (#7065)
loadams b80d2d4
Bug Fix for offload_states API (#7050)
U-rara 877c30e
Fix TOCTOU issues, switch to fstat (#7067)
loadams 59fe7f6
config torch to avoid graph breaks caused by logger (#6999)
ShellyNR 060aa5a
Fix meta load tensor imcompatible issue (#7073)
Yejing-Lai 817b31d
Replace calls to `python setup.py sdist` with `python -m build --sdis…
loadams f99605b
Revert "Handle special case of libuv for Windows (#7064)" (#7076)
loadams 57805b2
Add DeepseekV3 AutoTP. (#7045)
Yejing-Lai 697050e
Improve inference tutorial docs (#7083)
loadams 83c9461
Pin transformers version on tests that use latest. (#7085)
loadams c6bf7fb
Update README.md with ICS '23 MoE paper link (#7087)
siddharth9820 2c36865
Update parallelism for nv-torch-latest/nightly tests due to more GPUs…
loadams f2b89ec
Remove workflows for very old torch versions (#7090)
loadams 965cb2b
Merge branch 'master' into tohtana/support_autocast
tohtana 981e8e2
clear reduce buffer
tohtana 37f77ae
add config to set lower precision modules
tohtana 3083d94
fix to use comm dtype in config when autocast is disabled
tohtana d688b75
Merge branch 'master' into tohtana/support_autocast
tohtana aa60eb3
add tests
tohtana 9529830
sort dtypes
tohtana c8056a8
Merge branch 'master' into tohtana/support_autocast
tohtana c56339c
fix for cases where param and param.ds_tensor have different dtypes
tohtana 1d6ed6e
Merge branch 'master' into tohtana/support_autocast
loadams aa10e11
fix moe tests
tohtana 26e62e4
fix tests for opt state offloading
tohtana a74fa1e
fix var name
tohtana 2016995
Merge branch 'master' into tohtana/support_autocast
tohtana 9f5b8c0
Merge branch 'master' into tohtana/support_autocast
tohtana 7973b88
fix arg order
tohtana 15c436d
Merge branch 'master' into tohtana/support_autocast
tjruwase 713d56f
Merge branch 'master' into tohtana/support_autocast
tohtana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Iterable, Set, List, Union | ||
import importlib | ||
|
||
import torch | ||
|
||
LOWER_PRECISION_SAFE_MODULES = [ | ||
torch.nn.Linear, | ||
torch.nn.Conv1d, | ||
torch.nn.Conv2d, | ||
torch.nn.Conv3d, | ||
] | ||
|
||
TORCH_AUTOCAST_INITIALIZED = False | ||
|
||
|
||
def _validate_auto_cast_settings(engine): | ||
|
||
assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16" | ||
assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16" | ||
assert not engine.zero_quantized_weights(), "Cannot enable both torch autocast and zero quantized weights" | ||
|
||
assert all(p.dtype == torch.float32 | ||
for p in engine.parameters()), "All parameters must be float32 for torch autocast" | ||
assert engine.communication_data_type == torch.float32, "Communication data type must be float32 for torch autocast" | ||
|
||
|
||
def init_autocast_params(engine, dtype: torch.dtype, | ||
torch_autocast_lower_precision_safe_modules: Union[None, List[str]]) -> None: | ||
|
||
_validate_auto_cast_settings(engine) | ||
model = engine.module | ||
|
||
if torch_autocast_lower_precision_safe_modules is None: | ||
lower_precision_safe_module_classes = LOWER_PRECISION_SAFE_MODULES | ||
else: | ||
lower_precision_safe_module_classes = [] | ||
for module_name in torch_autocast_lower_precision_safe_modules: | ||
try: | ||
package_name, class_name = module_name.rsplit('.', 1) | ||
module = importlib.import_module(package_name) | ||
class_ = getattr(module, class_name) | ||
lower_precision_safe_module_classes.append(class_) | ||
except Exception as e: | ||
raise ValueError(f"Failed to import lower precision safe module {module_name}: {e}") | ||
|
||
for module in model.modules(): | ||
if module.__class__ in lower_precision_safe_module_classes: | ||
for p in module.parameters(recurse=False): | ||
p.autocast_dtype = dtype | ||
|
||
global TORCH_AUTOCAST_INITIALIZED | ||
TORCH_AUTOCAST_INITIALIZED = True | ||
|
||
|
||
def is_autocast_initialized() -> bool: | ||
return TORCH_AUTOCAST_INITIALIZED | ||
|
||
|
||
def get_default_autocast_lower_precision_modules() -> List[str]: | ||
return [f"{cls.__module__}.{cls.__name__}" for cls in LOWER_PRECISION_SAFE_MODULES] | ||
|
||
|
||
def get_autocast_dtype(param: torch.nn.Parameter) -> torch.dtype: | ||
return param.autocast_dtype if hasattr(param, "autocast_dtype") else param.dtype | ||
|
||
|
||
def has_autocast_dtype(param: torch.nn.Parameter) -> bool: | ||
return hasattr(param, "autocast_dtype") | ||
|
||
|
||
def get_all_autocast_dtypes(params: Iterable) -> Set[torch.dtype]: | ||
return {get_autocast_dtype(p) for p in params} | ||
|
||
|
||
def sort_dtypes(dtypes: List[torch.dtype]) -> List[torch.dtype]: | ||
return sorted(dtypes, key=str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add
...
here as in other sections as it's incomplete.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or update it with the rest of flags?