Skip to content

Commit 6feab42

Browse files
authored
fix bug (#123)
* Update word_embedding.py * Update transformer_encoder.py
1 parent f5ce4e5 commit 6feab42

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tencentpretrain/embeddings/word_embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class WordEmbedding(nn.Module):
99

1010
def __init__(self, args, vocab_size):
1111
super(WordEmbedding, self).__init__()
12-
if args.tensor_model_parallel_size > 1:
12+
if hasattr(args, "tensor_model_parallel_size") and args.tensor_model_parallel_size > 1:
1313
self.embedding = mpu.VocabParallelEmbedding(vocab_size, args.emb_size)
1414
else:
1515
self.embedding = nn.Embedding(vocab_size, args.emb_size)

tencentpretrain/encoders/transformer_encoder.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ def __init__(self, args):
2020
self.relative_position_embedding = args.relative_position_embedding
2121
self.rotary_position_embedding = args.rotary_position_embedding
2222
self.has_residual_attention = args.has_residual_attention
23-
self.tensor_model_parallel_size = args.tensor_model_parallel_size
23+
if hasattr(args, "tensor_model_parallel_size"):
24+
self.tensor_model_parallel_size = args.tensor_model_parallel_size
25+
else:
26+
self.tensor_model_parallel_size = 1
2427

2528
if self.relative_position_embedding:
2629
args.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num,
2730
num_buckets=args.relative_attention_buckets_num)
2831
elif self.rotary_position_embedding:
2932
args.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
3033

31-
if "deepspeed_checkpoint_activations" in args:
34+
if hasattr(args, "deepspeed_checkpoint_activations"):
3235
self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations
3336
self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num
3437
else:

0 commit comments

Comments
 (0)