Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.autocast with ZeRO #6993

Open
wants to merge 67 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
a4fbc3a
Use ds-specific module id to avoid conflicts (#6847)
tjruwase Jan 31, 2025
9d48dc6
add autocast support and ds_config item
tohtana Jan 31, 2025
8957009
prepare ipg buckets for multiple dtypes
tohtana Feb 1, 2025
c390c76
switch communication data type
tohtana Feb 2, 2025
458797d
add gradscaler
tohtana Feb 3, 2025
2984415
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams Jan 31, 2025
6b6a600
fix import and formatting
tohtana Feb 3, 2025
2817c02
convert comm type for z3
tohtana Feb 4, 2025
790c43a
Allow NVIDIA Blackwell (#6991)
fabiendupont Feb 4, 2025
fcdeda3
Update GH org references (#6998)
tjruwase Feb 5, 2025
b476d07
Update CNAME
loadams Feb 5, 2025
72f9687
Update CNAME
loadams Feb 5, 2025
a2425da
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma Feb 5, 2025
a896693
autotp training(fix dco) (#7004)
inkcherry Feb 5, 2025
1c1d43a
cast dtype for allgather
tohtana Feb 28, 2025
a265cad
import triton files when triton is supported and installed (#6989)
oelayan7 Feb 6, 2025
359c85d
Update A6000 tests transformers version (#7016)
loadams Feb 8, 2025
735fc2c
Fix ds-chat CI regression (#7015)
tjruwase Feb 10, 2025
1e7888c
[Ulysses tutorial] typos (#7024)
stas00 Feb 11, 2025
7d15f26
fix hostname -I for macOS #6497 (#6990)
fitzjalen Feb 12, 2025
ef5a2a4
Update workflows to cuda 12.4 (#7000)
loadams Feb 12, 2025
f1aea5d
[ROCm] Enable fp_quantizer on ROCm (#7027)
rraminen Feb 13, 2025
8152824
add gds chinese blog (#7034)
GuanhuaWang Feb 13, 2025
c898ac5
Add chinese blog for deepspeed windows, and fix format (#7035)
hwchen2017 Feb 14, 2025
e3ea926
AIO on ROCM (#7023)
jomayeri Feb 14, 2025
e946615
Control trace cache warnings (#7039)
tjruwase Feb 18, 2025
38e9bf3
Update CUDA compute capability to support Blackwell (#7047)
hwchen2017 Feb 18, 2025
acc6a1e
Update setup.py handling of ROCm cupy (#7051)
loadams Feb 19, 2025
bda1430
nv-ds-chat breaks with latest transformers (#7052)
loadams Feb 19, 2025
c184b16
Rename aio_thread_count to intra_op_parallelism (#7056)
tjruwase Feb 19, 2025
a2b8219
add autoTP training zero2 tests (#7049)
inkcherry Feb 19, 2025
9f50cde
Fix, bf16 optimizer remove dup loop (#7054)
wukong1992 Feb 20, 2025
71bd64e
Update version.txt after 0.16.4 release (#7063)
loadams Feb 20, 2025
5c9fd4b
fix an outdated doc wrt CUDA_VISIBLE_DEVICES (#7058)
stas00 Feb 20, 2025
20a46b7
Tecorigin sdaa accelerator (#6903)
siqi654321 Feb 20, 2025
5f68587
Handle special case of libuv for Windows (#7064)
loadams Feb 20, 2025
3f5cd1a
Update README with info on newest accelerator (#7065)
loadams Feb 21, 2025
b80d2d4
Bug Fix for offload_states API (#7050)
U-rara Feb 21, 2025
877c30e
Fix TOCTOU issues, switch to fstat (#7067)
loadams Feb 24, 2025
59fe7f6
config torch to avoid graph breaks caused by logger (#6999)
ShellyNR Feb 24, 2025
060aa5a
Fix meta load tensor imcompatible issue (#7073)
Yejing-Lai Feb 24, 2025
817b31d
Replace calls to `python setup.py sdist` with `python -m build --sdis…
loadams Feb 24, 2025
f99605b
Revert "Handle special case of libuv for Windows (#7064)" (#7076)
loadams Feb 25, 2025
57805b2
Add DeepseekV3 AutoTP. (#7045)
Yejing-Lai Feb 26, 2025
697050e
Improve inference tutorial docs (#7083)
loadams Feb 26, 2025
83c9461
Pin transformers version on tests that use latest. (#7085)
loadams Feb 27, 2025
c6bf7fb
Update README.md with ICS '23 MoE paper link (#7087)
siddharth9820 Feb 27, 2025
2c36865
Update parallelism for nv-torch-latest/nightly tests due to more GPUs…
loadams Feb 27, 2025
f2b89ec
Remove workflows for very old torch versions (#7090)
loadams Feb 28, 2025
965cb2b
Merge branch 'master' into tohtana/support_autocast
tohtana Feb 28, 2025
981e8e2
clear reduce buffer
tohtana Mar 1, 2025
37f77ae
add config to set lower precision modules
tohtana Mar 1, 2025
3083d94
fix to use comm dtype in config when autocast is disabled
tohtana Mar 3, 2025
d688b75
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 3, 2025
aa60eb3
add tests
tohtana Mar 4, 2025
9529830
sort dtypes
tohtana Mar 5, 2025
c8056a8
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 5, 2025
c56339c
fix for cases where param and param.ds_tensor have different dtypes
tohtana Mar 6, 2025
1d6ed6e
Merge branch 'master' into tohtana/support_autocast
loadams Mar 7, 2025
aa10e11
fix moe tests
tohtana Mar 8, 2025
26e62e4
fix tests for opt state offloading
tohtana Mar 8, 2025
a74fa1e
fix var name
tohtana Mar 8, 2025
2016995
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 10, 2025
9f5b8c0
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 12, 2025
7973b88
fix arg order
tohtana Mar 12, 2025
15c436d
Merge branch 'master' into tohtana/support_autocast
tjruwase Mar 17, 2025
713d56f
Merge branch 'master' into tohtana/support_autocast
tohtana Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
autotp training(fix dco) (#7004)
Same as [this PR](#6922).
[affeb88](affeb88)
I noticed the CI updated the DCO check recently. Using the suggested
rebase method for sign-off would reintroduce many conflicts, so I opted
for a squash merge with sign-off instead. thanks: )

Signed-off-by: inkcherry <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
inkcherry authored and tohtana committed Feb 28, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit a896693af777cadb85e2e9e001bb0aee07f0f87b
33 changes: 32 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed
@@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs):
engine = InferenceEngine(model, config=ds_inference_config)

return engine


def tp_model_init(model, tp_size, dtype):
"""
Initialize the model for tensor parallelism.

Args:
model (torch.nn.Module): The model to be initialized.
tp_size (int): The tensor parallelism size.
dtype (torch.dtype): The data type to be used for the model.

Returns:
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."

set_autotp_mode(training=True)

from deepspeed.runtime.tensor_parallel import TpTrainingManager
# The expected usage here is for it to be invoked by transformers package.

#TODO: We should provide a custom TP mapping solution without using autoTP
#as modifying the autoTP logic may be more difficult for users compared to configuring it

model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module

setattr(model, 'ds_autotp_parsed', True)

return model
6 changes: 6 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
@@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)


@timed_op
def broadcast_object_list(object_list, src, group=None, device=None):
global cdb
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)


@timed_op
def all_gather(tensor_list,
tensor,
4 changes: 4 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
@@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@disable_compiler_collective
def broadcast_object_list(self, object_list, src, group=None, device=None):
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)

@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
1 change: 0 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject import replace_transformer_layer, generic_injection
2 changes: 1 addition & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
@@ -6,5 +6,5 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
from .policy import DSPolicy
89 changes: 30 additions & 59 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
@@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode


def move(tensor, device, copy=True):
@@ -333,10 +335,18 @@ def tp_parser(model):
return policy_list

def set_tensor_parallel_config(self, mp_size, mp_group):

if is_autotp_training_mode():
self.mp_group = groups.get_tensor_model_parallel_group()
self.mp_size = groups.get_tensor_model_parallel_world_size()
return

self.mp_size = mp_size
self.mp_group = mp_group

def _replace(self, child, name, conv_linear_layer):
# This function should clearly define the routing rules for specific layers
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
@@ -352,80 +362,41 @@ def _replace(self, child, name, conv_linear_layer):
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
return Yuan_LinearLayer(child, self.mp_group)

elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
return Yuan_LinearAllreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
return GateUpPack_LinearLayer(child, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data
return Conv_LinearALlreduce(child, self.mp_group, name=name)
elif name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(child, self.mp_group)

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(child, self.mp_group, name=name)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0] // mp_size, weight_shape[1]]
setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()

if require_tp_fused_qkvw(name, self.mp_size):
conv_LinearLayer(child, self.mp_group)
elif require_tp_fused_qkvw(name, self.mp_size):
#Check and handle fused qkv for TP
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
bias_data_dc = None
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)

setattr(child, "replaced", True)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
return LinearLayer(child, self.mp_group, name=name)

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
661 changes: 569 additions & 92 deletions deepspeed/module_inject/layers.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -236,7 +236,7 @@ def load_module_recursive(module, prefix='', level=0):
child.weight.ds_id in all_ds_ids):
prefix1 = all_ds_ids[child.weight.ds_id]
if child.__class__ is nn.Linear:
child = LinearLayer(weight=all_ds_ids[child.weight.ds_id])
child = LinearLayer.from_weights(weight=all_ds_ids[child.weight.ds_id])
setattr(module, name, child)
continue
child_params = list(child.parameters())
@@ -249,7 +249,9 @@ def load_module_recursive(module, prefix='', level=0):
child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
setattr(module, name, child)
elif child.__class__ in [nn.Linear, ColumnParallelLinear, RowParallelLinear]:
child = LinearLayer(weight_shape=child.weight.shape, dtype=child.weight.dtype, bias=child.bias)
child = LinearLayer.from_weights(weight_shape=child.weight.shape,
dtype=child.weight.dtype,
bias=child.bias)
setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape)
6 changes: 3 additions & 3 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from .replace_policy import replace_policies, generic_policies
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed.module_inject.layers import is_autotp_training_mode
from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size

@@ -323,7 +323,7 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):

else:
# copy relevant state from child -> new module
if config.replace_with_kernel_inject:
if not is_autotp_training_mode() and config.replace_with_kernel_inject:
new_module = replace_with_policy(child,
_policy,
config.triangular_masking,
@@ -475,7 +475,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
set_lm_head(replaced_module)
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")

if config.save_mp_checkpoint_path is not None:
if not is_autotp_training_mode() and config.save_mp_checkpoint_path is not None:
from collections import OrderedDict
import json
num_partitions = 8
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@
from ..compression.constants import *
from .swap_tensor.aio_config import get_aio_config

from .tensor_parallel import get_tensor_parallel_config
from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy
from .data_pipeline.constants import *

@@ -933,6 +934,7 @@ def _initialize_params(self, param_dict):
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

self.timers_config = get_timers_config(param_dict)
self.tensor_parallel_config = get_tensor_parallel_config(param_dict)

def _batch_assertion(self):

120 changes: 118 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer

from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
from deepspeed.module_inject.layers import GatherReplacedLayerParams

from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
@@ -75,7 +76,7 @@
from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names
from deepspeed.monitor.monitor import MonitorMaster
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.utils import clip_grad_norm_
from deepspeed.runtime.utils import clip_grad_norm_, compare_tensors_in_structures
from deepspeed.runtime.eigenvalue import Eigenvalue
from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \
DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \
@@ -231,7 +232,6 @@ def __init__(self,
self._step_applied = False
self._global_grad_norm = None
self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.

self.checkpoint_engine = None

self._is_gradient_accumulation_boundary = None
@@ -248,6 +248,8 @@ def __init__(self,
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
if self.autotp_size() > 1:
self._configure_tensor_parallel_states(model)
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
if mpu is not None:
if self.elasticity_enabled():
@@ -415,6 +417,71 @@ def _optimized_linear_offload_setup(self):
else:
p.ds_offload = False

def _configure_tensor_parallel_states(self, model):
"""
Configures the tensor parallel states for the model.
This includes setting up the tensor parallel groups, initializing the TP mesh,
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks.
"""
self._set_client_model(model)

# sanity check
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
assert self.zero_optimization_stage(
) <= 1, "Currently, the compatibility between 'autotp' and 'zero_stage > 1' has not been validated"

self.mpu = groups
self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size())

self.first_dataloader_check = None

def check_dataloader_inputs_same_across_ranks(module, args, kwargs):

def broadcast_and_check(args, bcast_rank, bcast_group):
if isinstance(args, tuple):
args = list(args)
if len(args) > 0:
if self.mpu.get_tensor_model_parallel_rank() == 0:
_src_args = [args]
dist.broadcast_object_list(object_list=_src_args,
src=bcast_rank,
group=bcast_group,
device=get_accelerator().current_device())
# Rank 0 does not need to compare with itself
is_equal = True
else:
_src_args = [None]
dist.broadcast_object_list(object_list=_src_args,
src=bcast_rank,
group=bcast_group,
device=get_accelerator().current_device())

is_equal = compare_tensors_in_structures(args, _src_args[0])

equal_tensor = torch.tensor(is_equal,
dtype=self.communication_data_type,
device=get_accelerator().current_device())
dist.all_reduce(equal_tensor, group=bcast_group)
assert torch.equal(
equal_tensor,
torch.tensor(groups.get_tensor_model_parallel_world_size(),
dtype=self.communication_data_type,
device=get_accelerator().current_device())
), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency."

bcast_rank = self.mpu.get_tensor_model_parallel_src_rank()
bcast_group = self.mpu.get_tensor_model_parallel_group()

broadcast_and_check(args, bcast_rank, bcast_group)
broadcast_and_check(kwargs, bcast_rank, bcast_group)

logger.info(f":The Dataloader has passed the TP group consistency check.")
self.first_dataloader_check.remove()

self.first_dataloader_check = self.module.register_forward_pre_hook(check_dataloader_inputs_same_across_ranks,
prepend=True,
with_kwargs=True)

def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
@@ -836,6 +903,9 @@ def zero_legacy_stage1(self):
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters

def autotp_size(self):
return self._config.tensor_parallel_config.autotp_size

def graph_harvesting(self):
return self._config.graph_harvesting

@@ -3582,6 +3652,52 @@ def _save_zero_checkpoint(self, save_path, tag):
ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero'
logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')

def _replace_module_consolidated_state_dict(self):
"""
Get a full non-partitioned state_dict with fp16 weights on cpu.
Important: this function must be called on all ranks and not just rank 0.
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict)
This method is used for tensor parallel training.
Returns:
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None.
"""
#TODO: If we use both Zero3 and tensor parallel simultaneously
# we need to consolidate the gather mechanisms of both.
state_dict = OrderedDict() if dist.get_rank() == 0 else None

def get_layer_state_dict(module, prefix=""):
with GatherReplacedLayerParams(list(module.parameters(recurse=False)), module, enabled=True):
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
if (dist.get_rank() == 0):
state_dict[key] = param.detach().cpu()
# print(key,module, param.detach().cpu().shape)

for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")

get_layer_state_dict(self.module, prefix="")

# ensure that all GPU communication tasks are completed before the process exits
get_accelerator().synchronize()
return state_dict

def _consolidated_16bit_state_dict(self, exclude_frozen_parameters=False):
"""
Consolidate the 16-bit state dictionary.
"""
if self.zero_optimization_stage() == ZeroStageEnum.weights:
return self._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters)
elif self.autotp_size() > 1:
return self._replace_module_consolidated_state_dict()

raise ValueError("consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, "
"including Zero Stage 3 and tensor parallelism.")

def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False):
"""
Get a full non-partitioned state_dict with fp16 weights on cpu.
7 changes: 7 additions & 0 deletions deepspeed/runtime/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .config import AUTOTP_MODE, get_tensor_parallel_config
from .tp_manager import TpTrainingManager
81 changes: 81 additions & 0 deletions deepspeed/runtime/tensor_parallel/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from enum import Enum
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
import torch
from pydantic import Field
from typing import Optional


class AUTOTP_MODE(Enum):
TRAINING = "TRAINING"
INFERENCE = "INFERENCE"


class TPConfig(DeepSpeedConfigModel):
""" Configure tensor parallelism settings """

tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """

tp_grain_size: int = 1
"The variable required by the autoTP parser has not been activated in training yet"
"as it depends on the gather logic that supports uneven partitioning. "
"Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size."

mpu: object = None
"""
A model parallelism unit object that implements
``get_{model,data}_parallel_{rank,group,world_size}()``.
"""

tp_group: object = None


class TPTrainingConfig(DeepSpeedConfigModel):

dtype: torch.dtype = torch.float16
"""
Desired model data type, will convert model to this type.
"""

autotp_size: int = 0
"""
In automatic tensor-parallelism training, 'tensor_parallel_size'
When set to 0, indicates that it is disabled.
"""
tensor_parallel: TPConfig = Field({}, alias="tp")
"""
Configuration for tensor parallelism used to split the model across several
GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`.
"""

injection_policy_tuple: Optional[tuple] = None
#The following parameters are required by autoTP parser.
########################################
keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
"""

replace_with_kernel_inject: bool = Field(False, alias="kernel_inject")
"""
Set to true to inject inference kernels for models such as, Bert, GPT2,
GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two
linear layers as a tuple:
`(attention_output projection, transformer output projection)`
"""
########################################


def get_tensor_parallel_config(ds_config):

if 'tensor_parallel' in ds_config:
return TPTrainingConfig(**ds_config['tensor_parallel'])
return TPTrainingConfig()
66 changes: 66 additions & 0 deletions deepspeed/runtime/tensor_parallel/tp_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from .config import TPTrainingConfig, TPConfig
from deepspeed.utils import groups
import deepspeed.comm as dist


class TpTrainingManager():

def __init__(self, model, tp_size, dtype):
self.module = model
self.config = self._initialize_config(dtype)

from deepspeed.module_inject.auto_tp import AutoTP
from deepspeed import get_accelerator

# Parse model configuration
parser_dict = AutoTP.tp_parser(model)
print("AutoTP: ", parser_dict)

# Initialize TP configuration and model
self._initialize_tp_config(tp_size)
self._get_model_config_generate()

# Synchronize random number generator state across devices
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
dist.broadcast(_rng_state, groups.get_tensor_model_parallel_src_rank(), self.tp_config.tp_group)
get_accelerator().set_rng_state(_rng_state.cpu())

# Apply injection policies
self._apply_policies(parser_dict)

def _initialize_config(self, dtype):
"""Initialize and return the DeepSpeed TP training configuration."""
config = TPTrainingConfig()
config.dtype = dtype
return config

def _apply_policies(self, parser_dict):
"""Apply injection policies to the parsed modules."""
for client_module, injection_policy in parser_dict:
self.config.injection_policy_tuple = injection_policy
self._apply_injection_policy(self.config, client_module)

def _apply_injection_policy(self, config, client_module=None):
from deepspeed.module_inject import replace_transformer_layer
"""Apply the given injection policy to a client module."""
if isinstance(self.module, torch.nn.Module):
replace_transformer_layer(client_module, self.module, None, self.config, self.model_config)

def _initialize_tp_config(self, tp_size):
"""Perform TP configuration initialization."""
self.tp_config = TPConfig()
self.tp_config.tp_size = tp_size

groups._init_tp_mesh_device(tp_size)
self.tp_config.tp_group = groups.get_tensor_model_parallel_group()
self.config.tensor_parallel = self.tp_config

def _get_model_config_generate(self):
"""Generate and apply HF model configuration."""
self.model_config = getattr(self.module, 'config', None)
45 changes: 44 additions & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from torch._six import inf
except ModuleNotFoundError:
from torch import inf

from typing import Union, List, Dict
from deepspeed import comm as dist
from deepspeed.moe.utils import is_moe_param
from deepspeed.utils import groups, logger
@@ -1101,3 +1101,46 @@ def move_back_key(state, key):
move_back_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_back_key(state, "exp_avg_sq")


def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[List, Dict]) -> bool:
"""
Compare two lists or dictionaries for equality, including any tensors they may contain.
Args:
inputs1: First input, either a list or a dictionary.
inputs2: Second input, either a list or a dictionary.
Returns:
True if inputs1 and inputs2 are equal; False otherwise.
"""
if type(inputs1) != type(inputs2): # Ensure types match
return False

if isinstance(inputs1, list) and isinstance(inputs2, list):
if len(inputs1) != len(inputs2):
return False
for val1, val2 in zip(inputs1, inputs2):
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
val1 = val1.to(get_accelerator().current_device())
val2 = val2.to(get_accelerator().current_device())
if not torch.equal(val1, val2):
return False
elif val1 != val2:
return False
return True

elif isinstance(inputs1, dict) and isinstance(inputs2, dict):
if inputs1.keys() != inputs2.keys():
return False
for key in inputs1:
val1 = inputs1[key].to(get_accelerator().current_device())
val2 = inputs2[key].to(get_accelerator().current_device())
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
if not torch.equal(val1, val2):
return False
elif val1 != val2:
return False
return True

return False
123 changes: 121 additions & 2 deletions deepspeed/utils/groups.py
Original file line number Diff line number Diff line change
@@ -46,8 +46,6 @@
# All to All quantized graident communication groups
_ALL_TO_ALL_GROUP = {}

_DATA_PARALLEL_GROUP = None

mesh_device = None


@@ -64,6 +62,127 @@ def _ensure_divisibility(numerator, denominator):
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)


# ======== Start: Tensor Parallel Group Attributes ========

# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None

# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None


def _init_tp_mesh_device(tensor_model_parallel_size=1, data_parallel_size=None):
"""Initialize model data parallel groups."""

global _DATA_PARALLEL_GROUP
global _MODEL_PARALLEL_GROUP
global _TENSOR_MODEL_PARALLEL_GROUP

if _TENSOR_MODEL_PARALLEL_GROUP is not None:
return

if data_parallel_size is None:
data_parallel_size = dist.get_world_size() // tensor_model_parallel_size

mesh_device = dist.initialize_mesh_device((data_parallel_size, tensor_model_parallel_size),
("data_parallel", "tensor_parallel"))
_TENSOR_MODEL_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="tensor_parallel")
_DATA_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="data_parallel")

# They are always equal only in 2D (DP + TP) parallelism.
# _MODEL_PARALLEL_GROUP is assigned the same value as _TENSOR_MODEL_PARALLEL_GROUP
# to allow for future potential changes.
_MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP

return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP


def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""

assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP


def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""

assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP


def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP


def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size


def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return dist.get_world_size(group=get_tensor_model_parallel_group())


def get_model_parallel_world_size():
return get_tensor_model_parallel_world_size()


def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank


def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return dist.get_rank(group=get_tensor_model_parallel_group())


def get_model_parallel_rank():
return get_tensor_model_parallel_rank()


def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = dist.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size


def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return dist.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return dist.get_rank(group=get_data_parallel_group())


# ======== End: Tensor Parallel Group Attributes ========


# Not currently used. Helper function to create a model (tensor) parallel group.
def _create_model_parallel(model_parallel_size_):
"""
574 changes: 574 additions & 0 deletions tests/unit/model_parallelism/test_autotp_training.py

Large diffs are not rendered by default.