@@ -338,11 +338,11 @@ def __init__(
338
338
if sp_stream is not None :
339
339
self .overlap_handles = {}
340
340
self .sp_overlap_comm = True
341
- self .dafult_stream = get_accelerator ().default_stream ()
341
+ self .default_stream = get_accelerator ().default_stream ()
342
342
343
343
def layer_sync (self , layer ):
344
344
if self .sp_overlap_comm and hasattr (layer , 'done_event' ):
345
- self .dafult_stream .wait_event (layer .done_event )
345
+ self .default_stream .wait_event (layer .done_event )
346
346
347
347
def forward (self ,
348
348
query : Tensor ,
@@ -374,7 +374,7 @@ def bwd_hook(layer_type):
374
374
def pre_hook_fun (grad ):
375
375
type = 'd' + layer_type
376
376
self .overlap_handles [type + '_work' ].wait ()
377
- self .sp_stream .wait_stream (self .dafult_stream )
377
+ self .sp_stream .wait_stream (self .default_stream )
378
378
all2all_output = self .overlap_handles [type + '_grad' ]
379
379
grad = list (grad )
380
380
grad [0 ] = self .overlap_handles [type + '_post_all2all_func' ](all2all_output )
@@ -389,7 +389,7 @@ def pre_hook_fun(grad):
389
389
key_layer = _SeqAllToAll .apply (self .spg , key , self .scatter_idx , self .gather_idx , batch_dim_idx , None ,
390
390
self .overlap_handles , 'k' )
391
391
if self .sp_overlap_comm :
392
- self .dafult_stream .wait_stream (self .sp_stream )
392
+ self .default_stream .wait_stream (self .sp_stream )
393
393
394
394
value_layer = _SeqAllToAll .apply (self .spg , value , self .scatter_idx , self .gather_idx , batch_dim_idx , None ,
395
395
self .overlap_handles , 'v' )
0 commit comments