File tree 2 files changed +6
-3
lines changed
2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ class WordEmbedding(nn.Module):
9
9
10
10
def __init__ (self , args , vocab_size ):
11
11
super (WordEmbedding , self ).__init__ ()
12
- if args .tensor_model_parallel_size > 1 :
12
+ if hasattr ( args , "tensor_model_parallel_size" ) and args .tensor_model_parallel_size > 1 :
13
13
self .embedding = mpu .VocabParallelEmbedding (vocab_size , args .emb_size )
14
14
else :
15
15
self .embedding = nn .Embedding (vocab_size , args .emb_size )
Original file line number Diff line number Diff line change @@ -20,15 +20,18 @@ def __init__(self, args):
20
20
self .relative_position_embedding = args .relative_position_embedding
21
21
self .rotary_position_embedding = args .rotary_position_embedding
22
22
self .has_residual_attention = args .has_residual_attention
23
- self .tensor_model_parallel_size = args .tensor_model_parallel_size
23
+ if hasattr (args , "tensor_model_parallel_size" ):
24
+ self .tensor_model_parallel_size = args .tensor_model_parallel_size
25
+ else :
26
+ self .tensor_model_parallel_size = 1
24
27
25
28
if self .relative_position_embedding :
26
29
args .relative_pos_emb = RelativePositionEmbedding (bidirectional = True , heads_num = args .heads_num ,
27
30
num_buckets = args .relative_attention_buckets_num )
28
31
elif self .rotary_position_embedding :
29
32
args .freqs_cis = precompute_freqs_cis (args .hidden_size // args .heads_num , args .max_seq_length * 2 )
30
33
31
- if "deepspeed_checkpoint_activations" in args :
34
+ if hasattr ( args , "deepspeed_checkpoint_activations" ) :
32
35
self .deepspeed_checkpoint_activations = args .deepspeed_checkpoint_activations
33
36
self .deepspeed_checkpoint_layers_num = args .deepspeed_checkpoint_layers_num
34
37
else :
You can’t perform that action at this time.
0 commit comments