From bfb0d3af4c4dcd387a715e9048da2c355371f707 Mon Sep 17 00:00:00 2001 From: zhaoting <37293445+CaitinZhao@users.noreply.github.com> Date: Fri, 21 Feb 2025 10:17:49 +0800 Subject: [PATCH 1/2] zero doc & checkpoint save adapt (#834) * ckpt save when zero * add docs * fix bugs * op_group->optimizer_parallel_group * ms_class -> jit_class * transform_checkpoints->convert_checkpoints * typo * fix bugs --------- Co-authored-by: zhaoting --- docs/index.md | 8 + docs/tools/_toctree.yml | 4 + docs/tools/zero.md | 155 +++++++++++ .../opensora/train/train_t2v_diffusers.py | 2 +- .../opensora_pku/tools/ckpt/combine_ckpt.py | 2 +- mindone/diffusers/training_utils.py | 207 ++++++++++++++- mindone/models/modules/parallel/__init__.py | 12 +- mindone/models/modules/parallel/conv.py | 70 ++++- mindone/models/modules/parallel/dense.py | 41 ++- .../models/modules/parallel/param_wrapper.py | 17 +- mindone/trainers/callback.py | 245 ++++++++++++------ mindone/trainers/zero.py | 107 ++++---- tests/others/test_zero.py | 7 +- 13 files changed, 703 insertions(+), 174 deletions(-) create mode 100644 docs/tools/_toctree.yml create mode 100644 docs/tools/zero.md diff --git a/docs/index.md b/docs/index.md index 3e60f5014a..04ae4aac16 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,6 +32,14 @@ hide: [:octicons-arrow-right-24: Start tuning!](peft/index.md) +- :star2: __Tools__ + + --- + + Train Tools. Include Trainer, ZeRO, Image/Vedio data filtering strategy... + + [:octicons-arrow-right-24: Using it!](tools/zero.md) + - > :rocket: __Accelerate__ > --- diff --git a/docs/tools/_toctree.yml b/docs/tools/_toctree.yml new file mode 100644 index 0000000000..8c398ae5ae --- /dev/null +++ b/docs/tools/_toctree.yml @@ -0,0 +1,4 @@ +- sections: + - local: zero + title: ZeRO + title: Get started diff --git a/docs/tools/zero.md b/docs/tools/zero.md new file mode 100644 index 0000000000..ca96ea43cc --- /dev/null +++ b/docs/tools/zero.md @@ -0,0 +1,155 @@ +# Zero redundancy optimizer(ZeRO) on MindOne + +Zero Redundancy Optimizer (ZeRO) is a method for reducing memory usage under data parallelism strategy on paper: [ZeRO: ZeRO: Memory Optimization Towards Training A Trillion Parameter Models](https://arxiv.org/pdf/1910.02054.pdf). + +ZeRO eliminates memory redundancies in data and model parallel training while retaining low communication volume and high computational +granularity, allowing us to scale the model size proportional to the number of devices with sustained high efficiency. + +This tutorial walks you through how to generate faster and better with the ZeRO on MindOne. + +## Build Train Network With ZeRO + +Build a train network with ZeRO. + +```python +import mindspore as ms +from mindspore.communication import init +from mindspore.communication.management import GlobalComm +from mindone.trainers.zero import prepare_train_network + +# Initialize distributed environment +def init_env(mode, distribute): + ms.set_context(mode=mode) + if distribute: + init() + # ZeRO take effect must on DATA_PARALLEL + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + ) + +init_env(ms.GRAPH_MODE, True) + +# Net is your Train Network +net = Net() +# opt must be the subclass of MindSpore Optimizer. +opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3) + +# build a train network with ZeRO +train_net = prepare_train_network(net, opt, zero_stage=2, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) +``` + +!!! tip + optimizer_parallel_group may not be GlobalComm.WORLD_COMM_GROUP. Using [create_group](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.communication.html#mindspore.communication.create_group) to create your optimizer_parallel_group. + +More details: + +::: mindone.trainers.zero.prepare_train_network + +[Here](https://github.com/mindspore-lab/mindone/blob/master/tests/others/test_zero.py) is an example. + +## Memory Analysis + +The memory consumption during the training can be divided into two main parts: + +- Residual states. Mainly includes activate functions, temporary buffers, and unavailable memory fragments. +- Model states. Mainly includes three parts: optimizer states(AdamW fp32), gradients(fp16), and parameters(fp16). The three are abbreviated as OPG. Assuming the number of model parameters is Φ, +the total model states is 2Φ(parameters) + 2Φ(gradients) + (4Φ + 4Φ + 4Φ)(optimizer states) = 16Φ, the AdamW states accounting for 75%. + +Residual states can be greatly reduced through [recompute](https://www.mindspore.cn/docs/en/master/model_train/parallel/recompute.html) and [model parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/strategy_select.html). +Then the ZeRO algorithm can be used to reduce model states. + +For the optimization of model states (removing redundancy), ZeRO uses the method of partitioning, which means that each card only stores 1/N data. + +ZeRO has three main optimization stages (as depicted in ZeRO paper Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively: + +1) Optimizer State Partitioning (Pos): Optimizer states are kept 1/N, the model parameters and gradients are still kept in full on each card. The model state of each card is 4Φ + 12Φ/N, when N is very large, it tend to 4Φ, that's the 1/4 original memory; +2) Add Gradient Partitioning (Pos+g): Add the gradients partitioning to 1/N, The model state of each card is 2Φ + (2Φ + 12Φ)/N, when N is very large, it tend to 2Φ, that's the 1/8 original memory; +3) Add Parameter Partitioning (Pos+g+p): Add the parameters partitioning to 1/N, The model state of each card is 16Φ/N, when N is very large, it tend to 0; + +Pos correspond to ZeRO-1, Pos+g correspond to ZeRO-2 and Pos+g+p correspond to ZeRO-3. + +## Communitition Analysis + +Currently, AllReduce commonly used method is Ring AllReduce, which is divided into two steps: ReduceScatter and AllGather. The communication data volume (send+receive) of each card is approximately 2Φ. + +| zero stage | forward + backward | gradient | optimizer update | communitition | +| --- |--------------------|---------------------|------------------|---------------| +| 0 | NA | AllReduce | NA | 2Φ | +| 1 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ | +| 2 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ | +| 3 | 2 AllGather | ReduceScatter | NA | 3Φ | + +It can be concluded that Zero3 has an additional communication calculation. But, computing and communication are parallel streams on MindSpore. When the computation after communication is relatively large, ZeRO3 may be faster. + +## CheckPoint Saving & Loading + +Because the parameters of the model have been split, the parameters of each card need to be saved. + +### Resume + +checkpoint save: + +| zero stage | parameters | optimizer states | ema | +|------------|------------| --- | --- | +| 0 | one card | one card | one card | +| 1 | one card | each card | each card | +| 2 | one card | each card | each card | +| 3 | each card | each card | each card | + +!!! tip + + 💡 Recommend using rank_id to distinguish checkpoint saved on different cards. + +```python +rank_id = get_rank_id() +zero_stage=2 +train_net = prepare_train_network(net, opt, zero_stage=zero_stage, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) +if resume: + network_ckpt = "network.ckpt" if zero_stage != 3 else f"network_{rank_id}.ckpt" + ms.load_checkpoint(network_ckpt, net=train_net.network) + optimizer_ckpt = "optimizer.ckpt" if zero_stage == 0 else f"optimizer_{rank_id}.ckpt" + ms.load_checkpoint(optimizer_ckpt, net=train_net.optimizer) + ema_ckpt = "ema.ckpt" if zero_stage == 0 else f"ema_{rank_id}.ckpt" + ms.load_checkpoint(ema_ckpt, net=train_net.ema) +``` + +### Inference + +Inference need complete model parameters when use zero3. There are two ways(online & offline) to get the complete model parameters. + +#### Online Checkpoint Combile + +```python +def do_ckpt_combine_online(net_to_save, optimizer_parallel_group): + new_net_to_save = [] + all_gather_op = ops.AllGather(optimizer_parallel_group) + for param in net_to_save: + if param.parallel_optimizer: + new_data = ms.Tensor(all_gather_op(param).asnumpy()) + else: + new_data = ms.Tensor(param.asnumpy()) + new_net_to_save.append({"name": param.name, "data": new_data}) + return new_net_to_save + +net_to_save = [{"name": p.name, "data": p} for p in network.trainable_params()] +net_to_save = net_to_save if zero_stage != 3 else do_ckpt_combine_online(net_to_save, optimizer_parallel_group) +ms.save_checkpoint(net_to_save, "network.ckpt") +``` + +Add the code when need save model parameters. + +#### Offline Checkpoint Combile + +Parameters split infomation will be save when using ZereHelper, could use it to combile the checkpoints offline. + +```python +from mindone.trainers.zero import convert_checkpoints + +src_checkpoint = "save_checkpoint_dir/ckpt_{}.ckpt" +src_param_split_info_json = "params_info/params_split_info_{}.json" +group_size = 2 +convert_checkpoints(src_checkpoint, src_param_split_info_json, group_size) +``` + +And get the complete model parameters checkpoint at `save_checkpoint_dir/ckpt_all_2.ckpt`. diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index abba8c44e2..6731c4dc58 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -531,7 +531,7 @@ def main(args): latent_diffusion_with_loss, optimizer, zero_stage=args.zero_stage, - op_group=GlobalComm.WORLD_COMM_GROUP, + optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP, comm_fusion=comm_fusion_dict, scale_sense=loss_scaler, drop_overflow_update=args.drop_overflow_update, diff --git a/examples/opensora_pku/tools/ckpt/combine_ckpt.py b/examples/opensora_pku/tools/ckpt/combine_ckpt.py index 57fe2d1fe4..b4110ea86d 100644 --- a/examples/opensora_pku/tools/ckpt/combine_ckpt.py +++ b/examples/opensora_pku/tools/ckpt/combine_ckpt.py @@ -25,7 +25,7 @@ def main(): else args.strategy_ckpt ) assert os.path.exists(strategy_file), f"{strategy_file} does not exist!" - ms.transform_checkpoints(args.src, args.dest, "full_", strategy_file, None) + ms.convert_checkpoints(args.src, args.dest, "full_", strategy_file, None) output_path = os.path.join(args.dest, "rank_0", "full_0.ckpt") assert os.path.isfile(output_path) diff --git a/mindone/diffusers/training_utils.py b/mindone/diffusers/training_utils.py index 77d803f6f4..b14642955a 100644 --- a/mindone/diffusers/training_utils.py +++ b/mindone/diffusers/training_utils.py @@ -1,3 +1,4 @@ +import contextlib import copy import logging import math @@ -14,9 +15,17 @@ import mindspore as ms from mindspore import context, nn, ops from mindspore.amp import DynamicLossScaler, StaticLossScaler, all_finite +from mindspore.common import dtype as mstype +from mindspore.common.api import _pynative_executor from mindspore.communication import get_group_size, get_local_rank, get_rank, init +from mindspore.communication.management import GlobalComm +from mindspore.context import ParallelMode +from mindspore.parallel._utils import _get_parallel_mode from mindone.diffusers._peft import set_peft_model_state_dict +from mindone.diffusers.models.model_loading_utils import silence_mindspore_logger +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.trainers.zero import ZeroHelper, prepare_network from .models import UNet2DConditionModel from .schedulers import SchedulerMixin @@ -104,6 +113,12 @@ def compute_dream_and_update_latents( def cast_training_params(model: Union[nn.Cell, List[nn.Cell]], dtype=ms.float32): + """ + Casts the training parameters of the model to the specified data type. + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ if not isinstance(model, list): model = [model] for m in model: @@ -133,7 +148,8 @@ def _set_state_dict_into_text_encoder(lora_state_dict: Dict[str, ms.Tensor], pre def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): - """Compute the density for sampling the timesteps when doing SD3 training. + """ + Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -152,7 +168,8 @@ def compute_density_for_timestep_sampling( def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. + """ + Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -310,7 +327,9 @@ def pin_memory(self) -> None: raise NotImplementedError("Not Implemeneted for `pin_memory`.") def to(self, dtype=None, non_blocking=False) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`.""" + r""" + Move internal buffers of the ExponentialMovingAverage to `device`. + """ # .to() on the tensors handles None correctly raise NotImplementedError("Not Implemeneted for `to`.") @@ -335,9 +354,10 @@ def state_dict(self) -> dict: def load_state_dict(self, state_dict: dict) -> None: r""" - Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. + + Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ @@ -385,6 +405,10 @@ def is_master(args): return args.rank == 0 +def is_local_master(args): + return args.local_rank == 0 + + def init_distributed_device(args): # Distributed training = training on more than one GPU. # Works in both single and multi-node scenarios. @@ -806,3 +830,178 @@ def construct(self, *inputs): loss = self.unscale_loss(outputs[0]) outputs = (loss,) + outputs[1:] return outputs + + +@ms.jit_class +class pynative_no_grad(contextlib.ContextDecorator): + """ + Context Manager to disable gradient calculation. When enter this context, we will disable calculate + gradient. When exit this context, we will resume its prev state. + Currently, it can use both in Pynative and Graph mode. It also can be used as decorator. + + For mindone.diffusers, it is used in PyNative training to decorate the part of calculation that + does not require gradients, e.g. vae.encode_images or text_encoder.encode_prompts where does not + need to train VAE or text-encoders. + """ + + def __init__(self): + self.is_pynative_mode = context.get_context("mode") == context.PYNATIVE_MODE or os.getenv("MS_JIT") == "0" + self.prev_state = False + + def __enter__(self): + if self.is_pynative_mode: + self.prev_state = _pynative_executor.enable_grad() + _pynative_executor.set_enable_grad(False) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.is_pynative_mode: + _pynative_executor.set_enable_grad(self.prev_state) + return False + + +# Adapted from mindone.trainers.zero.prepare_train_network +def prepare_train_network( + network: nn.Cell, + optimizer: nn.Optimizer, + scale_sense: Union[float, nn.Cell] = 1.0, + updates: int = 0, + drop_overflow_update: bool = True, + gradient_accumulation_steps: int = 1, + clip_grad: bool = False, + clip_norm: float = 1.0, + verbose: bool = False, + zero_stage: int = 0, + optimizer_offload: bool = False, + optimizer_parallel_group: str = None, + dp_group: str = None, + comm_fusion: dict = None, + parallel_modules=None, +): + """ + Prepare network and optimizer for distributed training. + + Args: + network (`nn.Cell`): train network, not include grad function, + grad function must be built after rewrite train network. + optimizer (`nn.Optimizer`): Must be the subclass of MindSpore Optimizer. + scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called + to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`, + the shape should be :math:`()` or :math:`(1,)`. + zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. + optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. + comm_fusion (`dict`, *optional*): A dict contains the types and configurations + for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, + turn on the communication fusion. + Examples: {"allreduce": {"openstate": True, "bucket_size": 5e8}, + "reduce_scatter": {"openstate": True, "bucket_size": 5e8}, + "allgather": {"openstate": False, "bucket_size": 5e8},} + parallel_modules (`dict`, *optional*): A dict of Cells could split parameters in zero3, default is None. + If None, use `PARALLEL_MODULES` from `mindone.models.modules.parallel`. + """ + if zero_stage not in [0, 1, 2, 3]: + raise ValueError("Not support zero_stage {zero_stage}") + if optimizer_parallel_group is None: + logger.warning("Not set zero group, set it WORLD_COMM_GROUP.") + optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP + if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: + raise ValueError( + "optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage" + ) + + is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL + if not is_parallel and zero_stage == 0: + logger.info("No need prepare train_network with zero.") + zero_helper = None + else: + network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules) + zero_helper = ZeroHelper( + optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion + ) + + if isinstance(scale_sense, float): + scale_sense = ms.Tensor(scale_sense, ms.float32) + train_network = DiffusersTrainOneStepWrapper( + network, + optimizer, + scale_sense=scale_sense, + updates=updates, + drop_overflow_update=drop_overflow_update, + gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + clip_norm=clip_norm, + verbose=verbose, + zero_helper=zero_helper, + ) + return train_network + + +class DiffusersTrainOneStepWrapper(TrainOneStepWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def use_zero(self): + return self.zero_helper is not None and self.zero_stage != 0 + + def need_save_optimizer(self, args): + # TODO: Now we save optimizer in every process, try to save depend on self.zero_helper.optimizer_parallel_group + return True if self.use_zero else is_local_master(args) + + def save_state(self, args, output_dir, optimizer_state_filter=lambda x: True): + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving current state to {output_dir}") + + # Optimizer states + if self.use_zero: + os.makedirs(os.path.join(output_dir, "mindspore_model"), exist_ok=True) + optimizer_file = os.path.join(output_dir, "mindspore_model", f"zero_pp_{args.local_rank}_optim_states.ckpt") + elif self.need_save_optimizer(args): + optimizer_file = os.path.join(output_dir, "optimizer.ckpt") + ms.save_checkpoint(self.optimizer, optimizer_file, choice_func=optimizer_state_filter) + + # Loss Scaler states + loss_scaler_file = os.path.join(output_dir, "loss_scaler.ckpt") + loss_scaler_states = {"scale_sense": self.scale_sense} + if self.loss_scaling_manager: + loss_scaler_states.update( + { + "cur_iter": self.loss_scaling_manager.cur_iter, + "last_overflow_iter": self.loss_scaling_manager.last_overflow_iter, + } + ) + ms.save_checkpoint(loss_scaler_states, loss_scaler_file) + + def load_state(self, args, input_dir, optimizer_state_filter=lambda x: True): + # Optimizer states + optimizer_file = ( + os.path.join(input_dir, "mindspore_model", f"zero_pp_{args.local_rank}_optim_states.ckpt") + if self.use_zero + else os.path.join(input_dir, "optimizer.ckpt") + ) + optimizer_state_dict = ms.load_checkpoint(optimizer_file) + + with silence_mindspore_logger(): + param_not_load, ckpt_not_load = ms.load_param_into_net( + self.optimizer, optimizer_state_dict, strict_load=True + ) + + param_not_load = list(filter(lambda x: optimizer_state_filter(x), param_not_load)) + if param_not_load or ckpt_not_load: + logger.warning( + f"Loading checkpoint into optimizer returns param_not_load:{param_not_load} \nand ckpt_not_load:{ckpt_not_load}" + ) + + # Loss Scaler states + loss_scaler_file = os.path.join(input_dir, "loss_scaler.ckpt") + loss_scaler_state_dict = ms.load_checkpoint(loss_scaler_file) + + scale_sense = loss_scaler_state_dict.get("scale_sense", ms.Tensor(1.0, dtype=mstype.float32)) + cur_iter = loss_scaler_state_dict.get("cur_iter", None) + last_overflow_iter = loss_scaler_state_dict.get("last_overflow_iter", None) + + ops.assign(self.scale_sense, scale_sense) + if cur_iter is not None and last_overflow_iter is not None: + ops.assign(self.loss_scaling_manager.cur_iter, cur_iter) + ops.assign(self.loss_scaling_manager.last_overflow_iter, last_overflow_iter) diff --git a/mindone/models/modules/parallel/__init__.py b/mindone/models/modules/parallel/__init__.py index e3f77a9537..5240aeb9c6 100644 --- a/mindone/models/modules/parallel/__init__.py +++ b/mindone/models/modules/parallel/__init__.py @@ -1,7 +1,7 @@ -from mindspore import nn +from mindspore import mint, nn -from .conv import Conv1d, Conv2d, Conv3d -from .dense import Dense +from .conv import Conv1d, Conv2d, Conv3d, Mint_Conv2d, Mint_Conv3d +from .dense import Dense, Linear # {Original MindSpore Cell: New Cell in ZeRO3} PARALLEL_MODULES = { @@ -9,5 +9,9 @@ nn.Conv2d: Conv2d, nn.Conv3d: Conv3d, nn.Dense: Dense, + mint.nn.Conv2d: Mint_Conv2d, + mint.nn.Conv3d: Mint_Conv3d, + mint.nn.Linear: Linear, } -__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense"] + +__all__ = ["Conv1d", "Conv2d", "Conv3d", "Mint_Conv2d", "Mint_Conv3d", "Dense", "Linear"] diff --git a/mindone/models/modules/parallel/conv.py b/mindone/models/modules/parallel/conv.py index e47f958456..b8bd00be84 100644 --- a/mindone/models/modules/parallel/conv.py +++ b/mindone/models/modules/parallel/conv.py @@ -1,4 +1,4 @@ -from mindspore import nn, ops +from mindspore import mint, nn, ops from mindspore.communication import get_group_size, get_rank from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode @@ -8,25 +8,35 @@ class _Conv(nn.Cell): - def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None): + def __init__( + self, net, zero_stage: int = 0, optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None + ): super(_Conv, self).__init__(auto_prefix=False) self.net = net - self.set_param_wrapper(zero_stage, op_group, cell_type) + self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type) - def set_param_wrapper(self, zero_stage, op_group, cell_type=None): + @property + def weight(self): + return self.net.weight + + @property + def bias(self): + return self.net.bias + + def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None): self.param_wrapper_w = nn.Identity() self.param_wrapper_b = nn.Identity() if zero_stage == 3: # Init parallel settings is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - op_group_size = get_group_size(op_group) if is_parallel else 1 - op_rank_id = get_rank(op_group) if is_parallel else 0 - self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type) + op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 + op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 + self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type) split_op = ops.Split(0, op_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) - if self.net.has_bias: - self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) + if self.net.bias: + self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, optimizer_parallel_group, cell_type) if self.param_wrapper_b.need_rewrite: self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) @@ -71,3 +81,45 @@ def construct(self, x): new_shape[1] = self.net.out_channels out = out + bias.reshape(new_shape) return out + + +class Mint_Conv2d(_Conv): + def construct(self, x): + weight = self.param_wrapper_w(self.net.weight) + bias = self.param_wrapper_b(self.net.bias) + if self.net.padding_mode != "zeros": + output = self.net.conv2d( + mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode), + weight, + bias, + self.net.stride, + (0, 0), + self.net.dilation, + self.net.groups, + ) + else: + output = self.net.conv2d( + input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups + ) + return output + + +class Mint_Conv3d(_Conv): + def construct(self, x): + weight = self.param_wrapper_w(self.net.weight) + bias = self.param_wrapper_b(self.net.bias) + if self.net.padding_mode != "zeros": + output = self.net.conv3d( + mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode), + weight, + bias, + self.net.stride, + (0, 0, 0), + self.net.dilation, + self.net.groups, + ) + else: + output = self.net.conv3d( + input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups + ) + return output diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py index 66ef7fef71..dc24ef0bba 100644 --- a/mindone/models/modules/parallel/dense.py +++ b/mindone/models/modules/parallel/dense.py @@ -1,4 +1,8 @@ -from mindspore import nn, ops +from typing import Literal, Optional, Union + +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore import mint, nn, ops from mindspore.communication import get_group_size, get_rank from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode @@ -8,25 +12,39 @@ class Dense(nn.Cell): - def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None): - super(Dense, self).__init__(auto_prefix=False) + def __init__( + self, + net: Union[nn.Dense, mint.nn.Linear], + zero_stage: Literal[0, 1, 2, 3] = 0, + optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, + cell_type: Optional[mstype.Type] = None, + ): + super().__init__(auto_prefix=False) self.net = net - self.set_param_wrapper(zero_stage, op_group, cell_type) + self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type) + + @property + def weight(self): + return self.net.weight + + @property + def bias(self): + return self.net.bias - def set_param_wrapper(self, zero_stage, op_group, cell_type=None): + def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None): self.param_wrapper_w = nn.Identity() self.param_wrapper_b = nn.Identity() if zero_stage == 3: # Init parallel settings is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - op_group_size = get_group_size(op_group) if is_parallel else 1 - op_rank_id = get_rank(op_group) if is_parallel else 0 - self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type) + op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 + op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 + self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type) split_op = ops.Split(0, op_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) if self.net.has_bias: - self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) + self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, optimizer_parallel_group, cell_type) if self.param_wrapper_b.need_rewrite: self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) @@ -43,3 +61,8 @@ def construct(self, x): out_shape = x_shape[:-1] + (x.shape[-1],) x = x.reshape(out_shape) return x + + +class Linear(Dense): + def construct(self, x: Tensor) -> Tensor: + return self.net.dense(x, self.param_wrapper_w(self.net.weight), self.param_wrapper_b(self.net.bias)) diff --git a/mindone/models/modules/parallel/param_wrapper.py b/mindone/models/modules/parallel/param_wrapper.py index 1ca8d753b7..b007b97ab7 100644 --- a/mindone/models/modules/parallel/param_wrapper.py +++ b/mindone/models/modules/parallel/param_wrapper.py @@ -12,10 +12,14 @@ class ZeroParamWrapper(nn.Cell): """ def __init__( - self, param: ms.Parameter, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None + self, + param: ms.Parameter, + zero_stage: int = 0, + optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, + cell_type=None, ): super().__init__(auto_prefix=False) - self.op_group = op_group + self.optimizer_parallel_group = optimizer_parallel_group self.zero_stage = zero_stage self.cell_type = cell_type if zero_stage != 3: @@ -23,16 +27,16 @@ def __init__( # Init parallel settings self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1 + self.op_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 self.allgather = ops.Identity() self.reduce_scatter = None self.dtype = param.dtype - self.allreduce = ops.AllReduce(group=self.op_group, op=ops.ReduceOp.SUM) + self.allreduce = ops.AllReduce(group=self.optimizer_parallel_group, op=ops.ReduceOp.SUM) self.need_rewrite = self.check_rewrite(param) if self.need_rewrite: - self.op_allgather = ops.AllGather(group=self.op_group) - self.op_reduce_scatter = ops.ReduceScatter(group=self.op_group, op=ops.ReduceOp.SUM) + self.op_allgather = ops.AllGather(group=self.optimizer_parallel_group) + self.op_reduce_scatter = ops.ReduceScatter(group=self.optimizer_parallel_group, op=ops.ReduceOp.SUM) def check_rewrite(self, param): """Check the parameter need to split or not.""" @@ -40,6 +44,7 @@ def check_rewrite(self, param): B = param.shape[0] if not param.parallel_optimizer or B < self.op_group_size or B % self.op_group_size != 0: need_rewrite = False + param.parallel_optimizer = need_rewrite return need_rewrite def construct(self, param): diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index a5fdee0bcb..a553342317 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -1,13 +1,16 @@ import logging import os import time -from typing import List +from typing import List, Literal, Optional, Tuple, Union import mindspore as ms +from mindspore import Profiler, Tensor, nn, ops, save_checkpoint from mindspore.communication import get_rank +from mindspore.communication.management import GlobalComm from mindspore.train.callback._callback import Callback, _handle_loss from .checkpoint import CheckpointManager +from .ema import EMA from .recorder import PerfRecorder _logger = logging.getLogger("") @@ -23,7 +26,7 @@ def get_real_rank(): return int(os.getenv("RANK_ID", "0")) -class OverflowMonitor(ms.Callback): +class OverflowMonitor(Callback): def on_train_step_end(self, run_context): cb_params = run_context.original_args() cur_epoch_num = cb_params.get("cur_epoch_num", 1) @@ -38,40 +41,66 @@ def on_train_step_end(self, run_context): class EvalSaveCallback(Callback): def __init__( self, - network, - use_lora=False, - rank_id=0, - ckpt_save_dir="./", - output_dir=None, - ema=None, - save_ema_only=True, - ckpt_save_policy="lastest_k", - ckpt_max_keep=10, - step_mode=False, - ckpt_save_interval=1, - use_step_unit=False, - data_sink_mode=True, - lora_rank=None, - log_interval=1, - start_epoch=0, - record_lr=True, - model_name="sd", + network: nn.Cell, + use_lora: bool = False, + rank_id: int = 0, + ckpt_save_dir: str = "./", + output_dir: str = None, + ema: EMA = None, + save_ema_only: bool = True, + ckpt_save_policy: Literal["top_k", "latest_k", None] = "latest_k", + monitor_metric: Optional[str] = None, + ckpt_max_keep: int = 10, + step_mode: bool = False, + ckpt_save_interval: int = 1, + use_step_unit: bool = False, + data_sink_mode: bool = True, + lora_rank: Optional[int] = None, + log_interval: int = 1, + start_epoch: int = 0, + record_lr: bool = True, + model_name: str = "sd", save_trainable_only: bool = False, param_save_filter: List[str] = None, - resume_prefix_blacklist: List[str] = None, - integrated_save=False, - save_training_resume=True, - train_steps=-1, + resume_prefix_blacklist: Optional[Union[str, Tuple[str, ...]]] = None, + integrated_save: bool = False, + save_training_resume: bool = True, + train_steps: int = -1, + prefer_low_perf: bool = False, + zero_stage: int = 0, + optimizer_parallel_group: str = None, + ckpt_combine_online: bool = False, ): """ Args: step_mode: if True, ckpt_save_interval is counted in steps. otherwise, in epochs. param_save_filter: indicates what parameters to save in checkpoint. If None, save all parameters in network. \ Otherwise, only params that contain one of the keyword in param_save_filter list will be saved. - resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint. e.g. ['swap.', 'vae.']. + resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint, + e.g. ('swap.', 'vae.'). + zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + ckpt_combine_online (`bool`, *optional*): combining trainable parameters for saving checkpoint when zero_stage=3, \ + using allgather ops to combile the checkpoint online if `ckpt_combine_online=True`, \ + saving all device parameters if `ckpt_combine_online=False`, \ + and need to use `convert_checkpoints` to combile the checkpoint offline. default is False. """ self.rank_id = rank_id self.is_main_device = rank_id in [0, None] + self.use_zero = zero_stage in [1, 2, 3] + self.ckpt_combine_online = (zero_stage == 3) and ckpt_combine_online + if self.ckpt_combine_online and self.ema is not None: + _logger.warning("Can not enable ckpt_combine_online when use ema, set `ckpt_combine_online=False`.") + self.ckpt_combine_online = False + + self.need_save_network = self.is_main_device or (zero_stage == 3 and not self.ckpt_combine_online) + self.need_save_optimizer = self.is_main_device or self.use_zero + if self.use_zero: + if optimizer_parallel_group is None: + _logger.warning("EvalSaveCallback not set zero group, set it WORLD_COMM_GROUP.") + optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP + self.optimizer_parallel_group = optimizer_parallel_group + self.op_rank_id = get_rank(optimizer_parallel_group) self.ema = ema if output_dir is not None: self.output_dir = output_dir @@ -93,13 +122,15 @@ def __init__( self.record_lr = record_lr self.save_ema_only = save_ema_only - if self.is_main_device: + if self.need_save_network: self.ckpt_save_policy = ckpt_save_policy + self.monitor_metric = monitor_metric self.ckpt_manager = CheckpointManager( ckpt_save_dir, ckpt_save_policy, k=ckpt_max_keep, integrated_save=integrated_save, + prefer_low_perf=prefer_low_perf, ) if self.start_epoch == 0: if self.record_lr: @@ -131,17 +162,22 @@ def __init__( self.use_step_unit = use_step_unit self.train_steps = train_steps self.save_training_resume = save_training_resume - if resume_prefix_blacklist is not None: - - def choice_func(x): - for prefix in resume_prefix_blacklist: - if x.startswith("vae."): - return False - return True - - self.choice_func = choice_func - else: - self.choice_func = None + self.choice_func = None + if resume_prefix_blacklist: + if isinstance(resume_prefix_blacklist, str): + resume_prefix_blacklist = (resume_prefix_blacklist,) + self.choice_func = lambda x: not x.startswith(resume_prefix_blacklist) + + def _do_ckpt_combine_online(self): + new_net_to_save = [] + all_gather_op = ops.AllGather(self.optimizer_parallel_group) + for param in self.net_to_save: + if param.parallel_optimizer: + new_data = ms.Tensor(all_gather_op(param).asnumpy()) + else: + new_data = ms.Tensor(param.asnumpy()) + new_net_to_save.append({"name": param.name, "data": new_data}) + return new_net_to_save def on_train_step_end(self, run_context): cb_params = run_context.original_args() @@ -158,7 +194,31 @@ def on_train_step_end(self, run_context): else: cur_epoch = cb_params.cur_epoch_num - 1 - if self.is_main_device: + if self.save_training_resume and self.need_save_optimizer: + # TODO: resume training for step. + ckpt_name = f"train_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "train_resume.ckpt" + save_checkpoint( + cb_params.train_network, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + append_dict={ + "epoch_num": cur_epoch, + "cur_step": cur_step, + "loss_scale": self._get_scaling_value_from_cbp(cb_params), + }, + ) + if self.ema is not None: + ckpt_name = f"ema_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "ema_resume.ckpt" + save_checkpoint( + self.ema, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + ) + + if self.ckpt_combine_online: + new_net_to_save = self._do_ckpt_combine_online() + + if self.need_save_network: # if data sink, train step callback will not be invokded if self.step_mode and (cur_step % self.ckpt_save_interval == 0 or cur_step == step_num): ckpt_name = ( @@ -166,34 +226,29 @@ def on_train_step_end(self, run_context): if self.use_step_unit else f"{self.model_name}-e{cur_epoch}.ckpt" ) + if self.use_zero and not self.ckpt_combine_online: + file_extension = os.path.splitext(ckpt_name) + ckpt_name = f"{file_extension[0]}_op_rank_{self.op_rank_id}{file_extension[1]}" append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None - if self.ema is not None: - if not self.save_ema_only: - self.ckpt_manager.save( - self.net_to_save, - None, - ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), - append_dict=append_dict, - ) - # swap ema weight and network weight - self.ema.swap_before_eval() - - # save history checkpoints - self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict) - - if self.save_training_resume: - # TODO: resume training for step. - ms.save_checkpoint( - cb_params.train_network, - os.path.join(self.ckpt_save_dir, "train_resume.ckpt"), - choice_func=self.choice_func, - append_dict={ - "epoch_num": cur_epoch, - "cur_step": cur_step, - "loss_scale": self._get_scaling_value_from_cbp(cb_params), - }, - ) + perf = cb_params.get("eval_results") + net_to_save = new_net_to_save if self.ckpt_combine_online else self.net_to_save + if perf or self.ckpt_save_policy != "top_k": + if perf: + perf = perf[self.monitor_metric] + if self.ema is not None: + if not self.save_ema_only: + self.ckpt_manager.save( + self.net_to_save, + perf, + ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), + append_dict=append_dict, + ) + # swap ema weight and network weight + self.ema.swap_before_eval() + + # save history checkpoints + self.ckpt_manager.save(net_to_save, perf, ckpt_name=ckpt_name, append_dict=append_dict) # swap back network weight and ema weight. MUST execute after model saving and before next-step training if self.ema is not None: @@ -256,15 +311,42 @@ def on_train_epoch_end(self, run_context): opt = self._get_optimizer_from_cbp(cb_params) cur_step = int(opt.global_step.asnumpy().item()) - if self.is_main_device and (not self.step_mode): + if self.save_training_resume and self.need_save_optimizer: + # TODO: resume training for step. + ckpt_name = f"train_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "train_resume.ckpt" + save_checkpoint( + cb_params.train_network, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + append_dict={ + "epoch_num": cur_epoch, + "loss_scale": self._get_scaling_value_from_cbp(cb_params), + }, + ) + if self.ema is not None: + ckpt_name = f"ema_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "ema_resume.ckpt" + save_checkpoint( + self.ema, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + ) + + if self.ckpt_combine_online: + new_net_to_save = self._do_ckpt_combine_online() + + if self.need_save_network and (not self.step_mode): if (cur_epoch % self.ckpt_save_interval == 0) or (cur_epoch == epoch_num): ckpt_name = ( f"{self.model_name}-s{cur_step}.ckpt" if self.use_step_unit else f"{self.model_name}-e{cur_epoch}.ckpt" ) + if self.use_zero and not self.ckpt_combine_online: + file_extension = os.path.splitext(ckpt_name) + ckpt_name = f"{file_extension[0]}_op_rank_{self.op_rank_id}{file_extension[1]}" append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None + net_to_save = new_net_to_save if self.ckpt_combine_online else self.net_to_save if self.ema is not None: if not self.save_ema_only: self.ckpt_manager.save( @@ -277,18 +359,9 @@ def on_train_epoch_end(self, run_context): self.ema.swap_before_eval() # save history checkpoints - self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict) - - if self.save_training_resume: - ms.save_checkpoint( - cb_params.train_network, - os.path.join(self.ckpt_save_dir, "train_resume.ckpt"), - choice_func=self.choice_func, - append_dict={ - "epoch_num": cur_epoch, - "loss_scale": self._get_scaling_value_from_cbp(cb_params), - }, - ) + self.ckpt_manager.save( + net_to_save, perf=cb_params["net_outputs"], ckpt_name=ckpt_name, append_dict=append_dict + ) # swap back network weight and ema weight. MUST execute after model saving and before next-step training if self.ema is not None: @@ -299,16 +372,16 @@ def on_train_epoch_end(self, run_context): def on_train_end(self, run_context): if self.is_main_device: if self.ckpt_save_policy == "top_k": - log_str = f"Top K checkpoints:\n{self.main_indicator}\tcheckpoint\n" + log_str = f"Top K checkpoints: \n{self.main_indicator}\tcheckpoint\n" for p, ckpt_name in self.ckpt_manager.get_ckpt_queue(): - log_str += f"{p:.4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n" + log_str += f"{p: .4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n" def on_eval_end(self, run_context): if self.is_main_device: cb_params = run_context.original_args() metrics = cb_params.get("metrics") if metrics is not None: - metrics = {k: f"{v:.4f}" for k, v in metrics.items()} + metrics = {k: f"{v: .4f}" for k, v in metrics.items()} _logger.info(f"Eval result epoch {cb_params.cur_epoch_num}: {metrics}") def _get_optimizer_from_cbp(self, cb_params): @@ -326,7 +399,7 @@ def _get_scaling_value_from_cbp(self, cb_params): else: return cb_params.train_network.scale_sense.asnumpy().item() - def _fetch_optimizer_lr(self, cb_params) -> ms.Tensor: + def _fetch_optimizer_lr(self, cb_params) -> Tensor: opt = self._get_optimizer_from_cbp(cb_params) lr = opt.learning_rate if opt.dynamic_lr: @@ -334,7 +407,7 @@ def _fetch_optimizer_lr(self, cb_params) -> ms.Tensor: return lr -class StopAtStepCallback(ms.Callback): +class StopAtStepCallback(Callback): # stop the training process when reach train_steps def __init__(self, train_steps, global_step=0): self.global_step = global_step @@ -346,7 +419,7 @@ def on_train_step_end(self, run_context): run_context.request_stop() -class ProfilerCallback(ms.Callback): +class ProfilerCallback(Callback): def __init__(self, start_step=1, end_step=2, exit_after_analyze=True, out_dir="./profiler_data"): self.start_step = start_step self.end_step = end_step @@ -355,7 +428,7 @@ def __init__(self, start_step=1, end_step=2, exit_after_analyze=True, out_dir=". out_dir = os.path.join(out_dir, f"rank_{rank_id}") # If value of profile_framework is not None, a subdirectory named host_info will be generated under the # specified profiler directory to store the collected memory and time files on the Host side. - self.profiler = ms.Profiler( + self.profiler = Profiler( start_profile=False, output_path=out_dir, profile_framework="all", data_simplication=False ) @@ -377,12 +450,12 @@ def on_train_step_end(self, run_context): run_context.request_stop() -class ProfilerCallbackEpoch(ms.Callback): +class ProfilerCallbackEpoch(Callback): def __init__(self, start_epoch, stop_epoch, output_dir="./profiler_data"): super().__init__() self.start_epoch = start_epoch self.stop_epoch = stop_epoch - self.profiler = ms.Profiler(start_profile=False, output_path=output_dir) + self.profiler = Profiler(start_profile=False, output_path=output_dir) def on_train_epoch_begin(self, run_context): cb_params = run_context.original_args() diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 8dd6f32c85..312a3b53cf 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -84,7 +84,7 @@ class ZeroHelper: Args: optimizer (`nn.Optimizer`): Must be the subclass of MindSpore Optimizer. zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. - op_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. comm_fusion (`dict`, *optional*): A dict contains the types and configurations @@ -101,7 +101,7 @@ def __init__( self, optimizer: nn.Optimizer, zero_stage: int = 0, - op_group: str = None, + optimizer_parallel_group: str = None, dp_group: str = None, optimizer_offload: bool = False, comm_fusion: dict = None, @@ -109,7 +109,7 @@ def __init__( ): self.optimizer = optimizer self.zero_stage = zero_stage - self.op_group = op_group + self.optimizer_parallel_group = optimizer_parallel_group if isinstance(optimizer, ms.experimental.optim.optimizer.Optimizer): self.optimizer._parameters = self.optimizer.parameters self.ori_parameters = self.optimizer._parameters @@ -123,8 +123,8 @@ def __init__( self.op_reduce_scatter = ops.Identity() self.op_allreduce = ops.Identity() self.dp_allreduce = ops.Identity() - self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1 - self.op_rank_id = get_rank(self.op_group) if self.is_parallel else 0 + self.op_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 + self.op_rank_id = get_rank(self.optimizer_parallel_group) if self.is_parallel else 0 self.need_dp = False self.dp_group = dp_group self.last_assign = False @@ -172,14 +172,14 @@ def __init__( def set_comm_ops( self, ): - self.op_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) - self.op_reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.op_group) + self.op_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) + self.op_reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) # AllGather the parameters after optimizer calculate to update the parameters in train network. - self.op_allgather = ops.AllGather(group=self.op_group) + self.op_allgather = ops.AllGather(group=self.optimizer_parallel_group) self.need_dp = self.dp_group is not None if self.need_dp: - # Set it when op_group is not the WORLD_COMM_GROUP. + # Set it when optimizer_parallel_group is not the WORLD_COMM_GROUP. self.dp_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.dp_group) self.dp_group_size = ms.Tensor(get_group_size(group=self.dp_group), ms.float32) @@ -200,7 +200,7 @@ def set_zero1_allreduce_fusion_comm_list(self, comm_fusion): param_size = param.itemsize * param.size param_name = param.name self.update_comm_op_info(allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name) - comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) + comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", allreduce_info[-1]["fusion_id"]) self.zero1_allreduce_list.append(comm_op) _logger.info(f"zero1_allreduce_fusion: {allreduce_info}") @@ -223,11 +223,11 @@ def set_zero2_reduce_scatter_fusion_comm_list(self, comm_fusion): self.update_comm_op_info( allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name ) - comm_op = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.op_group) + comm_op = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", reduce_scatter_info[-1]["fusion_id"]) self.zero2_reduce_scatter_list.append(comm_op) - comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) + comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", allreduce_info[-1]["fusion_id"]) self.zero2_allreduce_list.append(comm_op) _logger.info(f"zero2_reduce_scatter_fusion: {reduce_scatter_info}") @@ -244,7 +244,7 @@ def set_optimizer_allgather_fusion_comm_list(self, comm_fusion): self.update_comm_op_info( allgather_info, comm_fusion["allgather"]["bucket_size"], param_size, param_name ) - comm_op = ops.AllGather(group=self.op_group) + comm_op = ops.AllGather(group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", allgather_info[-1]["fusion_id"]) self.optimizer_allgather_list.append(comm_op) _logger.info(f"optimizer_allgather_fusion: {allgather_info}") @@ -260,7 +260,7 @@ def set_dp_allreduce_comm_list(self, comm_fusion): self.update_comm_op_info( dp_allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name ) - comm_op = ops.AllGather(group=self.op_group) + comm_op = ops.AllGather(group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", dp_allreduce_info[-1]["fusion_id"]) self.dp_allreduce_list.append(comm_op) _logger.info(f"dp_allreduce_fusion: {dp_allreduce_info}") @@ -301,22 +301,20 @@ def dump_params_split_info(self, params_split_info): def get_need_parameter_split(self): self.need_parameter_split = [False] * len(self.optimizer._parameters) - param_tuples = self.get_optimizer_param_tuples() for i, param in enumerate(self.optimizer._parameters): if self.zero_stage == 3: - if param_tuples: - B = param_tuples[0][i].shape[0] - else: - continue + self.need_parameter_split[i] = param.parallel_optimizer else: B = param.shape[0] - if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: - if self.zero_stage in [1, 2]: - self.need_parameter_split[i] = True + if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: + if self.zero_stage in [1, 2]: + self.need_parameter_split[i] = True self.need_parameter_split = tuple(self.need_parameter_split) def split_params(self): - if self.zero_stage in [1, 2] and self.is_parallel: + if not (self.zero_stage in [1, 2, 3] and self.is_parallel): + return + if self.zero_stage in [1, 2]: _logger.info("Clone optimizer.parameters, will increase memory.") # Because the first input of MindSpore optimizer must be ms.Parameter, # copy optimizer.parameters for optimizer parameters update. @@ -330,15 +328,14 @@ def split_params(self): _logger.debug(f"Split optimizer param {param.name} {param.shape}") # If zero_stage is 3, the parameters in train network have been split, # use parameter in param_tuples to get batch size. - if self.zero_stage == 3: - if param_tuples: - B = param_tuples[0][i].shape[0] - else: - continue - else: - B = param.shape[0] _logger.debug(f"Do split with zero_stage {self.zero_stage}") - if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: + if self.zero_stage in [1, 2]: + B = param.shape[0] + if self.ori_parameters[i] and B >= self.op_group_size and B % self.op_group_size == 0: + param.parallel_optimizer = True + else: + param.parallel_optimizer = False + if param.parallel_optimizer: if self.zero_stage in [1, 2]: ori_shape = param.shape param.assign_value(self.split_param(param)) @@ -346,7 +343,7 @@ def split_params(self): for param_tuple in param_tuples: ori_shape = param_tuple[i].shape param_tuple[i].assign_value(self.split_param(param_tuple[i])) - _logger.debug(f"Optimizer {param_tuple[i].name} " f"from {ori_shape} to {param_tuple[i].shape}") + _logger.debug(f"Optimizer {param_tuple[i].name} from {ori_shape} to {param_tuple[i].shape}") def reduce_scatter_gradients(self, gradients): dtype = gradients[0].dtype @@ -469,11 +466,11 @@ def get_cell_dtype(cell): return None -def _init_parallel_settings(net, op_group, parallel_modules=None): +def _init_parallel_settings(net, optimizer_parallel_group, parallel_modules=None): for module, parallel_module in parallel_modules.items(): if isinstance(net, module): cell_type = get_cell_dtype(net) - new_net = parallel_module(net, 3, op_group) + new_net = parallel_module(net, 3, optimizer_parallel_group) if cell_type is not None: new_net.to_float(cell_type) return new_net @@ -487,14 +484,14 @@ def get_cell_params_fullname_dict(cell: nn.Cell): return fullname_dict -def _prepare_network(network: nn.Cell, op_group: str, parallel_modules=None): - new_net = _init_parallel_settings(network, op_group, parallel_modules) +def _prepare_network(network: nn.Cell, optimizer_parallel_group: str, parallel_modules=None): + new_net = _init_parallel_settings(network, optimizer_parallel_group, parallel_modules) if new_net is not None: return new_net for name, sub_net in network._cells.items(): if not sub_net: continue - new_sub_net = _init_parallel_settings(sub_net, op_group, parallel_modules) + new_sub_net = _init_parallel_settings(sub_net, optimizer_parallel_group, parallel_modules) if new_sub_net is not None: params_fullname_dict = get_cell_params_fullname_dict(sub_net) if isinstance(network, (nn.CellList, nn.SequentialCell)): @@ -513,27 +510,27 @@ def _prepare_network(network: nn.Cell, op_group: str, parallel_modules=None): param = getattr(sub_net, param_name) _logger.warning(f"Set param {param.name} parallel_optimizer False, param shape {param.shape}") param.parallel_optimizer = False - _prepare_network(sub_net, op_group, parallel_modules) + _prepare_network(sub_net, optimizer_parallel_group, parallel_modules) return network -def prepare_network(network: nn.Cell, zero_stage: int = 0, op_group: str = None, parallel_modules=None): +def prepare_network(network: nn.Cell, zero_stage: int = 0, optimizer_parallel_group: str = None, parallel_modules=None): if zero_stage != 3 or _get_parallel_mode() != ParallelMode.DATA_PARALLEL: _logger.info("No need rewrite network and return original network.") return network _logger.info("Rewrite the network, please wait...") if parallel_modules is None: parallel_modules = PARALLEL_MODULES - network = _prepare_network(network, op_group, parallel_modules) + network = _prepare_network(network, optimizer_parallel_group, parallel_modules) return network -def prepare_ema(ema, zero_stage: int = 0, op_group: str = None): +def prepare_ema(ema, zero_stage: int = 0, optimizer_parallel_group: str = None): is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel or zero_stage != 3: return ema - op_group_size = get_group_size(op_group) - op_rank_id = get_rank(op_group) + op_group_size = get_group_size(optimizer_parallel_group) + op_rank_id = get_rank(optimizer_parallel_group) _logger.info(f"Split EMA params: rank_id {op_rank_id}, rank_size {op_group_size}.") for net_weight, ema_weight, swap_cache in zip(ema.net_weight, ema.ema_weight, ema.swap_cache): if net_weight.shape == ema_weight.shape: @@ -556,7 +553,7 @@ def prepare_train_network( verbose: bool = False, zero_stage: int = 0, optimizer_offload: bool = False, - op_group: str = None, + optimizer_parallel_group: str = None, dp_group: str = None, comm_fusion: dict = None, parallel_modules=None, @@ -573,7 +570,7 @@ def prepare_train_network( the shape should be :math:`()` or :math:`(1,)`. zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. - op_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. comm_fusion (`dict`, *optional*): A dict contains the types and configurations for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, @@ -586,22 +583,26 @@ def prepare_train_network( """ if zero_stage not in [0, 1, 2, 3]: raise ValueError("Not support zero_stage {zero_stage}") - if op_group is None: + if optimizer_parallel_group is None: _logger.warning("Not set zero group, set it WORLD_COMM_GROUP.") - op_group = GlobalComm.WORLD_COMM_GROUP - if op_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: - raise ValueError("op_group {op_group} and dp_group {dp_group} not full network hccl group coverage") + optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP + if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: + raise ValueError( + "optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage" + ) is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel and zero_stage == 0: _logger.info("No need prepare train_network with zero.") zero_helper = None else: - network = prepare_network(network, zero_stage, op_group, parallel_modules=parallel_modules) - zero_helper = ZeroHelper(optimizer, zero_stage, op_group, dp_group, optimizer_offload, comm_fusion) + network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules) + zero_helper = ZeroHelper( + optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion + ) if ema is not None: - ema = prepare_ema(ema, zero_stage, op_group) + ema = prepare_ema(ema, zero_stage, optimizer_parallel_group) if isinstance(scale_sense, float): scale_sense = ms.Tensor(scale_sense, ms.float32) train_network = TrainOneStepWrapper( @@ -620,7 +621,7 @@ def prepare_train_network( return train_network -def transform_checkpoints(src_checkpoint: str, src_param_split_info_json: str, group_size: int): +def convert_checkpoints(src_checkpoint: str, src_param_split_info_json: str, group_size: int): """ src_checkpoint (`str`): The path of checkpoints need to merge parameters. eg. "save_checkpoint_dir/ckpt_{}.ckpt", {} is placeholder of rank_id. diff --git a/tests/others/test_zero.py b/tests/others/test_zero.py index c9c99742c0..5d3a613c7f 100644 --- a/tests/others/test_zero.py +++ b/tests/others/test_zero.py @@ -78,7 +78,12 @@ def test_zero(x, y, zero_stage=0, comm_fusion=False): "allgather": {"bucket_size": 64}, } train_net = prepare_train_network( - net, opt, ema=ema, zero_stage=zero_stage, op_group=GlobalComm.WORLD_COMM_GROUP, comm_fusion=comm_fusion_dict + net, + opt, + ema=ema, + zero_stage=zero_stage, + optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP, + comm_fusion=comm_fusion_dict, ) for i in range(10): From 0f5064c7db94984e225499882c97430fc720df42 Mon Sep 17 00:00:00 2001 From: zhaoting <37293445+CaitinZhao@users.noreply.github.com> Date: Sat, 22 Feb 2025 10:37:37 +0800 Subject: [PATCH 2/2] fix a bug (#844) Co-authored-by: zhaoting --- docs/tools/zero.md | 3 +- .../opensora_pku/tools/ckpt/combine_ckpt.py | 2 +- mindone/diffusers/training_utils.py | 207 +-------------- mindone/models/modules/parallel/conv.py | 6 +- mindone/models/modules/parallel/dense.py | 4 +- .../models/modules/parallel/param_wrapper.py | 8 +- mindone/trainers/callback.py | 245 ++++++------------ mindone/trainers/zero.py | 44 ++-- 8 files changed, 126 insertions(+), 393 deletions(-) diff --git a/docs/tools/zero.md b/docs/tools/zero.md index ca96ea43cc..c25f1dac08 100644 --- a/docs/tools/zero.md +++ b/docs/tools/zero.md @@ -124,7 +124,8 @@ Inference need complete model parameters when use zero3. There are two ways(onli def do_ckpt_combine_online(net_to_save, optimizer_parallel_group): new_net_to_save = [] all_gather_op = ops.AllGather(optimizer_parallel_group) - for param in net_to_save: + for p in net_to_save: + param = p["data"] if param.parallel_optimizer: new_data = ms.Tensor(all_gather_op(param).asnumpy()) else: diff --git a/examples/opensora_pku/tools/ckpt/combine_ckpt.py b/examples/opensora_pku/tools/ckpt/combine_ckpt.py index b4110ea86d..57fe2d1fe4 100644 --- a/examples/opensora_pku/tools/ckpt/combine_ckpt.py +++ b/examples/opensora_pku/tools/ckpt/combine_ckpt.py @@ -25,7 +25,7 @@ def main(): else args.strategy_ckpt ) assert os.path.exists(strategy_file), f"{strategy_file} does not exist!" - ms.convert_checkpoints(args.src, args.dest, "full_", strategy_file, None) + ms.transform_checkpoints(args.src, args.dest, "full_", strategy_file, None) output_path = os.path.join(args.dest, "rank_0", "full_0.ckpt") assert os.path.isfile(output_path) diff --git a/mindone/diffusers/training_utils.py b/mindone/diffusers/training_utils.py index b14642955a..77d803f6f4 100644 --- a/mindone/diffusers/training_utils.py +++ b/mindone/diffusers/training_utils.py @@ -1,4 +1,3 @@ -import contextlib import copy import logging import math @@ -15,17 +14,9 @@ import mindspore as ms from mindspore import context, nn, ops from mindspore.amp import DynamicLossScaler, StaticLossScaler, all_finite -from mindspore.common import dtype as mstype -from mindspore.common.api import _pynative_executor from mindspore.communication import get_group_size, get_local_rank, get_rank, init -from mindspore.communication.management import GlobalComm -from mindspore.context import ParallelMode -from mindspore.parallel._utils import _get_parallel_mode from mindone.diffusers._peft import set_peft_model_state_dict -from mindone.diffusers.models.model_loading_utils import silence_mindspore_logger -from mindone.trainers.train_step import TrainOneStepWrapper -from mindone.trainers.zero import ZeroHelper, prepare_network from .models import UNet2DConditionModel from .schedulers import SchedulerMixin @@ -113,12 +104,6 @@ def compute_dream_and_update_latents( def cast_training_params(model: Union[nn.Cell, List[nn.Cell]], dtype=ms.float32): - """ - Casts the training parameters of the model to the specified data type. - Args: - model: The PyTorch model whose parameters will be cast. - dtype: The data type to which the model parameters will be cast. - """ if not isinstance(model, list): model = [model] for m in model: @@ -148,8 +133,7 @@ def _set_state_dict_into_text_encoder(lora_state_dict: Dict[str, ms.Tensor], pre def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): - """ - Compute the density for sampling the timesteps when doing SD3 training. + """Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -168,8 +152,7 @@ def compute_density_for_timestep_sampling( def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """ - Computes loss weighting scheme for SD3 training. + """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -327,9 +310,7 @@ def pin_memory(self) -> None: raise NotImplementedError("Not Implemeneted for `pin_memory`.") def to(self, dtype=None, non_blocking=False) -> None: - r""" - Move internal buffers of the ExponentialMovingAverage to `device`. - """ + r"""Move internal buffers of the ExponentialMovingAverage to `device`.""" # .to() on the tensors handles None correctly raise NotImplementedError("Not Implemeneted for `to`.") @@ -354,10 +335,9 @@ def state_dict(self) -> dict: def load_state_dict(self, state_dict: dict) -> None: r""" + Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. - - Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ @@ -405,10 +385,6 @@ def is_master(args): return args.rank == 0 -def is_local_master(args): - return args.local_rank == 0 - - def init_distributed_device(args): # Distributed training = training on more than one GPU. # Works in both single and multi-node scenarios. @@ -830,178 +806,3 @@ def construct(self, *inputs): loss = self.unscale_loss(outputs[0]) outputs = (loss,) + outputs[1:] return outputs - - -@ms.jit_class -class pynative_no_grad(contextlib.ContextDecorator): - """ - Context Manager to disable gradient calculation. When enter this context, we will disable calculate - gradient. When exit this context, we will resume its prev state. - Currently, it can use both in Pynative and Graph mode. It also can be used as decorator. - - For mindone.diffusers, it is used in PyNative training to decorate the part of calculation that - does not require gradients, e.g. vae.encode_images or text_encoder.encode_prompts where does not - need to train VAE or text-encoders. - """ - - def __init__(self): - self.is_pynative_mode = context.get_context("mode") == context.PYNATIVE_MODE or os.getenv("MS_JIT") == "0" - self.prev_state = False - - def __enter__(self): - if self.is_pynative_mode: - self.prev_state = _pynative_executor.enable_grad() - _pynative_executor.set_enable_grad(False) - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.is_pynative_mode: - _pynative_executor.set_enable_grad(self.prev_state) - return False - - -# Adapted from mindone.trainers.zero.prepare_train_network -def prepare_train_network( - network: nn.Cell, - optimizer: nn.Optimizer, - scale_sense: Union[float, nn.Cell] = 1.0, - updates: int = 0, - drop_overflow_update: bool = True, - gradient_accumulation_steps: int = 1, - clip_grad: bool = False, - clip_norm: float = 1.0, - verbose: bool = False, - zero_stage: int = 0, - optimizer_offload: bool = False, - optimizer_parallel_group: str = None, - dp_group: str = None, - comm_fusion: dict = None, - parallel_modules=None, -): - """ - Prepare network and optimizer for distributed training. - - Args: - network (`nn.Cell`): train network, not include grad function, - grad function must be built after rewrite train network. - optimizer (`nn.Optimizer`): Must be the subclass of MindSpore Optimizer. - scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called - to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`, - the shape should be :math:`()` or :math:`(1,)`. - zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. - optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. - optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. - dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. - comm_fusion (`dict`, *optional*): A dict contains the types and configurations - for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, - turn on the communication fusion. - Examples: {"allreduce": {"openstate": True, "bucket_size": 5e8}, - "reduce_scatter": {"openstate": True, "bucket_size": 5e8}, - "allgather": {"openstate": False, "bucket_size": 5e8},} - parallel_modules (`dict`, *optional*): A dict of Cells could split parameters in zero3, default is None. - If None, use `PARALLEL_MODULES` from `mindone.models.modules.parallel`. - """ - if zero_stage not in [0, 1, 2, 3]: - raise ValueError("Not support zero_stage {zero_stage}") - if optimizer_parallel_group is None: - logger.warning("Not set zero group, set it WORLD_COMM_GROUP.") - optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP - if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: - raise ValueError( - "optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage" - ) - - is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - if not is_parallel and zero_stage == 0: - logger.info("No need prepare train_network with zero.") - zero_helper = None - else: - network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules) - zero_helper = ZeroHelper( - optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion - ) - - if isinstance(scale_sense, float): - scale_sense = ms.Tensor(scale_sense, ms.float32) - train_network = DiffusersTrainOneStepWrapper( - network, - optimizer, - scale_sense=scale_sense, - updates=updates, - drop_overflow_update=drop_overflow_update, - gradient_accumulation_steps=gradient_accumulation_steps, - clip_grad=clip_grad, - clip_norm=clip_norm, - verbose=verbose, - zero_helper=zero_helper, - ) - return train_network - - -class DiffusersTrainOneStepWrapper(TrainOneStepWrapper): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def use_zero(self): - return self.zero_helper is not None and self.zero_stage != 0 - - def need_save_optimizer(self, args): - # TODO: Now we save optimizer in every process, try to save depend on self.zero_helper.optimizer_parallel_group - return True if self.use_zero else is_local_master(args) - - def save_state(self, args, output_dir, optimizer_state_filter=lambda x: True): - os.makedirs(output_dir, exist_ok=True) - logger.info(f"Saving current state to {output_dir}") - - # Optimizer states - if self.use_zero: - os.makedirs(os.path.join(output_dir, "mindspore_model"), exist_ok=True) - optimizer_file = os.path.join(output_dir, "mindspore_model", f"zero_pp_{args.local_rank}_optim_states.ckpt") - elif self.need_save_optimizer(args): - optimizer_file = os.path.join(output_dir, "optimizer.ckpt") - ms.save_checkpoint(self.optimizer, optimizer_file, choice_func=optimizer_state_filter) - - # Loss Scaler states - loss_scaler_file = os.path.join(output_dir, "loss_scaler.ckpt") - loss_scaler_states = {"scale_sense": self.scale_sense} - if self.loss_scaling_manager: - loss_scaler_states.update( - { - "cur_iter": self.loss_scaling_manager.cur_iter, - "last_overflow_iter": self.loss_scaling_manager.last_overflow_iter, - } - ) - ms.save_checkpoint(loss_scaler_states, loss_scaler_file) - - def load_state(self, args, input_dir, optimizer_state_filter=lambda x: True): - # Optimizer states - optimizer_file = ( - os.path.join(input_dir, "mindspore_model", f"zero_pp_{args.local_rank}_optim_states.ckpt") - if self.use_zero - else os.path.join(input_dir, "optimizer.ckpt") - ) - optimizer_state_dict = ms.load_checkpoint(optimizer_file) - - with silence_mindspore_logger(): - param_not_load, ckpt_not_load = ms.load_param_into_net( - self.optimizer, optimizer_state_dict, strict_load=True - ) - - param_not_load = list(filter(lambda x: optimizer_state_filter(x), param_not_load)) - if param_not_load or ckpt_not_load: - logger.warning( - f"Loading checkpoint into optimizer returns param_not_load:{param_not_load} \nand ckpt_not_load:{ckpt_not_load}" - ) - - # Loss Scaler states - loss_scaler_file = os.path.join(input_dir, "loss_scaler.ckpt") - loss_scaler_state_dict = ms.load_checkpoint(loss_scaler_file) - - scale_sense = loss_scaler_state_dict.get("scale_sense", ms.Tensor(1.0, dtype=mstype.float32)) - cur_iter = loss_scaler_state_dict.get("cur_iter", None) - last_overflow_iter = loss_scaler_state_dict.get("last_overflow_iter", None) - - ops.assign(self.scale_sense, scale_sense) - if cur_iter is not None and last_overflow_iter is not None: - ops.assign(self.loss_scaling_manager.cur_iter, cur_iter) - ops.assign(self.loss_scaling_manager.last_overflow_iter, last_overflow_iter) diff --git a/mindone/models/modules/parallel/conv.py b/mindone/models/modules/parallel/conv.py index b8bd00be84..6a1f6a2396 100644 --- a/mindone/models/modules/parallel/conv.py +++ b/mindone/models/modules/parallel/conv.py @@ -29,13 +29,13 @@ def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None if zero_stage == 3: # Init parallel settings is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 + optimizer_parallel_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type) - split_op = ops.Split(0, op_group_size) + split_op = ops.Split(0, optimizer_parallel_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) - if self.net.bias: + if self.net.bias is not None: self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, optimizer_parallel_group, cell_type) if self.param_wrapper_b.need_rewrite: self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py index dc24ef0bba..9182364d6d 100644 --- a/mindone/models/modules/parallel/dense.py +++ b/mindone/models/modules/parallel/dense.py @@ -37,10 +37,10 @@ def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None if zero_stage == 3: # Init parallel settings is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 + optimizer_parallel_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type) - split_op = ops.Split(0, op_group_size) + split_op = ops.Split(0, optimizer_parallel_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) if self.net.has_bias: diff --git a/mindone/models/modules/parallel/param_wrapper.py b/mindone/models/modules/parallel/param_wrapper.py index b007b97ab7..f5f31f723b 100644 --- a/mindone/models/modules/parallel/param_wrapper.py +++ b/mindone/models/modules/parallel/param_wrapper.py @@ -27,7 +27,7 @@ def __init__( # Init parallel settings self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - self.op_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 + self.optimizer_parallel_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 self.allgather = ops.Identity() self.reduce_scatter = None self.dtype = param.dtype @@ -42,7 +42,7 @@ def check_rewrite(self, param): """Check the parameter need to split or not.""" need_rewrite = self.is_parallel B = param.shape[0] - if not param.parallel_optimizer or B < self.op_group_size or B % self.op_group_size != 0: + if not param.parallel_optimizer or B < self.optimizer_parallel_group_size or B % self.optimizer_parallel_group_size != 0: need_rewrite = False param.parallel_optimizer = need_rewrite return need_rewrite @@ -56,7 +56,7 @@ def construct(self, param): def bprop(self, param, out, dout): if self.need_rewrite: - r = self.op_reduce_scatter(dout.to(self.dtype)) / self.op_group_size + r = self.op_reduce_scatter(dout.to(self.dtype)) / self.optimizer_parallel_group_size return (r,) - dout = self.allreduce(dout.to(self.dtype)) / self.op_group_size + dout = self.allreduce(dout.to(self.dtype)) / self.optimizer_parallel_group_size return (dout,) diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index a553342317..a5fdee0bcb 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -1,16 +1,13 @@ import logging import os import time -from typing import List, Literal, Optional, Tuple, Union +from typing import List import mindspore as ms -from mindspore import Profiler, Tensor, nn, ops, save_checkpoint from mindspore.communication import get_rank -from mindspore.communication.management import GlobalComm from mindspore.train.callback._callback import Callback, _handle_loss from .checkpoint import CheckpointManager -from .ema import EMA from .recorder import PerfRecorder _logger = logging.getLogger("") @@ -26,7 +23,7 @@ def get_real_rank(): return int(os.getenv("RANK_ID", "0")) -class OverflowMonitor(Callback): +class OverflowMonitor(ms.Callback): def on_train_step_end(self, run_context): cb_params = run_context.original_args() cur_epoch_num = cb_params.get("cur_epoch_num", 1) @@ -41,66 +38,40 @@ def on_train_step_end(self, run_context): class EvalSaveCallback(Callback): def __init__( self, - network: nn.Cell, - use_lora: bool = False, - rank_id: int = 0, - ckpt_save_dir: str = "./", - output_dir: str = None, - ema: EMA = None, - save_ema_only: bool = True, - ckpt_save_policy: Literal["top_k", "latest_k", None] = "latest_k", - monitor_metric: Optional[str] = None, - ckpt_max_keep: int = 10, - step_mode: bool = False, - ckpt_save_interval: int = 1, - use_step_unit: bool = False, - data_sink_mode: bool = True, - lora_rank: Optional[int] = None, - log_interval: int = 1, - start_epoch: int = 0, - record_lr: bool = True, - model_name: str = "sd", + network, + use_lora=False, + rank_id=0, + ckpt_save_dir="./", + output_dir=None, + ema=None, + save_ema_only=True, + ckpt_save_policy="lastest_k", + ckpt_max_keep=10, + step_mode=False, + ckpt_save_interval=1, + use_step_unit=False, + data_sink_mode=True, + lora_rank=None, + log_interval=1, + start_epoch=0, + record_lr=True, + model_name="sd", save_trainable_only: bool = False, param_save_filter: List[str] = None, - resume_prefix_blacklist: Optional[Union[str, Tuple[str, ...]]] = None, - integrated_save: bool = False, - save_training_resume: bool = True, - train_steps: int = -1, - prefer_low_perf: bool = False, - zero_stage: int = 0, - optimizer_parallel_group: str = None, - ckpt_combine_online: bool = False, + resume_prefix_blacklist: List[str] = None, + integrated_save=False, + save_training_resume=True, + train_steps=-1, ): """ Args: step_mode: if True, ckpt_save_interval is counted in steps. otherwise, in epochs. param_save_filter: indicates what parameters to save in checkpoint. If None, save all parameters in network. \ Otherwise, only params that contain one of the keyword in param_save_filter list will be saved. - resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint, - e.g. ('swap.', 'vae.'). - zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. - optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. - ckpt_combine_online (`bool`, *optional*): combining trainable parameters for saving checkpoint when zero_stage=3, \ - using allgather ops to combile the checkpoint online if `ckpt_combine_online=True`, \ - saving all device parameters if `ckpt_combine_online=False`, \ - and need to use `convert_checkpoints` to combile the checkpoint offline. default is False. + resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint. e.g. ['swap.', 'vae.']. """ self.rank_id = rank_id self.is_main_device = rank_id in [0, None] - self.use_zero = zero_stage in [1, 2, 3] - self.ckpt_combine_online = (zero_stage == 3) and ckpt_combine_online - if self.ckpt_combine_online and self.ema is not None: - _logger.warning("Can not enable ckpt_combine_online when use ema, set `ckpt_combine_online=False`.") - self.ckpt_combine_online = False - - self.need_save_network = self.is_main_device or (zero_stage == 3 and not self.ckpt_combine_online) - self.need_save_optimizer = self.is_main_device or self.use_zero - if self.use_zero: - if optimizer_parallel_group is None: - _logger.warning("EvalSaveCallback not set zero group, set it WORLD_COMM_GROUP.") - optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP - self.optimizer_parallel_group = optimizer_parallel_group - self.op_rank_id = get_rank(optimizer_parallel_group) self.ema = ema if output_dir is not None: self.output_dir = output_dir @@ -122,15 +93,13 @@ def __init__( self.record_lr = record_lr self.save_ema_only = save_ema_only - if self.need_save_network: + if self.is_main_device: self.ckpt_save_policy = ckpt_save_policy - self.monitor_metric = monitor_metric self.ckpt_manager = CheckpointManager( ckpt_save_dir, ckpt_save_policy, k=ckpt_max_keep, integrated_save=integrated_save, - prefer_low_perf=prefer_low_perf, ) if self.start_epoch == 0: if self.record_lr: @@ -162,22 +131,17 @@ def __init__( self.use_step_unit = use_step_unit self.train_steps = train_steps self.save_training_resume = save_training_resume - self.choice_func = None - if resume_prefix_blacklist: - if isinstance(resume_prefix_blacklist, str): - resume_prefix_blacklist = (resume_prefix_blacklist,) - self.choice_func = lambda x: not x.startswith(resume_prefix_blacklist) - - def _do_ckpt_combine_online(self): - new_net_to_save = [] - all_gather_op = ops.AllGather(self.optimizer_parallel_group) - for param in self.net_to_save: - if param.parallel_optimizer: - new_data = ms.Tensor(all_gather_op(param).asnumpy()) - else: - new_data = ms.Tensor(param.asnumpy()) - new_net_to_save.append({"name": param.name, "data": new_data}) - return new_net_to_save + if resume_prefix_blacklist is not None: + + def choice_func(x): + for prefix in resume_prefix_blacklist: + if x.startswith("vae."): + return False + return True + + self.choice_func = choice_func + else: + self.choice_func = None def on_train_step_end(self, run_context): cb_params = run_context.original_args() @@ -194,31 +158,7 @@ def on_train_step_end(self, run_context): else: cur_epoch = cb_params.cur_epoch_num - 1 - if self.save_training_resume and self.need_save_optimizer: - # TODO: resume training for step. - ckpt_name = f"train_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "train_resume.ckpt" - save_checkpoint( - cb_params.train_network, - os.path.join(self.ckpt_save_dir, ckpt_name), - choice_func=self.choice_func, - append_dict={ - "epoch_num": cur_epoch, - "cur_step": cur_step, - "loss_scale": self._get_scaling_value_from_cbp(cb_params), - }, - ) - if self.ema is not None: - ckpt_name = f"ema_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "ema_resume.ckpt" - save_checkpoint( - self.ema, - os.path.join(self.ckpt_save_dir, ckpt_name), - choice_func=self.choice_func, - ) - - if self.ckpt_combine_online: - new_net_to_save = self._do_ckpt_combine_online() - - if self.need_save_network: + if self.is_main_device: # if data sink, train step callback will not be invokded if self.step_mode and (cur_step % self.ckpt_save_interval == 0 or cur_step == step_num): ckpt_name = ( @@ -226,29 +166,34 @@ def on_train_step_end(self, run_context): if self.use_step_unit else f"{self.model_name}-e{cur_epoch}.ckpt" ) - if self.use_zero and not self.ckpt_combine_online: - file_extension = os.path.splitext(ckpt_name) - ckpt_name = f"{file_extension[0]}_op_rank_{self.op_rank_id}{file_extension[1]}" append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None - perf = cb_params.get("eval_results") - net_to_save = new_net_to_save if self.ckpt_combine_online else self.net_to_save - if perf or self.ckpt_save_policy != "top_k": - if perf: - perf = perf[self.monitor_metric] - if self.ema is not None: - if not self.save_ema_only: - self.ckpt_manager.save( - self.net_to_save, - perf, - ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), - append_dict=append_dict, - ) - # swap ema weight and network weight - self.ema.swap_before_eval() - - # save history checkpoints - self.ckpt_manager.save(net_to_save, perf, ckpt_name=ckpt_name, append_dict=append_dict) + if self.ema is not None: + if not self.save_ema_only: + self.ckpt_manager.save( + self.net_to_save, + None, + ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), + append_dict=append_dict, + ) + # swap ema weight and network weight + self.ema.swap_before_eval() + + # save history checkpoints + self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict) + + if self.save_training_resume: + # TODO: resume training for step. + ms.save_checkpoint( + cb_params.train_network, + os.path.join(self.ckpt_save_dir, "train_resume.ckpt"), + choice_func=self.choice_func, + append_dict={ + "epoch_num": cur_epoch, + "cur_step": cur_step, + "loss_scale": self._get_scaling_value_from_cbp(cb_params), + }, + ) # swap back network weight and ema weight. MUST execute after model saving and before next-step training if self.ema is not None: @@ -311,42 +256,15 @@ def on_train_epoch_end(self, run_context): opt = self._get_optimizer_from_cbp(cb_params) cur_step = int(opt.global_step.asnumpy().item()) - if self.save_training_resume and self.need_save_optimizer: - # TODO: resume training for step. - ckpt_name = f"train_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "train_resume.ckpt" - save_checkpoint( - cb_params.train_network, - os.path.join(self.ckpt_save_dir, ckpt_name), - choice_func=self.choice_func, - append_dict={ - "epoch_num": cur_epoch, - "loss_scale": self._get_scaling_value_from_cbp(cb_params), - }, - ) - if self.ema is not None: - ckpt_name = f"ema_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "ema_resume.ckpt" - save_checkpoint( - self.ema, - os.path.join(self.ckpt_save_dir, ckpt_name), - choice_func=self.choice_func, - ) - - if self.ckpt_combine_online: - new_net_to_save = self._do_ckpt_combine_online() - - if self.need_save_network and (not self.step_mode): + if self.is_main_device and (not self.step_mode): if (cur_epoch % self.ckpt_save_interval == 0) or (cur_epoch == epoch_num): ckpt_name = ( f"{self.model_name}-s{cur_step}.ckpt" if self.use_step_unit else f"{self.model_name}-e{cur_epoch}.ckpt" ) - if self.use_zero and not self.ckpt_combine_online: - file_extension = os.path.splitext(ckpt_name) - ckpt_name = f"{file_extension[0]}_op_rank_{self.op_rank_id}{file_extension[1]}" append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None - net_to_save = new_net_to_save if self.ckpt_combine_online else self.net_to_save if self.ema is not None: if not self.save_ema_only: self.ckpt_manager.save( @@ -359,9 +277,18 @@ def on_train_epoch_end(self, run_context): self.ema.swap_before_eval() # save history checkpoints - self.ckpt_manager.save( - net_to_save, perf=cb_params["net_outputs"], ckpt_name=ckpt_name, append_dict=append_dict - ) + self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict) + + if self.save_training_resume: + ms.save_checkpoint( + cb_params.train_network, + os.path.join(self.ckpt_save_dir, "train_resume.ckpt"), + choice_func=self.choice_func, + append_dict={ + "epoch_num": cur_epoch, + "loss_scale": self._get_scaling_value_from_cbp(cb_params), + }, + ) # swap back network weight and ema weight. MUST execute after model saving and before next-step training if self.ema is not None: @@ -372,16 +299,16 @@ def on_train_epoch_end(self, run_context): def on_train_end(self, run_context): if self.is_main_device: if self.ckpt_save_policy == "top_k": - log_str = f"Top K checkpoints: \n{self.main_indicator}\tcheckpoint\n" + log_str = f"Top K checkpoints:\n{self.main_indicator}\tcheckpoint\n" for p, ckpt_name in self.ckpt_manager.get_ckpt_queue(): - log_str += f"{p: .4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n" + log_str += f"{p:.4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n" def on_eval_end(self, run_context): if self.is_main_device: cb_params = run_context.original_args() metrics = cb_params.get("metrics") if metrics is not None: - metrics = {k: f"{v: .4f}" for k, v in metrics.items()} + metrics = {k: f"{v:.4f}" for k, v in metrics.items()} _logger.info(f"Eval result epoch {cb_params.cur_epoch_num}: {metrics}") def _get_optimizer_from_cbp(self, cb_params): @@ -399,7 +326,7 @@ def _get_scaling_value_from_cbp(self, cb_params): else: return cb_params.train_network.scale_sense.asnumpy().item() - def _fetch_optimizer_lr(self, cb_params) -> Tensor: + def _fetch_optimizer_lr(self, cb_params) -> ms.Tensor: opt = self._get_optimizer_from_cbp(cb_params) lr = opt.learning_rate if opt.dynamic_lr: @@ -407,7 +334,7 @@ def _fetch_optimizer_lr(self, cb_params) -> Tensor: return lr -class StopAtStepCallback(Callback): +class StopAtStepCallback(ms.Callback): # stop the training process when reach train_steps def __init__(self, train_steps, global_step=0): self.global_step = global_step @@ -419,7 +346,7 @@ def on_train_step_end(self, run_context): run_context.request_stop() -class ProfilerCallback(Callback): +class ProfilerCallback(ms.Callback): def __init__(self, start_step=1, end_step=2, exit_after_analyze=True, out_dir="./profiler_data"): self.start_step = start_step self.end_step = end_step @@ -428,7 +355,7 @@ def __init__(self, start_step=1, end_step=2, exit_after_analyze=True, out_dir=". out_dir = os.path.join(out_dir, f"rank_{rank_id}") # If value of profile_framework is not None, a subdirectory named host_info will be generated under the # specified profiler directory to store the collected memory and time files on the Host side. - self.profiler = Profiler( + self.profiler = ms.Profiler( start_profile=False, output_path=out_dir, profile_framework="all", data_simplication=False ) @@ -450,12 +377,12 @@ def on_train_step_end(self, run_context): run_context.request_stop() -class ProfilerCallbackEpoch(Callback): +class ProfilerCallbackEpoch(ms.Callback): def __init__(self, start_epoch, stop_epoch, output_dir="./profiler_data"): super().__init__() self.start_epoch = start_epoch self.stop_epoch = stop_epoch - self.profiler = Profiler(start_profile=False, output_path=output_dir) + self.profiler = ms.Profiler(start_profile=False, output_path=output_dir) def on_train_epoch_begin(self, run_context): cb_params = run_context.original_args() diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 312a3b53cf..3ad782cd11 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -43,11 +43,11 @@ def _run_dp_allreduce(dp_group_size, dp_allreduce, gradient): @_stage2_reduce_scatter.register("Tensor", "Function", "Function", "Tensor", "Bool") -def _run_stage2_reduce_scatter(op_group_size, reduce_scatter, allreduce, gradient, need_reduce_scatter): +def _run_stage2_reduce_scatter(optimizer_parallel_group_size, reduce_scatter, allreduce, gradient, need_reduce_scatter): if need_reduce_scatter: - gradient = reduce_scatter(gradient) / op_group_size + gradient = reduce_scatter(gradient) / optimizer_parallel_group_size else: - gradient = allreduce(gradient) / op_group_size + gradient = allreduce(gradient) / optimizer_parallel_group_size return gradient @@ -55,8 +55,8 @@ def _run_stage2_reduce_scatter(op_group_size, reduce_scatter, allreduce, gradien @_stage1_split_grad.register("Function", "Int", "Int", "Function", "Tensor", "Bool") -def _run_stage1_split_grad(split, op_group_size, op_rank_id, allreduce, gradient, need_split): - gradient = allreduce(gradient) / op_group_size +def _run_stage1_split_grad(split, optimizer_parallel_group_size, op_rank_id, allreduce, gradient, need_split): + gradient = allreduce(gradient) / optimizer_parallel_group_size if need_split: gradient = split(gradient)[op_rank_id] return gradient @@ -123,7 +123,7 @@ def __init__( self.op_reduce_scatter = ops.Identity() self.op_allreduce = ops.Identity() self.dp_allreduce = ops.Identity() - self.op_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 + self.optimizer_parallel_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 self.op_rank_id = get_rank(self.optimizer_parallel_group) if self.is_parallel else 0 self.need_dp = False self.dp_group = dp_group @@ -132,7 +132,7 @@ def __init__( self.need_parameter_split = tuple([False] * len(self.optimizer._parameters)) self.use_comm_fusion = False if self.zero_stage in [1, 2, 3] and self.is_parallel: - self.split_op = ops.Split(0, self.op_group_size) # optimizer parallel split + self.split_op = ops.Split(0, self.optimizer_parallel_group_size) # optimizer parallel split self.get_need_parameter_split() if comm_fusion is None: self.set_comm_ops() @@ -164,7 +164,7 @@ def __init__( _logger.info( f"Build TrainOneStepWrapper with ZeRO stage: {self.zero_stage}, " f"optimizer_offload: {optimizer_offload}, " - f"op_group_size: {self.op_group_size}, " + f"optimizer_parallel_group_size: {self.optimizer_parallel_group_size}, " f"op_rank_id: {self.op_rank_id}, " f"dp_group_size: {self.dp_group_size}." ) @@ -266,7 +266,7 @@ def set_dp_allreduce_comm_list(self, comm_fusion): _logger.info(f"dp_allreduce_fusion: {dp_allreduce_info}") def split_param(self, param): - return split_np(param, self.op_group_size, self.op_rank_id) + return split_np(param, self.optimizer_parallel_group_size, self.op_rank_id) def get_optimizer_param_tuples(self): param_tuples = [] @@ -291,7 +291,7 @@ def dump_params_split_info(self, params_split_info): for i, param in enumerate(self.optimizer._parameters): param_split_info = { "split": self.need_parameter_split[i], - "group_size": self.op_group_size, + "group_size": self.optimizer_parallel_group_size, "rank_id": self.op_rank_id, } params_split_info_dict[param.name] = param_split_info @@ -306,7 +306,7 @@ def get_need_parameter_split(self): self.need_parameter_split[i] = param.parallel_optimizer else: B = param.shape[0] - if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: + if param.parallel_optimizer and B >= self.optimizer_parallel_group_size and B % self.optimizer_parallel_group_size == 0: if self.zero_stage in [1, 2]: self.need_parameter_split[i] = True self.need_parameter_split = tuple(self.need_parameter_split) @@ -331,7 +331,11 @@ def split_params(self): _logger.debug(f"Do split with zero_stage {self.zero_stage}") if self.zero_stage in [1, 2]: B = param.shape[0] - if self.ori_parameters[i] and B >= self.op_group_size and B % self.op_group_size == 0: + if ( + self.ori_parameters[i].parallel_optimizer + and B >= self.optimizer_parallel_group_size + and B % self.optimizer_parallel_group_size == 0 + ): param.parallel_optimizer = True else: param.parallel_optimizer = False @@ -351,7 +355,7 @@ def reduce_scatter_gradients(self, gradients): gradients = self.hyper_map( ops.partial( _stage2_reduce_scatter, - ms.Tensor(self.op_group_size, dtype), + ms.Tensor(self.optimizer_parallel_group_size, dtype), ), self.zero2_reduce_scatter_list, self.zero2_allreduce_list, @@ -362,7 +366,7 @@ def reduce_scatter_gradients(self, gradients): gradients = self.hyper_map( ops.partial( _stage2_reduce_scatter, - ms.Tensor(self.op_group_size, dtype), + ms.Tensor(self.optimizer_parallel_group_size, dtype), self.op_reduce_scatter, self.op_allreduce, ), @@ -399,7 +403,7 @@ def split_gradients(self, gradients): ops.partial( _stage1_split_grad, self.split_op, - self.op_group_size, + self.optimizer_parallel_group_size, self.op_rank_id, ), self.zero1_allreduce_list, @@ -411,7 +415,7 @@ def split_gradients(self, gradients): ops.partial( _stage1_split_grad, self.split_op, - self.op_group_size, + self.optimizer_parallel_group_size, self.op_rank_id, self.op_allreduce, ), @@ -529,14 +533,14 @@ def prepare_ema(ema, zero_stage: int = 0, optimizer_parallel_group: str = None): is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel or zero_stage != 3: return ema - op_group_size = get_group_size(optimizer_parallel_group) + optimizer_parallel_group_size = get_group_size(optimizer_parallel_group) op_rank_id = get_rank(optimizer_parallel_group) - _logger.info(f"Split EMA params: rank_id {op_rank_id}, rank_size {op_group_size}.") + _logger.info(f"Split EMA params: rank_id {op_rank_id}, rank_size {optimizer_parallel_group_size}.") for net_weight, ema_weight, swap_cache in zip(ema.net_weight, ema.ema_weight, ema.swap_cache): if net_weight.shape == ema_weight.shape: continue - ema_weight.set_data(split_np(ema_weight, op_group_size, op_rank_id), slice_shape=True) - swap_cache.set_data(split_np(swap_cache, op_group_size, op_rank_id), slice_shape=True) + ema_weight.set_data(split_np(ema_weight, optimizer_parallel_group_size, op_rank_id), slice_shape=True) + swap_cache.set_data(split_np(swap_cache, optimizer_parallel_group_size, op_rank_id), slice_shape=True) return ema