Skip to content

Commit 27547a4

Browse files
wmpscckaeli
and
kaeli
authored
Fixed some bugs regarding activation checkpoints and updated the BPE vocabulary loader (#125)
* Add token counter, update BPE vocab init * Add special token security check * update no_decay list * [Fix] Fixed the impact of passing parameters on activation checkpointing. * update * update * update --------- Co-authored-by: kaeli <[email protected]>
1 parent dc155e4 commit 27547a4

9 files changed

+93
-56
lines changed

tencentpretrain/encoders/transformer_encoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ def custom_forward(*inputs):
130130
mpu.reset_checkpointed_activations_memory_buffer()
131131
l = 0
132132
while l < self.layers_num:
133-
inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs)
133+
inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), *inputs)
134134
l += self.deepspeed_checkpoint_layers_num
135135
else:
136136
for i in range(self.layers_num):
137137
if self.parameter_sharing:
138-
inputs = self.transformer(inputs)
138+
inputs = self.transformer(*inputs)
139139
else:
140-
inputs = self.transformer[i](inputs)
140+
inputs = self.transformer[i](*inputs)
141141

142142
hidden = inputs[0]
143143

tencentpretrain/layers/multi_headed_attn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def unshape(x):
136136
scores += prev_attn
137137
prev_attn_out = scores
138138

139-
probs = nn.Softmax(dim=-1)(scores)
139+
# probs = nn.Softmax(dim=-1)(scores)
140+
probs = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype)
140141
probs = self.dropout(probs)
141142
output = unshape(torch.matmul(probs, value))
142143
output = self.final_linear(output)

tencentpretrain/layers/transformer.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention, ParallelMultiHeadedAttention
44
from tencentpretrain.layers import *
55

6+
67
class TransformerLayer(nn.Module):
78
"""
89
Transformer layer mainly consists of two parts:
@@ -40,7 +41,7 @@ def __init__(self, args, layer_number=None):
4041

4142
self.self_attn = MultiHeadedAttention(
4243
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
43-
with_scale = with_scale, lora_params=lora_params, layer_number=layer_number
44+
with_scale=with_scale, lora_params=lora_params, layer_number=layer_number
4445
)
4546
self.dropout_1 = nn.Dropout(args.dropout)
4647

@@ -53,7 +54,7 @@ def __init__(self, args, layer_number=None):
5354
self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
5455
self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
5556

56-
def forward(self, inputs):
57+
def forward(self, *inputs):
5758

5859
"""
5960
Args:
@@ -63,7 +64,7 @@ def forward(self, inputs):
6364
Returns:
6465
output: [batch_size x seq_length x hidden_size]
6566
"""
66-
if len(inputs)==2:
67+
if len(inputs) == 2:
6768
hidden, mask = inputs
6869
prev_attn = None
6970
else:
@@ -136,7 +137,7 @@ def __init__(self, args, layer_number=None):
136137

137138
self.self_attn = ParallelMultiHeadedAttention(
138139
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
139-
with_scale = with_scale, lora_params=lora_params, layer_number=layer_number
140+
with_scale=with_scale, lora_params=lora_params, layer_number=layer_number
140141
)
141142
self.dropout_1 = nn.Dropout(args.dropout)
142143

@@ -150,7 +151,7 @@ def __init__(self, args, layer_number=None):
150151
self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
151152
self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
152153

153-
def forward(self, inputs):
154+
def forward(self, *inputs):
154155

155156
"""
156157
Args:
@@ -161,7 +162,7 @@ def forward(self, inputs):
161162
output: [batch_size x seq_length x hidden_size]
162163
"""
163164

164-
if len(inputs)==2:
165+
if len(inputs) == 2:
165166
hidden, mask = inputs
166167
prev_attn = None
167168
else:
@@ -220,7 +221,7 @@ def generate_mask(self, seq_length, batch_size, device):
220221
mask = mask.repeat(batch_size, 1, 1, 1)
221222
return mask
222223

223-
def forward(self, inputs):
224+
def forward(self, *inputs):
224225

225226
"""
226227
Args:
@@ -231,15 +232,15 @@ def forward(self, inputs):
231232
output: [batch_size x seq_length x hidden_size]
232233
"""
233234

234-
if len(inputs)==2:
235+
if len(inputs) == 2:
235236
hidden, seg = inputs
236237
prev_attn = None
237238
else:
238239
hidden, seg, prev_attn = inputs
239240
batch_size, seq_length, _ = hidden.size()
240241
mask = self.generate_mask(seq_length, batch_size, hidden.device)
241242
layer_inputs = hidden, mask, prev_attn
242-
outputs = self.layer(layer_inputs)
243+
outputs = self.layer(*layer_inputs)
243244

244245
if self.has_residual_attention:
245246
hidden, mask, prev_attn_out = outputs

tencentpretrain/trainer.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def init_optimizer(args, model_for_training):
109109
if 'lora' not in n:
110110
p.requires_grad = False
111111
else:
112-
no_decay = ["bias", "gamma", "beta"]
112+
no_decay = ["bias", "gamma", "beta", "layer_norm.weight", "layer_norm_1.weight", "layer_norm_2.weight"]
113113
optimizer_grouped_parameters = [
114114
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
115115
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
@@ -695,23 +695,22 @@ def worker(local_rank, gpu_ranks, args):
695695
if args.pipeline_model_parallel_size > 1:
696696
from deepspeed.pipe import PipelineModule, TiedLayerSpec, LayerSpec
697697
def get_model(model, args):
698-
layers = [LayerSpec(EmbeddingPipe, args,model=model),
699-
*[LayerSpec(ParallelTransformerLayerPipe, args,model=model, layer_idx=idx) for idx in
698+
layers = [LayerSpec(EmbeddingPipe, args, model=model),
699+
*[LayerSpec(ParallelTransformerLayerPipe, args, model=model, layer_idx=idx) for idx in
700700
range(args.layers_num)],
701-
LayerSpec(TargetPipe, args=args,model=model)
702-
]
701+
LayerSpec(TargetPipe, args=args, model=model)]
703702
return layers
704-
layers = get_model(model_for_training,args)
703+
layers = get_model(model_for_training, args)
705704
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
706705
topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
707706
num_mp=mpu.get_tensor_model_parallel_world_size(),
708707
num_dp=mpu.get_data_parallel_world_size())
709708

710-
model_for_training=PipelineModule(layers=layers,
711-
num_stages=args.pipeline_model_parallel_size,
712-
activation_checkpoint_interval=args.deepspeed_checkpoint_layers_num,
713-
loss_fn=CrossEntropy,
714-
checkpointable_layers=['ParallelTransformerLayerPipe'], topology=topo)
709+
model_for_training = PipelineModule(layers=layers,
710+
num_stages=args.pipeline_model_parallel_size,
711+
activation_checkpoint_interval=args.deepspeed_checkpoint_layers_num,
712+
loss_fn=CrossEntropy,
713+
checkpointable_layers=['ParallelTransformerLayerPipe'], topology=topo)
715714

716715
# Build optimizer.
717716
custom_optimizer, custom_scheduler, optimizer_grouped_parameters = init_optimizer(args, model_for_training)

tencentpretrain/utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
"t5": T5Dataloader, "gsg": GsgDataloader, "bart": BartDataloader,
2020
"cls": ClsDataloader, "prefixlm": PrefixlmDataloader, "cls_mlm": ClsMlmDataloader,
2121
"vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, "s2t": S2tDataloader,
22-
"beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader}
22+
"beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader,
23+
"llm_pretrain": LlmPretrainDataloader}
2324

2425
str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear}
2526

tencentpretrain/utils/constants.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
with open("models/special_tokens_map.json", mode="r", encoding="utf-8") as f:
55
special_tokens_map = json.load(f)
66

7-
UNK_TOKEN = special_tokens_map["unk_token"]
8-
CLS_TOKEN = special_tokens_map["cls_token"]
9-
SEP_TOKEN = special_tokens_map["sep_token"]
10-
MASK_TOKEN = special_tokens_map["mask_token"]
11-
PAD_TOKEN = special_tokens_map["pad_token"]
12-
try:
13-
# e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
14-
SENTINEL_TOKEN = special_tokens_map["sentinel_token"]
15-
except KeyError:
16-
pass
7+
UNK_TOKEN = special_tokens_map.get("unk_token")
8+
CLS_TOKEN = special_tokens_map.get("cls_token")
9+
SEP_TOKEN = special_tokens_map.get("sep_token")
10+
MASK_TOKEN = special_tokens_map.get("mask_token")
11+
PAD_TOKEN = special_tokens_map.get("pad_token")
12+
13+
# e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
14+
SENTINEL_TOKEN = special_tokens_map.get("sentinel_token")

tencentpretrain/utils/dataloader.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,6 @@ class S2tDataloader(AudioDataloader):
754754
def __iter__(self):
755755
import torchaudio
756756
import torchaudio.compliance.kaldi as ta_kaldi
757-
758757
padding_vector = torch.FloatTensor(self.audio_feature_size * [self.padding_value] if self.audio_feature_size > 1 else self.padding_value).unsqueeze(0).cuda(self.local_rank)
759758
while True:
760759
while self._empty():
@@ -949,3 +948,7 @@ def __iter__(self):
949948
yield torch.LongTensor(src), \
950949
torch.LongTensor(tgt), \
951950
torch.LongTensor(seg)
951+
952+
953+
class LlmPretrainDataloader(LmDataloader):
954+
pass

tencentpretrain/utils/dataset.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def merge_dataset(dataset_path, workers_num):
1717
for i in range(workers_num):
1818
tmp_dataset_reader = open("dataset-tmp-" + str(i) + ".pt", "rb")
1919
while True:
20-
tmp_data = tmp_dataset_reader.read(2**20)
20+
tmp_data = tmp_dataset_reader.read(2 ** 20)
2121
if tmp_data:
2222
dataset_writer.write(tmp_data)
2323
else:
@@ -69,13 +69,21 @@ def build_and_save(self, workers_num):
6969
if workers_num == 1:
7070
self.worker(0, 0, lines_num)
7171
else:
72+
async_results = []
7273
pool = Pool(workers_num)
7374
for i in range(workers_num):
7475
start = i * lines_num // workers_num
7576
end = (i + 1) * lines_num // workers_num
76-
pool.apply_async(func=self.worker, args=[i, start, end])
77+
# pool.apply_async(func=self.worker, args=[i, start, end])
78+
async_results.append(pool.apply_async(func=self.worker, args=[i, start, end]))
7779
pool.close()
7880
pool.join()
81+
async_results = [res.get() for res in async_results]
82+
if async_results[0] is not None:
83+
samples_num = sum([res[0] for res in async_results])
84+
tokens_num = sum([res[1] for res in async_results])
85+
print("Number of samples:", samples_num)
86+
print("Total number of tokens:", tokens_num)
7987

8088
# Merge datasets.
8189
merge_dataset(self.dataset_path, workers_num)
@@ -211,7 +219,8 @@ def create_ins_from_doc(self, all_documents, document_index):
211219
pad_num = self.seq_length - len(src)
212220

213221
if not self.dynamic_masking:
214-
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
222+
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
223+
self.span_geo_prob, self.span_max_length)
215224
src = (src, pad_num)
216225
instance = (src, tgt_mlm, is_random_next, seg_pos)
217226
else:
@@ -245,7 +254,8 @@ def worker(self, proc_id, start, end):
245254
line = f.readline()
246255
pos += 1
247256

248-
document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]
257+
document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(
258+
self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]
249259

250260
if self.full_sentences:
251261
if len(document) > 0:
@@ -293,7 +303,8 @@ def build_instances(self, all_documents):
293303
seg_pos = [len(src)]
294304

295305
if not self.dynamic_masking:
296-
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
306+
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob,
307+
self.span_max_length)
297308
instance = ((src, 0), tgt, seg_pos)
298309
else:
299310
instance = ((src, 0), seg_pos)
@@ -308,9 +319,10 @@ def build_instances(self, all_documents):
308319
seg_pos = [len(src)]
309320

310321
pad_num = self.seq_length - len(src)
311-
322+
312323
if not self.dynamic_masking:
313-
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
324+
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob,
325+
self.span_max_length)
314326
instance = ((src, pad_num), tgt, seg_pos)
315327
else:
316328
instance = ((src, pad_num), seg_pos)
@@ -417,7 +429,8 @@ def create_ins_from_doc(self, document):
417429
pad_num = self.seq_length - len(src)
418430

419431
if not self.dynamic_masking:
420-
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
432+
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
433+
self.span_geo_prob, self.span_max_length)
421434
src = (src, pad_num)
422435
instance = (src, tgt_mlm, is_wrong_order, seg_pos)
423436
else:
@@ -464,7 +477,7 @@ def worker(self, proc_id, start, end):
464477
seg_pos = [self.seq_length]
465478
src = (src, 0)
466479
pickle.dump((src, seg_pos), dataset_writer)
467-
buffer = buffer[instances_num * (self.seq_length + 1): ]
480+
buffer = buffer[instances_num * (self.seq_length + 1):]
468481

469482
else:
470483
instances_num = len(document) // (self.seq_length + 1)
@@ -486,13 +499,17 @@ def worker(self, proc_id, start, end):
486499

487500
dataset_writer.close()
488501

502+
489503
class LlmPretrainDataset(Dataset):
490504
def __init__(self, args, vocab, tokenizer):
491505
super(LlmPretrainDataset, self).__init__(args, vocab, tokenizer)
492506
self.full_sentences = args.full_sentences
493507

494508
def worker(self, proc_id, start, end):
495509
print("Worker %d is building dataset ... " % proc_id)
510+
samples_num = 0
511+
tokens_num = 0
512+
496513
set_seed(self.seed)
497514
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb")
498515
pos = 0
@@ -517,7 +534,7 @@ def worker(self, proc_id, start, end):
517534
seg_pos = [self.seq_length]
518535
src = (src, 0)
519536
pickle.dump((src, seg_pos), dataset_writer)
520-
buffer = buffer[instances_num * (self.seq_length + 1): ]
537+
buffer = buffer[instances_num * (self.seq_length + 1):]
521538

522539
else:
523540
instances_num = len(document) // (self.seq_length + 1)
@@ -533,7 +550,8 @@ def worker(self, proc_id, start, end):
533550
pad_num = self.seq_length + 1 - len(src)
534551
src = (src, pad_num)
535552
pickle.dump((src, seg_pos), dataset_writer)
536-
553+
tokens_num += len(src)
554+
samples_num += 1
537555
if pos >= end:
538556
break
539557

@@ -675,7 +693,8 @@ def create_ins_from_doc(self, all_documents, document_index):
675693

676694
while i < len(document):
677695
segment = document[i]
678-
if i in mask_seq_list and len(tgt) + len(segment) < target_tgt_seq_length and len(src) + 1 < target_seq_length:
696+
if i in mask_seq_list and len(tgt) + len(segment) < target_tgt_seq_length and len(
697+
src) + 1 < target_seq_length:
679698
tgt = tgt + segment
680699
src = src + [self.vocab.get(MASK_TOKEN)]
681700
elif i not in mask_seq_list and len(src) + len(segment) < target_seq_length:
@@ -884,7 +903,8 @@ def worker(self, proc_id, start, end):
884903
if len(line) == 2:
885904
label = int(line[0])
886905
text = line[1]
887-
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)]
906+
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(
907+
self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)]
888908
tgt_cls = label
889909
seg_pos = [len(src)]
890910
elif len(line) == 3: # For sentence pair input.
@@ -920,7 +940,8 @@ def worker(self, proc_id, start, end):
920940

921941
if not self.dynamic_masking:
922942
src_single, pad_num = src
923-
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
943+
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking,
944+
self.span_masking, self.span_geo_prob, self.span_max_length)
924945
src = (src_single, pad_num)
925946
instance = (src, tgt_mlm, tgt_cls, seg_pos)
926947
else:
@@ -1046,6 +1067,8 @@ class DalleDataset(FileWithTextDataset):
10461067
class LlmSftDataset(Dataset):
10471068
def worker(self, proc_id, start, end):
10481069
print("Worker %d is building dataset ... " % proc_id)
1070+
samples_num = 0
1071+
tokens_num = 0
10491072
set_seed(self.seed)
10501073
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb")
10511074
pos = 0
@@ -1079,7 +1102,10 @@ def worker(self, proc_id, start, end):
10791102
pad_num = self.seq_length - len(src)
10801103

10811104
pickle.dump(((src, pad_num), seg_pos), dataset_writer)
1105+
tokens_num += len(src)
1106+
samples_num += 1
10821107
if pos >= end:
10831108
break
10841109

10851110
dataset_writer.close()
1111+
return samples_num, tokens_num

0 commit comments

Comments
 (0)