28
28
from colossalai .interface .optimizer import DistributedOptim
29
29
from colossalai .logging import get_dist_logger
30
30
from colossalai .nn .optimizer import DistGaloreAwamW , cast_to_distributed
31
- from colossalai .pipeline .schedule import InterleavedSchedule , OneForwardOneBackwardSchedule
31
+ from colossalai .pipeline .schedule import InterleavedSchedule , OneForwardOneBackwardSchedule , ZeroBubbleVPipeScheduler
32
32
from colossalai .pipeline .stage_manager import PipelineStageManager
33
33
from colossalai .quantization import BnbQuantizationConfig , quantize_model
34
34
from colossalai .quantization .fp8_hook import FP8Hook
@@ -296,7 +296,7 @@ def __init__(
296
296
self ._current_grad_norm : Optional [float ] = None
297
297
super ().__init__ (optim )
298
298
299
- def backward (self , loss : Tensor , * args , ** kwargs ):
299
+ def backward (self , loss : Tensor , inputs = None , retain_graph = False , ** kwargs ):
300
300
r"""
301
301
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
302
302
@@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
315
315
316
316
# Call the superclass backward method to compute gradients.
317
317
with self .model ._hook_context ():
318
- super ().backward (loss , * args , ** kwargs )
318
+ super ().backward (loss , inputs = inputs , retain_graph = retain_graph , ** kwargs )
319
319
320
320
if self .model .require_grad_sync :
321
321
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -324,7 +324,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
324
324
# If gradient synchronization is is not required, return.
325
325
return
326
326
327
- def backward_by_grad (self , tensor : Tensor , grad : Tensor ):
327
+ def backward_by_grad (self , tensor : Tensor , grad : Tensor , inputs : Tensor = None , retain_graph : bool = False ):
328
328
"""
329
329
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
330
330
@@ -341,7 +341,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
341
341
"""
342
342
343
343
# Call the superclass backward method to compute gradients.
344
- super ().backward_by_grad (tensor , grad )
344
+ super ().backward_by_grad (tensor , grad , inputs = inputs , retain_graph = retain_graph )
345
345
346
346
if self .model .require_grad_sync :
347
347
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -525,7 +525,7 @@ def __init__(
525
525
max_norm = max_norm ,
526
526
)
527
527
528
- def backward (self , loss : Tensor , * args , ** kwargs ):
528
+ def backward (self , loss : Tensor , inputs = None , retain_graph = False , ** kwargs ):
529
529
r"""
530
530
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
531
531
@@ -543,7 +543,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
543
543
"""
544
544
# Call the superclass backward method to compute gradients.
545
545
with self .model ._hook_context ():
546
- super ().backward (loss , * args , ** kwargs )
546
+ super ().backward (loss , inputs = inputs , retain_graph = retain_graph , ** kwargs )
547
547
548
548
if self .model .require_grad_sync :
549
549
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -552,7 +552,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
552
552
# If gradient synchronization is is not required, return.
553
553
return
554
554
555
- def backward_by_grad (self , tensor : Tensor , grad : Tensor ):
555
+ def backward_by_grad (self , tensor : Tensor , grad : Tensor , inputs : Tensor = None , retain_graph : bool = False ):
556
556
"""
557
557
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
558
558
@@ -568,7 +568,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
568
568
None
569
569
"""
570
570
# Call the superclass backward method to compute gradients.
571
- super ().backward_by_grad (tensor , grad )
571
+ super ().backward_by_grad (tensor , grad , inputs = inputs , retain_graph = retain_graph )
572
572
573
573
if self .model .require_grad_sync :
574
574
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -785,7 +785,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
785
785
else :
786
786
return
787
787
788
- def backward (self , loss , retain_graph = False ):
788
+ def backward (self , loss , inputs = None , retain_graph = False ):
789
789
"""
790
790
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
791
791
@@ -801,7 +801,7 @@ def backward(self, loss, retain_graph=False):
801
801
None
802
802
"""
803
803
# Call the superclass backward method to compute gradients.
804
- super ().backward (loss , retain_graph )
804
+ super ().backward (loss , inputs = inputs , retain_graph = retain_graph )
805
805
806
806
if self .require_grad_sync and self .model .shard_config .enable_sequence_parallelism :
807
807
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -810,7 +810,7 @@ def backward(self, loss, retain_graph=False):
810
810
# If gradient synchronization is is not required, return.
811
811
return
812
812
813
- def backward_by_grad (self , tensor , grad ):
813
+ def backward_by_grad (self , tensor , grad , inputs : Tensor = None , retain_graph : bool = False ):
814
814
"""
815
815
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
816
816
@@ -826,7 +826,7 @@ def backward_by_grad(self, tensor, grad):
826
826
None
827
827
"""
828
828
# Call the superclass backward_by_grad method to compute gradients.
829
- super ().backward_by_grad (tensor , grad )
829
+ super ().backward_by_grad (tensor , grad , inputs = inputs , retain_graph = retain_graph )
830
830
831
831
if self .require_grad_sync and self .model .shard_config .enable_sequence_parallelism :
832
832
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -1030,6 +1030,7 @@ def __init__(
1030
1030
custom_policy : Policy = None ,
1031
1031
pp_style : str = "1f1b" ,
1032
1032
num_model_chunks : int = 1 ,
1033
+ scheduler_nodes : List = None ,
1033
1034
num_layers_per_stage : Optional [List [int ]] = None ,
1034
1035
gradient_checkpoint_config : Optional [GradientCheckpointConfig ] = None ,
1035
1036
enable_metadata_cache : bool = True ,
@@ -1048,6 +1049,9 @@ def __init__(
1048
1049
dist .get_world_size () % (tp_size * pp_size ) == 0
1049
1050
), f"World size { dist .get_world_size ()} is not divisible by tp_size { tp_size } * pp_size { pp_size } "
1050
1051
1052
+ assert (
1053
+ not pp_style == "zbv" or scheduler_nodes is not None
1054
+ ), f"scheduler_nodes must not be None when using zero bubble pipeline."
1051
1055
if enable_sequence_parallelism :
1052
1056
self .sequence_parallelism_mode = (
1053
1057
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@@ -1109,29 +1113,39 @@ def __init__(
1109
1113
self .pg_mesh = ProcessGroupMesh (self .pp_size , self .dp_size , self .tp_size , self .sp_size )
1110
1114
1111
1115
self .stage_manager = None
1112
- self .schedule = None
1116
+ self .scheduler = None
1113
1117
self .custom_policy = custom_policy
1114
1118
assert zero_stage in (0 , 1 , 2 )
1115
1119
if self .pp_size > 1 :
1116
- assert pp_style in ["1f1b" , "interleaved" ], "Unsupported pipeline parallelism style"
1117
- assert pp_style == "interleaved" or num_model_chunks == 1 , "num_model_chunks must be 1 when using 1f1b"
1120
+ assert pp_style in ["1f1b" , "interleaved" , "zbv" ], "Unsupported pipeline parallelism style"
1121
+ assert (
1122
+ pp_style in ["interleaved" , "zbv" ] or num_model_chunks == 1
1123
+ ), "num_model_chunks must be 1 when using 1f1b"
1124
+ assert (
1125
+ pp_style in ["1f1b" , "interleaved" ] or num_model_chunks == 2
1126
+ ), "num_model_chunks must be 2 when using zero bubble pipeline"
1118
1127
assert (
1119
1128
num_microbatches is not None or microbatch_size is not None
1120
1129
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
1121
1130
assert (
1122
1131
self .zero_stage <= 1
1123
1132
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
1133
+ if pp_style == "zbv" :
1134
+ self .logger .warning (
1135
+ """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
1136
+ )
1124
1137
self .stage_manager = PipelineStageManager (
1125
1138
self .pg_mesh ,
1126
1139
pipeline_axis = self .pp_axis ,
1127
- enable_interleave = pp_style == "interleaved" ,
1140
+ enable_interleave = (pp_style == "interleaved" or pp_style == "zbv" ),
1141
+ use_zbv = (pp_style == "zbv" ),
1128
1142
num_model_chunks = num_model_chunks ,
1129
1143
num_layers_per_stage = num_layers_per_stage ,
1130
1144
)
1131
1145
1132
1146
if pp_style == "interleaved" :
1133
1147
assert num_model_chunks > 1 , "number of model chunks must be > 1 when using interleaved"
1134
- self .schedule = InterleavedSchedule (
1148
+ self .scheduler = InterleavedSchedule (
1135
1149
stage_manager = self .stage_manager ,
1136
1150
num_model_chunks = num_model_chunks ,
1137
1151
num_microbatch = num_microbatches ,
@@ -1141,13 +1155,21 @@ def __init__(
1141
1155
fp8_communication = fp8_communication ,
1142
1156
)
1143
1157
elif pp_style == "1f1b" :
1144
- self .schedule = OneForwardOneBackwardSchedule (
1158
+ self .scheduler = OneForwardOneBackwardSchedule (
1145
1159
stage_manager = self .stage_manager ,
1146
1160
num_microbatches = num_microbatches ,
1147
1161
microbatch_size = microbatch_size ,
1148
1162
enable_metadata_cache = enable_metadata_cache ,
1149
1163
fp8_communication = fp8_communication ,
1150
1164
)
1165
+ elif pp_style == "zbv" :
1166
+ self .scheduler = ZeroBubbleVPipeScheduler (
1167
+ stage_manager = self .stage_manager ,
1168
+ schedule = scheduler_nodes ,
1169
+ num_model_chunks = num_model_chunks ,
1170
+ num_microbatch = num_microbatches ,
1171
+ microbatch_size = microbatch_size ,
1172
+ )
1151
1173
else :
1152
1174
raise NotImplementedError ()
1153
1175
if sequence_parallelism_mode == "ring_attn" :
@@ -1263,7 +1285,6 @@ def configure(
1263
1285
1264
1286
# Replace with distributed implementation if exists
1265
1287
optimizer = cast_to_distributed (optimizer )
1266
-
1267
1288
if isinstance (optimizer , DistGaloreAwamW ) and zero_stage > 0 and self .dp_size > 0 :
1268
1289
self .logger .warning (
1269
1290
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." ,
@@ -1278,6 +1299,7 @@ def configure(
1278
1299
self .dp_size == 1 and self .pp_size == 1
1279
1300
)
1280
1301
# sync gradients across DP * SP ranks
1302
+ # sync gradients across DP * SP ranks
1281
1303
# Apply Hybrid ZeRO across DP * SP ranks
1282
1304
if self .enable_sequence_parallelism and not is_share_sp_tp (self .sequence_parallelism_mode ):
1283
1305
dp_group = self .pg_mesh .create_group_along_axis ([self .dp_axis , self .sp_axis ])
@@ -1380,7 +1402,7 @@ def execute_pipeline(
1380
1402
ctx = optimizer .no_sync () if isinstance (optimizer , HybridParallelZeroOptimizer ) else model .no_sync ()
1381
1403
1382
1404
with ctx , model ._hook_context ():
1383
- outputs = self .schedule .forward_backward_step (
1405
+ outputs = self .scheduler .forward_backward_step (
1384
1406
model , data_iter , criterion , optimizer , return_loss , return_outputs
1385
1407
)
1386
1408
0 commit comments