Skip to content

Commit af6aa9e

Browse files
flybird11111pre-commit-ci[bot]duanjunwen
authored
[plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <[email protected]>
1 parent b804fdc commit af6aa9e

File tree

15 files changed

+140
-53
lines changed

15 files changed

+140
-53
lines changed

.github/workflows/build_on_pr.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ jobs:
140140
141141
- name: Install Colossal-AI
142142
run: |
143-
BUILD_EXT=1 pip install -v -e .
143+
BUILD_EXT=1 pip install -v .
144144
pip install --no-cache-dir -r requirements/requirements-test.txt
145145
146146
- name: Store Colossal-AI Cache

.github/workflows/build_on_schedule.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
if: steps.check-avai.outputs.avai == 'true'
5656
run: |
5757
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
58-
BUILD_EXT=1 pip install -v -e .
58+
BUILD_EXT=1 pip install -v .
5959
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
6060
pip install --no-cache-dir -r requirements/requirements-test.txt
6161

colossalai/amp/naive_amp/mixed_precision_mixin/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def zero_grad(self):
4343
dtype: torch.dtype
4444

4545
@abstractmethod
46-
def pre_backward(self, loss: Tensor) -> Tensor:
46+
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
4747
"""Called before backward.
4848
4949
Args:

colossalai/amp/naive_amp/mixed_precision_optimizer.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,18 @@ def __init__(
8585
master_params.append(master_p)
8686
group["params"] = master_params
8787

88-
def backward(self, loss: Tensor, *args, **kwargs):
88+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
8989
loss = self.mixed_precision.pre_backward(loss)
90-
loss.backward(*args, **kwargs)
90+
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
9191

92-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
92+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
9393
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
94-
tensor.backward(grad)
94+
torch.autograd.backward(
95+
tensors=tensor,
96+
grad_tensors=grad,
97+
inputs=inputs,
98+
retain_graph=retain_graph,
99+
)
95100

96101
def zero_grad(self, *args, **kwargs):
97102
for p in self.working_to_master_map.keys():

colossalai/booster/mixed_precision/fp16_torch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def __init__(
4646
growth_interval=growth_interval,
4747
)
4848

49-
def backward(self, loss: Tensor, *args, **kwargs) -> None:
49+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
5050
scaled_loss = self.scale_loss(loss)
51-
scaled_loss.backward(*args, **kwargs)
51+
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
5252

5353
def step(self, *args, **kwargs) -> Optional[float]:
5454
out = self.scaler.step(self.optim, *args, **kwargs)

colossalai/booster/plugin/hybrid_parallel_plugin.py

+42-21
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from colossalai.interface.optimizer import DistributedOptim
2929
from colossalai.logging import get_dist_logger
3030
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
31-
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
31+
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
3232
from colossalai.pipeline.stage_manager import PipelineStageManager
3333
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3434
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@@ -288,7 +288,7 @@ def __init__(
288288
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
289289
super().__init__(optim)
290290

291-
def backward(self, loss: Tensor, *args, **kwargs):
291+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
292292
r"""
293293
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
294294
@@ -306,7 +306,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
306306
"""
307307

308308
# Call the superclass backward method to compute gradients.
309-
super().backward(loss, *args, **kwargs)
309+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
310310

311311
if self.model.require_grad_sync:
312312
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
315315
# If gradient synchronization is is not required, return.
316316
return
317317

318-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
318+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
319319
"""
320320
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
321321
@@ -332,7 +332,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
332332
"""
333333

334334
# Call the superclass backward method to compute gradients.
335-
super().backward_by_grad(tensor, grad)
335+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
336336

337337
if self.model.require_grad_sync:
338338
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -512,7 +512,7 @@ def __init__(
512512
max_norm=max_norm,
513513
)
514514

515-
def backward(self, loss: Tensor, *args, **kwargs):
515+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
516516
r"""
517517
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
518518
@@ -529,7 +529,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
529529
None
530530
"""
531531
# Call the superclass backward method to compute gradients.
532-
super().backward(loss, *args, **kwargs)
532+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
533533

534534
if self.model.require_grad_sync:
535535
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -538,7 +538,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
538538
# If gradient synchronization is is not required, return.
539539
return
540540

541-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
541+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
542542
"""
543543
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
544544
@@ -554,7 +554,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
554554
None
555555
"""
556556
# Call the superclass backward method to compute gradients.
557-
super().backward_by_grad(tensor, grad)
557+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
558558

559559
if self.model.require_grad_sync:
560560
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -768,7 +768,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
768768
else:
769769
return
770770

771-
def backward(self, loss, retain_graph=False):
771+
def backward(self, loss, inputs=None, retain_graph=False):
772772
"""
773773
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
774774
@@ -784,7 +784,7 @@ def backward(self, loss, retain_graph=False):
784784
None
785785
"""
786786
# Call the superclass backward method to compute gradients.
787-
super().backward(loss, retain_graph)
787+
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
788788

789789
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
790790
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -793,7 +793,7 @@ def backward(self, loss, retain_graph=False):
793793
# If gradient synchronization is is not required, return.
794794
return
795795

796-
def backward_by_grad(self, tensor, grad):
796+
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
797797
"""
798798
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
799799
@@ -809,7 +809,7 @@ def backward_by_grad(self, tensor, grad):
809809
None
810810
"""
811811
# Call the superclass backward_by_grad method to compute gradients.
812-
super().backward_by_grad(tensor, grad)
812+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
813813

814814
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
815815
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -1013,6 +1013,7 @@ def __init__(
10131013
custom_policy: Policy = None,
10141014
pp_style: str = "1f1b",
10151015
num_model_chunks: int = 1,
1016+
scheduler_nodes: List = None,
10161017
num_layers_per_stage: Optional[List[int]] = None,
10171018
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
10181019
enable_metadata_cache: bool = True,
@@ -1029,6 +1030,9 @@ def __init__(
10291030
dist.get_world_size() % (tp_size * pp_size) == 0
10301031
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
10311032

1033+
assert (
1034+
not pp_style == "zbv" or scheduler_nodes is not None
1035+
), f"scheduler_nodes must not be None when using zero bubble pipeline."
10321036
if enable_sequence_parallelism:
10331037
self.sequence_parallelism_mode = (
10341038
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@@ -1088,29 +1092,39 @@ def __init__(
10881092
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
10891093

10901094
self.stage_manager = None
1091-
self.schedule = None
1095+
self.scheduler = None
10921096
self.custom_policy = custom_policy
10931097
assert zero_stage in (0, 1, 2)
10941098
if self.pp_size > 1:
1095-
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
1096-
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
1099+
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
1100+
assert (
1101+
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
1102+
), "num_model_chunks must be 1 when using 1f1b"
1103+
assert (
1104+
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
1105+
), "num_model_chunks must be 2 when using zero bubble pipeline"
10971106
assert (
10981107
num_microbatches is not None or microbatch_size is not None
10991108
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
11001109
assert (
11011110
self.zero_stage <= 1
11021111
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
1112+
if pp_style == "zbv":
1113+
self.logger.warning(
1114+
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
1115+
)
11031116
self.stage_manager = PipelineStageManager(
11041117
self.pg_mesh,
11051118
pipeline_axis=self.pp_axis,
1106-
enable_interleave=(pp_style == "interleaved"),
1119+
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
1120+
use_zbv=(pp_style == "zbv"),
11071121
num_model_chunks=num_model_chunks,
11081122
num_layers_per_stage=num_layers_per_stage,
11091123
)
11101124

11111125
if pp_style == "interleaved":
11121126
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
1113-
self.schedule = InterleavedSchedule(
1127+
self.scheduler = InterleavedSchedule(
11141128
stage_manager=self.stage_manager,
11151129
num_model_chunks=num_model_chunks,
11161130
num_microbatch=num_microbatches,
@@ -1119,12 +1133,20 @@ def __init__(
11191133
overlap_p2p=overlap_p2p,
11201134
)
11211135
elif pp_style == "1f1b":
1122-
self.schedule = OneForwardOneBackwardSchedule(
1136+
self.scheduler = OneForwardOneBackwardSchedule(
11231137
stage_manager=self.stage_manager,
11241138
num_microbatches=num_microbatches,
11251139
microbatch_size=microbatch_size,
11261140
enable_metadata_cache=enable_metadata_cache,
11271141
)
1142+
elif pp_style == "zbv":
1143+
self.scheduler = ZeroBubbleVPipeScheduler(
1144+
stage_manager=self.stage_manager,
1145+
schedule=scheduler_nodes,
1146+
num_model_chunks=num_model_chunks,
1147+
num_microbatch=num_microbatches,
1148+
microbatch_size=microbatch_size,
1149+
)
11281150
else:
11291151
raise NotImplementedError()
11301152
if sequence_parallelism_mode == "ring_attn":
@@ -1236,7 +1258,6 @@ def configure(
12361258

12371259
# Replace with distributed implementation if exists
12381260
optimizer = cast_to_distributed(optimizer)
1239-
12401261
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
12411262
self.logger.warning(
12421263
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
@@ -1352,7 +1373,7 @@ def execute_pipeline(
13521373
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
13531374

13541375
with ctx, model._wait_all_gather():
1355-
outputs = self.schedule.forward_backward_step(
1376+
outputs = self.scheduler.forward_backward_step(
13561377
model, data_iter, criterion, optimizer, return_loss, return_outputs
13571378
)
13581379

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def __init__(
280280
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
281281

282282
self.stage_manager = None
283-
self.schedule = None
283+
self.scheduler = None
284284
self.custom_policy = custom_policy
285285
assert zero_stage in (0, 1, 2)
286286
if self.pp_size > 1:
@@ -304,7 +304,7 @@ def __init__(
304304

305305
if pp_style == "interleaved":
306306
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
307-
self.schedule = InterleavedSchedule(
307+
self.scheduler = InterleavedSchedule(
308308
stage_manager=self.stage_manager,
309309
num_model_chunks=num_model_chunks,
310310
num_microbatch=num_microbatches,
@@ -313,7 +313,7 @@ def __init__(
313313
overlap_p2p=overlap_p2p,
314314
)
315315
elif pp_style == "1f1b":
316-
self.schedule = OneForwardOneBackwardSchedule(
316+
self.scheduler = OneForwardOneBackwardSchedule(
317317
stage_manager=self.stage_manager,
318318
num_microbatches=num_microbatches,
319319
microbatch_size=microbatch_size,

colossalai/interface/optimizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs):
4949
"""
5050
self.optim.zero_grad(*args, **kwargs)
5151

52-
def backward(self, loss: Tensor, *args, **kwargs):
52+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
5353
"""
5454
Performs a backward pass on the loss.
5555
"""
56-
loss.backward(*args, **kwargs)
56+
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
5757

5858
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
5959
"""

colossalai/pipeline/stage_manager.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool:
136136
if not self.is_interleave or ignore_chunk:
137137
return self.stage == self.num_stages - 1
138138
else:
139-
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
139+
# use zero bubble pipeline
140+
if self.use_zbv:
141+
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
142+
else:
143+
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
140144

141145
@property
142146
def num_stages(self) -> int:

colossalai/shardformer/policies/llama.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]:
261261
held_layers.append(module.embed_tokens)
262262
for start_idx, end_idx in stage_indices:
263263
held_layers.extend(module.layers[start_idx:end_idx])
264-
if stage_manager.is_last_stage(ignore_chunk=True):
264+
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
265+
held_layers.append(module.norm)
266+
elif stage_manager.is_last_stage(ignore_chunk=True):
265267
held_layers.append(module.norm)
266268

267269
else:
@@ -351,7 +353,9 @@ def get_held_layers(self) -> List[Module]:
351353
"""Get pipeline layers for current stage."""
352354
stage_manager = self.pipeline_stage_manager
353355
held_layers = super().get_held_layers()
354-
if stage_manager.is_last_stage(ignore_chunk=True):
356+
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
357+
held_layers.append(self.model.lm_head)
358+
elif stage_manager.is_last_stage(ignore_chunk=True):
355359
held_layers.append(self.model.lm_head)
356360
return held_layers
357361

@@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]:
404408
"""Get pipeline layers for current stage."""
405409
stage_manager = self.pipeline_stage_manager
406410
held_layers = super().get_held_layers()
407-
if stage_manager.is_last_stage(ignore_chunk=True):
411+
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
412+
held_layers.append(self.model.score)
413+
elif stage_manager.is_last_stage(ignore_chunk=True):
408414
held_layers.append(self.model.score)
409415
return held_layers
410416

colossalai/zero/gemini/gemini_ddp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor):
373373
loss.backward()
374374
self._post_backward()
375375

376-
def backward_by_grad(self, tensor, grad):
376+
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
377377
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
378378

379379
@staticmethod

colossalai/zero/gemini/gemini_optimizer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor):
298298
loss = self.mix_precision_mixin.pre_backward(loss)
299299
self.module.backward(loss)
300300

301-
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
301+
def backward_by_grad(
302+
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
303+
):
302304
# This function is called except the last stage of pipeline parallel
303305
# It receives the scaled grad from the previous rank
304306
# No need to scale the grad again
305307
# Need to unscale when optimizing
306-
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
308+
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
307309
self.module.backward_by_grad(tensor, grad)
308310

309311
def _maybe_move_fp32_params(self):

0 commit comments

Comments
 (0)