Skip to content

Commit 93f0fc8

Browse files
committed
fix typo
Signed-off-by: c8ef <[email protected]>
1 parent 29e9fd5 commit 93f0fc8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

deepspeed/sequence/layer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,11 @@ def __init__(
338338
if sp_stream is not None:
339339
self.overlap_handles = {}
340340
self.sp_overlap_comm = True
341-
self.dafult_stream = get_accelerator().default_stream()
341+
self.default_stream = get_accelerator().default_stream()
342342

343343
def layer_sync(self, layer):
344344
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
345-
self.dafult_stream.wait_event(layer.done_event)
345+
self.default_stream.wait_event(layer.done_event)
346346

347347
def forward(self,
348348
query: Tensor,
@@ -374,7 +374,7 @@ def bwd_hook(layer_type):
374374
def pre_hook_fun(grad):
375375
type = 'd' + layer_type
376376
self.overlap_handles[type + '_work'].wait()
377-
self.sp_stream.wait_stream(self.dafult_stream)
377+
self.sp_stream.wait_stream(self.default_stream)
378378
all2all_output = self.overlap_handles[type + '_grad']
379379
grad = list(grad)
380380
grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output)
@@ -389,7 +389,7 @@ def pre_hook_fun(grad):
389389
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
390390
self.overlap_handles, 'k')
391391
if self.sp_overlap_comm:
392-
self.dafult_stream.wait_stream(self.sp_stream)
392+
self.default_stream.wait_stream(self.sp_stream)
393393

394394
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
395395
self.overlap_handles, 'v')

0 commit comments

Comments
 (0)