Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.autocast with ZeRO #6993

Open
wants to merge 67 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
a4fbc3a
Use ds-specific module id to avoid conflicts (#6847)
tjruwase Jan 31, 2025
9d48dc6
add autocast support and ds_config item
tohtana Jan 31, 2025
8957009
prepare ipg buckets for multiple dtypes
tohtana Feb 1, 2025
c390c76
switch communication data type
tohtana Feb 2, 2025
458797d
add gradscaler
tohtana Feb 3, 2025
2984415
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams Jan 31, 2025
6b6a600
fix import and formatting
tohtana Feb 3, 2025
2817c02
convert comm type for z3
tohtana Feb 4, 2025
790c43a
Allow NVIDIA Blackwell (#6991)
fabiendupont Feb 4, 2025
fcdeda3
Update GH org references (#6998)
tjruwase Feb 5, 2025
b476d07
Update CNAME
loadams Feb 5, 2025
72f9687
Update CNAME
loadams Feb 5, 2025
a2425da
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma Feb 5, 2025
a896693
autotp training(fix dco) (#7004)
inkcherry Feb 5, 2025
1c1d43a
cast dtype for allgather
tohtana Feb 28, 2025
a265cad
import triton files when triton is supported and installed (#6989)
oelayan7 Feb 6, 2025
359c85d
Update A6000 tests transformers version (#7016)
loadams Feb 8, 2025
735fc2c
Fix ds-chat CI regression (#7015)
tjruwase Feb 10, 2025
1e7888c
[Ulysses tutorial] typos (#7024)
stas00 Feb 11, 2025
7d15f26
fix hostname -I for macOS #6497 (#6990)
fitzjalen Feb 12, 2025
ef5a2a4
Update workflows to cuda 12.4 (#7000)
loadams Feb 12, 2025
f1aea5d
[ROCm] Enable fp_quantizer on ROCm (#7027)
rraminen Feb 13, 2025
8152824
add gds chinese blog (#7034)
GuanhuaWang Feb 13, 2025
c898ac5
Add chinese blog for deepspeed windows, and fix format (#7035)
hwchen2017 Feb 14, 2025
e3ea926
AIO on ROCM (#7023)
jomayeri Feb 14, 2025
e946615
Control trace cache warnings (#7039)
tjruwase Feb 18, 2025
38e9bf3
Update CUDA compute capability to support Blackwell (#7047)
hwchen2017 Feb 18, 2025
acc6a1e
Update setup.py handling of ROCm cupy (#7051)
loadams Feb 19, 2025
bda1430
nv-ds-chat breaks with latest transformers (#7052)
loadams Feb 19, 2025
c184b16
Rename aio_thread_count to intra_op_parallelism (#7056)
tjruwase Feb 19, 2025
a2b8219
add autoTP training zero2 tests (#7049)
inkcherry Feb 19, 2025
9f50cde
Fix, bf16 optimizer remove dup loop (#7054)
wukong1992 Feb 20, 2025
71bd64e
Update version.txt after 0.16.4 release (#7063)
loadams Feb 20, 2025
5c9fd4b
fix an outdated doc wrt CUDA_VISIBLE_DEVICES (#7058)
stas00 Feb 20, 2025
20a46b7
Tecorigin sdaa accelerator (#6903)
siqi654321 Feb 20, 2025
5f68587
Handle special case of libuv for Windows (#7064)
loadams Feb 20, 2025
3f5cd1a
Update README with info on newest accelerator (#7065)
loadams Feb 21, 2025
b80d2d4
Bug Fix for offload_states API (#7050)
U-rara Feb 21, 2025
877c30e
Fix TOCTOU issues, switch to fstat (#7067)
loadams Feb 24, 2025
59fe7f6
config torch to avoid graph breaks caused by logger (#6999)
ShellyNR Feb 24, 2025
060aa5a
Fix meta load tensor imcompatible issue (#7073)
Yejing-Lai Feb 24, 2025
817b31d
Replace calls to `python setup.py sdist` with `python -m build --sdis…
loadams Feb 24, 2025
f99605b
Revert "Handle special case of libuv for Windows (#7064)" (#7076)
loadams Feb 25, 2025
57805b2
Add DeepseekV3 AutoTP. (#7045)
Yejing-Lai Feb 26, 2025
697050e
Improve inference tutorial docs (#7083)
loadams Feb 26, 2025
83c9461
Pin transformers version on tests that use latest. (#7085)
loadams Feb 27, 2025
c6bf7fb
Update README.md with ICS '23 MoE paper link (#7087)
siddharth9820 Feb 27, 2025
2c36865
Update parallelism for nv-torch-latest/nightly tests due to more GPUs…
loadams Feb 27, 2025
f2b89ec
Remove workflows for very old torch versions (#7090)
loadams Feb 28, 2025
965cb2b
Merge branch 'master' into tohtana/support_autocast
tohtana Feb 28, 2025
981e8e2
clear reduce buffer
tohtana Mar 1, 2025
37f77ae
add config to set lower precision modules
tohtana Mar 1, 2025
3083d94
fix to use comm dtype in config when autocast is disabled
tohtana Mar 3, 2025
d688b75
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 3, 2025
aa60eb3
add tests
tohtana Mar 4, 2025
9529830
sort dtypes
tohtana Mar 5, 2025
c8056a8
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 5, 2025
c56339c
fix for cases where param and param.ds_tensor have different dtypes
tohtana Mar 6, 2025
1d6ed6e
Merge branch 'master' into tohtana/support_autocast
loadams Mar 7, 2025
aa10e11
fix moe tests
tohtana Mar 8, 2025
26e62e4
fix tests for opt state offloading
tohtana Mar 8, 2025
a74fa1e
fix var name
tohtana Mar 8, 2025
2016995
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 10, 2025
9f5b8c0
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 12, 2025
7973b88
fix arg order
tohtana Mar 12, 2025
15c436d
Merge branch 'master' into tohtana/support_autocast
tjruwase Mar 17, 2025
713d56f
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage
from deepspeed.runtime.torch_autocast import get_autocast_dtype, is_autocast_initialized


class DeepSpeedOptimizer(object):
Expand Down Expand Up @@ -61,3 +62,20 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
if key == 'params':
continue
param_group[key] = value

def report_ipg_memory_usage(self, tag, param_elems, dtype=None):
dtypes = self.ipg_buckets.keys() if dtype is None else [dtype]

for dt in dtypes:
bucket = self.ipg_buckets[dt]
elem_count = bucket.elements + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {dt} {bucket.elements} param {param_elems} max_percent {percent_of_bucket_size}"
)

def get_param_comm_dtype(self, param):
if is_autocast_initialized():
return get_autocast_dtype(param)
else:
return self.communication_data_type
32 changes: 32 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,35 @@ def get_amp_params(param_dict):
return False


def get_torch_autocast_enabled(param_dict):
if TORCH_AUTOCAST in param_dict.keys():
return get_scalar_param(param_dict[TORCH_AUTOCAST], TORCH_AUTOCAST_ENABLED, TORCH_AUTOCAST_ENABLED_DEFAULT)
else:
return False


def get_torch_autocast_dtype(param_dict):
if TORCH_AUTOCAST in param_dict:
if TORCH_AUTOCAST_DTYPE in param_dict[TORCH_AUTOCAST]:
try:
return DtypeEnum(param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]).value
except KeyError:
raise ValueError(
f"Invalid dtype for torch autocast: {param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]}")
return None


def get_lower_precision_safe_modules(param_dict):
if TORCH_AUTOCAST in param_dict:
if TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES in param_dict[TORCH_AUTOCAST]:
module_names_with_package = param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES]
if not all(isinstance(module_name, str) for module_name in module_names_with_package):
raise ValueError(
f"Invalid module names for torch autocast: {module_names_with_package}. Expected list of strings.")
return module_names_with_package
return None


def get_fp16_enabled(param_dict):
if FP16 in param_dict.keys():
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
Expand Down Expand Up @@ -836,6 +865,9 @@ def _initialize_params(self, param_dict):
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
self.torch_autocast_enabled = get_torch_autocast_enabled(param_dict)
self.torch_autocast_dtype = get_torch_autocast_dtype(param_dict)
self.torch_autocast_lower_precision_safe_modules = get_lower_precision_safe_modules(param_dict)
self.loss_scale = get_loss_scale(param_dict)
self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
Expand Down
17 changes: 17 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,23 @@
AMP_ENABLED = "enabled"
AMP_ENABLED_DEFAULT = False

#########################################
# Torch AMP support
#########################################
TORCH_AUTOCAST_FORMAT = '''
PyTorch autocast config should be of the format:
"torch_autocast": {
"enabled": true,
"dtype": "bfloat16",
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add ... here as in other sections as it's incomplete.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or update it with the rest of flags?

'''
TORCH_AUTOCAST = "torch_autocast"

TORCH_AUTOCAST_ENABLED = "enabled"
TORCH_AUTOCAST_ENABLED_DEFAULT = False
TORCH_AUTOCAST_DTYPE = "dtype"
TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES = "lower_precision_safe_modules"

#########################################
# Gradient clipping
#########################################
Expand Down
21 changes: 19 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from contextlib import contextmanager

from typing import Callable, Dict, Union, Iterable, Container
from typing import Callable, Dict, Union, Iterable, Container, List

import deepspeed

Expand Down Expand Up @@ -91,6 +91,7 @@

from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from deepspeed.runtime.torch_autocast import init_autocast_params, get_default_autocast_lower_precision_modules

from .pipe.module import PipelineModule
from .utils import get_ma_status
Expand Down Expand Up @@ -316,6 +317,9 @@ def __init__(self,
if not isinstance(model_parameters, list):
model_parameters = list(model_parameters)

if self.torch_autocast_enabled():
init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules())

if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler()
Expand Down Expand Up @@ -923,6 +927,16 @@ def amp_enabled(self):
def amp_params(self):
return self._config.amp_params

def torch_autocast_enabled(self) -> bool:
return self._config.torch_autocast_enabled

def torch_autocast_dtype(self) -> torch.dtype:
return self._config.torch_autocast_dtype

def torch_autocast_lower_precision_safe_modules(self) -> List[str]:
module_names = self._config.torch_autocast_lower_precision_safe_modules
return get_default_autocast_lower_precision_modules() if module_names is None else module_names

def fp16_auto_cast(self):
return self._config.fp16_auto_cast

Expand Down Expand Up @@ -2027,7 +2041,10 @@ def forward(self, *inputs, **kwargs):
if self.autotuning_profile_model_info():
ma = get_ma_status()

loss = self.module(*inputs, **kwargs)
with torch.autocast(device_type=get_accelerator().device_name(),
dtype=self.torch_autocast_dtype(),
enabled=self.torch_autocast_enabled()):
loss = self.module(*inputs, **kwargs)

if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma
Expand Down
81 changes: 81 additions & 0 deletions deepspeed/runtime/torch_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Set, List, Union
import importlib

import torch

LOWER_PRECISION_SAFE_MODULES = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
]

TORCH_AUTOCAST_INITIALIZED = False


def _validate_auto_cast_settings(engine):

assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16"
assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16"
assert not engine.zero_quantized_weights(), "Cannot enable both torch autocast and zero quantized weights"

assert all(p.dtype == torch.float32
for p in engine.parameters()), "All parameters must be float32 for torch autocast"
assert engine.communication_data_type == torch.float32, "Communication data type must be float32 for torch autocast"


def init_autocast_params(engine, dtype: torch.dtype,
torch_autocast_lower_precision_safe_modules: Union[None, List[str]]) -> None:

_validate_auto_cast_settings(engine)
model = engine.module

if torch_autocast_lower_precision_safe_modules is None:
lower_precision_safe_module_classes = LOWER_PRECISION_SAFE_MODULES
else:
lower_precision_safe_module_classes = []
for module_name in torch_autocast_lower_precision_safe_modules:
try:
package_name, class_name = module_name.rsplit('.', 1)
module = importlib.import_module(package_name)
class_ = getattr(module, class_name)
lower_precision_safe_module_classes.append(class_)
except Exception as e:
raise ValueError(f"Failed to import lower precision safe module {module_name}: {e}")

for module in model.modules():
if module.__class__ in lower_precision_safe_module_classes:
for p in module.parameters(recurse=False):
p.autocast_dtype = dtype

global TORCH_AUTOCAST_INITIALIZED
TORCH_AUTOCAST_INITIALIZED = True


def is_autocast_initialized() -> bool:
return TORCH_AUTOCAST_INITIALIZED


def get_default_autocast_lower_precision_modules() -> List[str]:
return [f"{cls.__module__}.{cls.__name__}" for cls in LOWER_PRECISION_SAFE_MODULES]


def get_autocast_dtype(param: torch.nn.Parameter) -> torch.dtype:
return param.autocast_dtype if hasattr(param, "autocast_dtype") else param.dtype


def has_autocast_dtype(param: torch.nn.Parameter) -> bool:
return hasattr(param, "autocast_dtype")


def get_all_autocast_dtypes(params: Iterable) -> Set[torch.dtype]:
return {get_autocast_dtype(p) for p in params}


def sort_dtypes(dtypes: List[torch.dtype]) -> List[torch.dtype]:
return sorted(dtypes, key=str)
5 changes: 2 additions & 3 deletions deepspeed/runtime/zero/offload_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,5 @@ def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
return set(safe_get_local_optimizer_state(p, "exp_avg").device for p in model.parameters()) | \
set(safe_get_local_optimizer_state(p, "exp_avg_sq").device for p in model.parameters())
elif state == OffloadStateTypeEnum.contiguous_grad_buffer:
if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None:
return {}
return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device}
return set(bucket.buffer.device for bucket in model.optimizer.ipg_buckets.values()
if bucket.buffer is not None)
Loading