Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve overflow handling in ZeRO #6976

Open
wants to merge 90 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
a3a18f7
Improve overflow handling in ZeRO
tjruwase Jan 28, 2025
19431f8
Unit test and pydantic configuration
tjruwase Jan 28, 2025
406cf26
Formatting fixes
tjruwase Jan 28, 2025
35570f5
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 29, 2025
cb78444
Remove unused symbol
tjruwase Jan 29, 2025
ee1c1fd
Fix typo
tjruwase Jan 29, 2025
0b2cf73
Pydantic fp16 config
tjruwase Jan 29, 2025
c7a90f9
Fix more typos
tjruwase Jan 29, 2025
3694e07
Address #4986
tjruwase Jan 29, 2025
2bbcf00
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 29, 2025
c1b87ea
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 30, 2025
5da6cd0
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 30, 2025
a65d20c
Merge branch 'master' into olruwase/ds_5241
loadams Jan 30, 2025
ae039b2
Fix typo
tjruwase Jan 30, 2025
0446192
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Jan 30, 2025
5d48745
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 31, 2025
05c362d
Merge branch 'master' into olruwase/ds_5241
loadams Jan 31, 2025
5e17ed6
Merge branch 'master' into olruwase/ds_5241
loadams Feb 1, 2025
06bb3a6
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 5, 2025
0d0ab3d
Fix min loss scale
tjruwase Feb 5, 2025
cccd5b1
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 5, 2025
2c6f630
Fix UTs
tjruwase Feb 6, 2025
21bfca0
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 6, 2025
5fe5810
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 6, 2025
732ceb7
Using explicit GPU upcast for ZeRO-Offload (#6962)
xylian86 Jan 21, 2025
db9aff9
Update version.txt after 0.16.3 release (#6965)
loadams Jan 21, 2025
4edeb03
Precisely track nvme optimizer offload (#6963)
tjruwase Jan 23, 2025
f00f4ea
Update build_win.bat script to exclue GDS op as it lacks Windows supp…
loadams Jan 24, 2025
c3846fa
Improve overflow handling in ZeRO
tjruwase Jan 28, 2025
7d56ffa
Unit test and pydantic configuration
tjruwase Jan 28, 2025
6ca11ef
Formatting fixes
tjruwase Jan 28, 2025
49f3df8
Add CUDA 12.8 support and comment on CUDA 12.7 (#6975)
loadams Jan 28, 2025
8364b12
Update torch versions to support 2.6 (#6977)
loadams Jan 29, 2025
ea9b473
Remove unused symbol
tjruwase Jan 29, 2025
d2425a2
Fix typo
tjruwase Jan 29, 2025
7d5be07
Pydantic fp16 config
tjruwase Jan 29, 2025
e8fc098
Fix more typos
tjruwase Jan 29, 2025
2bbb7b4
Address #4986
tjruwase Jan 29, 2025
3ab5e88
generalize deepspeed linear and implement it for non cuda systems (#6…
oelayan7 Jan 29, 2025
271db94
Fix typo
tjruwase Jan 30, 2025
b1900af
Update recommended Windows whl building versions (#6983)
loadams Jan 30, 2025
e3d10e5
Title: Fix setup_env_ranks to Properly Set Environment Variables Inst…
fabiosanger Jan 30, 2025
b8d8e39
Specify torchvision in nv-ds-chat workflow (prevents errors with torc…
loadams Jan 30, 2025
fde7df1
Remove assumption that padding only occurs on last rank (#6974)
xylian86 Jan 31, 2025
b0b0132
Use ds-specific module id to avoid conflicts (#6847)
tjruwase Jan 31, 2025
353ab08
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams Jan 31, 2025
14189a7
Allow NVIDIA Blackwell (#6991)
fabiendupont Feb 4, 2025
75996f8
Update GH org references (#6998)
tjruwase Feb 5, 2025
b23c545
Fix min loss scale
tjruwase Feb 5, 2025
7cd3a9f
Fix UTs
tjruwase Feb 6, 2025
2c5629e
Update CNAME
loadams Feb 5, 2025
6b15688
Update CNAME
loadams Feb 5, 2025
3773d83
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma Feb 5, 2025
64c4b04
autotp training(fix dco) (#7004)
inkcherry Feb 5, 2025
5fa2910
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 6, 2025
1f5a672
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 7, 2025
9882116
Fix ds-chat CI regression
tjruwase Feb 7, 2025
97d7915
Merge branch 'olruwase/ds_7014' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 7, 2025
4a1dd0f
Fix bug
tjruwase Feb 7, 2025
0ac4457
Avoid naming collision on partition()
tjruwase Feb 7, 2025
1597d48
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 8, 2025
2ae2062
Use new API
tjruwase Feb 8, 2025
9fb73a4
Merge branch 'master' into olruwase/ds_7014
tjruwase Feb 8, 2025
26fa8af
Merge branch 'olruwase/ds_7014' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 8, 2025
b565d77
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 8, 2025
d098c32
Merge branch 'master' into olruwase/ds_5241
loadams Feb 10, 2025
990a5ad
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 13, 2025
9b1b030
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 21, 2025
1953c38
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 25, 2025
2ea182e
Code cleanup
tjruwase Feb 25, 2025
9aff208
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 25, 2025
36c55d2
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 26, 2025
80fcb83
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 27, 2025
776385f
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 27, 2025
e5f64af
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 28, 2025
61685dc
Use new dlpack api; Formatting fixes
tjruwase Mar 3, 2025
75ac86c
Merge branch 'olruwase/new_dlpack_api' of github.com:microsoft/DeepSp…
tjruwase Mar 3, 2025
6b9736c
Merge branch 'master' into olruwase/ds_5241
tjruwase Mar 4, 2025
83850ad
Triage pytest --forked cupy failure
tjruwase Mar 4, 2025
4d56c99
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Mar 4, 2025
5e76c7d
Revert pytest debugging
tjruwase Mar 4, 2025
a59cb55
Merge branch 'master' into olruwase/ds_5241
loadams Mar 7, 2025
f10a2f2
Merge branch 'master' into olruwase/ds_5241
loadams Mar 7, 2025
919f538
Merge branch 'master' into olruwase/ds_5241
tjruwase Mar 11, 2025
4b58326
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase Mar 11, 2025
75203d7
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Mar 11, 2025
08a07cb
UT workaround
tjruwase Mar 13, 2025
728dd38
Merge branch 'master' into olruwase/ds_5241
tjruwase Mar 17, 2025
2ac9211
Merge branch 'master' into olruwase/ds_5241
tjruwase Mar 19, 2025
2d6913a
Merge branch 'master' into olruwase/ds_5241
tjruwase Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Improve overflow handling in ZeRO
tjruwase committed Jan 28, 2025
commit a3a18f7288e85e34a6a470cdbb808f87928d6257
7 changes: 5 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ class BF16_Optimizer(ZeROOptimizer):
def __init__(self,
init_optimizer,
param_names,
bfloat16_config,
mpu=None,
clip_grad=0.0,
norm_type=2,
@@ -44,7 +45,6 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
@@ -53,10 +53,13 @@ def __init__(self,
self.param_names = param_names
self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

assert bfloat16_config.enabled, f"BF16Optimizer: requires bfloat16 to be enabled"
assert grad_acc_dtype in [torch.float32, torch.bfloat16
], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
self.grad_acc_dtype = grad_acc_dtype
self.immediate_grad_update = immediate_grad_update

self.immediate_grad_update = bfloat16_config.immediate_grad_update
self.check_overflow = bfloat16_config.check_overflow

self.clip_grad = clip_grad
self.norm_type = norm_type
17 changes: 3 additions & 14 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .validate_config import get_bfloat16_config

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
@@ -163,13 +164,6 @@ def get_fp16_enabled(param_dict):
return False


def get_bfloat16_enabled(param_dict):
for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys():
return get_scalar_param(param_dict[key], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT)
return False


def get_bfloat16_immediate_grad_update(param_dict):
for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys():
@@ -193,8 +187,6 @@ def get_fp16_auto_cast(param_dict):
def get_loss_scale(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT)
elif get_bfloat16_enabled(param_dict):
return 1.0
else:
return FP16_LOSS_SCALE_DEFAULT

@@ -203,8 +195,6 @@ def get_initial_dynamic_scale(param_dict):
if get_fp16_enabled(param_dict):
initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER,
FP16_INITIAL_SCALE_POWER_DEFAULT)
elif get_bfloat16_enabled(param_dict):
initial_scale_power = 0
else:
initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT

@@ -828,10 +818,9 @@ def _initialize_params(self, param_dict):
self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
self.bfloat16_immediate_grad_update = get_bfloat16_immediate_grad_update(param_dict)
self.bfloat16_config = get_bfloat16_config(param_dict)
assert not (self.fp16_enabled
and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
and self.bfloat16_config.enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
7 changes: 6 additions & 1 deletion deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,9 @@
BFLOAT16_FORMAT = '''
BFLOAT16 parameters should be of the format:
"bf16": {
"enabled": true
"enabled": true,
"immediate_grad_update": false,
"check_overflow": false
}
'''
BFLOAT16 = "bf16"
@@ -126,6 +128,9 @@
BFLOAT16_ENABLED = "enabled"
BFLOAT16_ENABLED_DEFAULT = False

CHECK_OVERFLOW = "check_overflow"
BFLOAT16_CHECK_OVERFLOW_DEFAULT = False

# BFLOAT16 optimizer immediate gradient update
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
14 changes: 11 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
@@ -839,7 +839,7 @@ def fp16_enabled(self):
return self._config.fp16_enabled

def bfloat16_enabled(self):
return self._config.bfloat16_enabled
return self._config.bfloat16_config.enabled

def fp16_master_weights_and_gradients(self):
return self._config.fp16_master_weights_and_gradients
@@ -1527,14 +1527,14 @@ def _configure_bf16_optimizer(self, optimizer):
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = BF16_Optimizer(optimizer,
self.param_names,
bfloat16_config=self._config.bfloat_config,
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.seq_data_parallel_group,
timers=timers,
grad_acc_dtype=self.get_data_types()[1],
graph_harvesting=self.graph_harvesting(),
immediate_grad_update=self._config.bfloat16_immediate_grad_update,
has_moe_layers=self.has_moe_layers)

return optimizer
@@ -1545,6 +1545,13 @@ def _configure_zero_optimizer(self, optimizer):
mics_shard_size = self.mics_shard_size()
model_dtype, gradient_accumulation_dtype = self.get_data_types()

if self.bfloat16_enabled():
check_grad_overflow = self._config.bfloat16_config.check_grad_overflow
elif self.fp16_enabled():
check_grad_overflow = True
else:
check_grad_overflow = False

timers = self.timers if self.wall_clock_breakdown() else NoopTimer()

if optimizer is None:
@@ -1596,7 +1603,8 @@ def _configure_zero_optimizer(self, optimizer):
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type,
elastic_checkpoint=self.zero_elastic_checkpoint())
elastic_checkpoint=self.zero_elastic_checkpoint(),
check_grad_overflow=check_grad_overflow)

elif zero_stage == ZeroStageEnum.weights:
assert not self.has_moe_layers, "MoE not supported with Stage 3"
4 changes: 3 additions & 1 deletion deepspeed/runtime/fp16/loss_scaler.py
Original file line number Diff line number Diff line change
@@ -196,7 +196,9 @@ def update_scale(self, overflow):
hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}"
logger.info(hysteresis_msg)
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:

stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
2 changes: 1 addition & 1 deletion deepspeed/runtime/hybrid_engine.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ def new_inference_container(self, orig_layer, policy_cls, layer_id):

if self._config.fp16_enabled:
inference_dtype = torch.float16
elif self._config.bfloat16_enabled:
elif self._config.bfloat16_config.enabled:
inference_dtype = torch.bfloat16
else:
inference_dtype = torch.float32
2 changes: 1 addition & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
@@ -863,7 +863,7 @@ def _exec_backward_pass(self, buffer_id):

if self.using_bf16_optimizer and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
if not self._config.bfloat16_immediate_grad_update:
if not self._config.bfloat16_config.immediate_grad_update:
self.optimizer.update_hp_grads(clear_lp_grads=False)

# Free up the memory from the output of forward()
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
@@ -358,10 +358,10 @@ def _post_init_method(self, module):

def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
if ds_config.bfloat16_config.enabled and ds_config.fp16_enabled:
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")

if ds_config.bfloat16_enabled:
if ds_config.bfloat16_config.enabled:
self.dtype = torch.bfloat16
elif ds_config.fp16_enabled:
self.dtype = torch.half
63 changes: 34 additions & 29 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
@@ -136,7 +136,8 @@ def __init__(self,
round_robin_gradients=False,
has_moe_layers=False,
fp16_master_weights_and_gradients=False,
elastic_checkpoint=False):
elastic_checkpoint=False,
check_grad_overflow=True):

if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
self.cpu_offload = True
@@ -155,6 +156,7 @@ def __init__(self,
# 2. keep common stuff here in case we need to add ne552w fused optimizer later

self.elastic_checkpoint = elastic_checkpoint
self.check_grad_overflow = check_grad_overflow
self.param_names = param_names
self.mpu = mpu
# differences from apex.fp16_utils:
@@ -552,6 +554,8 @@ def __init__(self,

self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()
if self.cpu_offload:
self._create_optimizer_mapping()

def destroy(self):
for i, _ in enumerate(self.optimizer.param_groups):
@@ -578,6 +582,12 @@ def _create_param_mapping(self):

return param_mapping

def _create_optimizer_mapping(self):
for i, _ in enumerate(self.optimizer.param_groups):
for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
lp._zero_optimizer = self

def _link_all_hp_params(self):
if self.cpu_offload:
self._get_offload_gradient_dict()
@@ -1170,11 +1180,14 @@ def get_grad_position(self, group_id, tensor_list, first_offset, partition_size)
]
current_offset += num_elements

def update_overflow_tracker_for_param_grad(self, param):
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is not None and self._has_inf_or_nan(grad_accum.data):
def update_offload_overflow_tracker(self, grad):
if grad is not None and self._has_inf_or_nan(grad.data):
self.local_overflow = True

def update_offload_overflow_tracker_for_param_grad(self, param):
grad_accum = self.get_param_gradient_attribute(param)
self.update_offload_overflow_tracker(grad_accum)

def _get_offload_gradient_dict(self):
for param_group_index, _ in enumerate(self.optimizer.param_groups):
self.offload_gradient_dict[param_group_index] = []
@@ -1276,7 +1289,7 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
src_tensor = src_tensor.float()

dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None #offload only
self.clear_grad_attribute(param) #offload only

def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
@@ -1308,17 +1321,17 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
"""

# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
total_dev_norm = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group)

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)
total_norm = total_dev_norm[0].item()**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
total_norm = -1.0

return total_norm
return torch.tensor(total_norm, device=self.device, dtype=torch.float)

############################################################################################
def copy_grads_in_partition(self, param):
@@ -1330,7 +1343,7 @@ def copy_grads_in_partition(self, param):
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)

self.update_overflow_tracker_for_param_grad(param)
self.update_offload_overflow_tracker_for_param_grad(param)

self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)

@@ -1784,10 +1797,7 @@ def scaled_global_norm(self, norm_type=2):
norm_groups = []
for i, group in enumerate(self.bit16_groups):
if self.cpu_offload:
# complete complete_grad_norm_calculation_for_cpu_offload return python float, moving back to
# torch.tensor as else statement returns tensor as well
norm = torch.tensor(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]),
device=self.device)
norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])
norm_groups.append(norm)
else:
norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))
@@ -1827,8 +1837,8 @@ def step(self, closure=None):
see_memory_usage(f"In step before checking overflow")

# First compute norm for all group so we know if there is overflow
if self.dtype == torch.float16:
self.check_overflow()
if self.check_grad_overflow:
self.check_overflow(partition_gradients=self.partition_gradients)

prev_scale = self.loss_scale
self._update_scale(self.overflow)
@@ -1838,7 +1848,8 @@ def step(self, closure=None):
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
for k in self.averaged_gradients.keys():
self.averaged_gradients[k] = None

see_memory_usage('After overflow after clearing gradients')

@@ -1999,21 +2010,15 @@ def has_overflow_partitioned_grads_serial(self):
return invalid_grad_count.bool()

def has_overflow(self, partition_gradients=True):
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to(
get_accelerator().current_device_name())

if partition_gradients:
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to(
get_accelerator().current_device_name())
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)

else:
params = []
for group in self.bit16_groups:
for param in group:
params.append(param)
overflow_gpu = self.has_overflow_serial(params).byte().to(get_accelerator().current_device_name())

# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)
2 changes: 2 additions & 0 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
@@ -127,6 +127,8 @@ def set_full_hp_grad(self, value):
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data))
if hasattr(self, '_zero_optimizer'):
self._zero_optimizer.update_offload_overflow_tracker(value)


def safe_get_full_fp32_param(param):