@@ -17,7 +17,7 @@ def merge_dataset(dataset_path, workers_num):
17
17
for i in range (workers_num ):
18
18
tmp_dataset_reader = open ("dataset-tmp-" + str (i ) + ".pt" , "rb" )
19
19
while True :
20
- tmp_data = tmp_dataset_reader .read (2 ** 20 )
20
+ tmp_data = tmp_dataset_reader .read (2 ** 20 )
21
21
if tmp_data :
22
22
dataset_writer .write (tmp_data )
23
23
else :
@@ -69,13 +69,21 @@ def build_and_save(self, workers_num):
69
69
if workers_num == 1 :
70
70
self .worker (0 , 0 , lines_num )
71
71
else :
72
+ async_results = []
72
73
pool = Pool (workers_num )
73
74
for i in range (workers_num ):
74
75
start = i * lines_num // workers_num
75
76
end = (i + 1 ) * lines_num // workers_num
76
- pool .apply_async (func = self .worker , args = [i , start , end ])
77
+ # pool.apply_async(func=self.worker, args=[i, start, end])
78
+ async_results .append (pool .apply_async (func = self .worker , args = [i , start , end ]))
77
79
pool .close ()
78
80
pool .join ()
81
+ async_results = [res .get () for res in async_results ]
82
+ if async_results [0 ] is not None :
83
+ samples_num = sum ([res [0 ] for res in async_results ])
84
+ tokens_num = sum ([res [1 ] for res in async_results ])
85
+ print ("Number of samples:" , samples_num )
86
+ print ("Total number of tokens:" , tokens_num )
79
87
80
88
# Merge datasets.
81
89
merge_dataset (self .dataset_path , workers_num )
@@ -211,7 +219,8 @@ def create_ins_from_doc(self, all_documents, document_index):
211
219
pad_num = self .seq_length - len (src )
212
220
213
221
if not self .dynamic_masking :
214
- src , tgt_mlm = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob , self .span_max_length )
222
+ src , tgt_mlm = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking ,
223
+ self .span_geo_prob , self .span_max_length )
215
224
src = (src , pad_num )
216
225
instance = (src , tgt_mlm , is_random_next , seg_pos )
217
226
else :
@@ -245,7 +254,8 @@ def worker(self, proc_id, start, end):
245
254
line = f .readline ()
246
255
pos += 1
247
256
248
- document = [self .vocab .get (CLS_TOKEN )] + self .tokenizer .convert_tokens_to_ids (self .tokenizer .tokenize (line )) + [self .vocab .get (SEP_TOKEN )]
257
+ document = [self .vocab .get (CLS_TOKEN )] + self .tokenizer .convert_tokens_to_ids (
258
+ self .tokenizer .tokenize (line )) + [self .vocab .get (SEP_TOKEN )]
249
259
250
260
if self .full_sentences :
251
261
if len (document ) > 0 :
@@ -293,7 +303,8 @@ def build_instances(self, all_documents):
293
303
seg_pos = [len (src )]
294
304
295
305
if not self .dynamic_masking :
296
- src , tgt = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob , self .span_max_length )
306
+ src , tgt = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob ,
307
+ self .span_max_length )
297
308
instance = ((src , 0 ), tgt , seg_pos )
298
309
else :
299
310
instance = ((src , 0 ), seg_pos )
@@ -308,9 +319,10 @@ def build_instances(self, all_documents):
308
319
seg_pos = [len (src )]
309
320
310
321
pad_num = self .seq_length - len (src )
311
-
322
+
312
323
if not self .dynamic_masking :
313
- src , tgt = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob , self .span_max_length )
324
+ src , tgt = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob ,
325
+ self .span_max_length )
314
326
instance = ((src , pad_num ), tgt , seg_pos )
315
327
else :
316
328
instance = ((src , pad_num ), seg_pos )
@@ -417,7 +429,8 @@ def create_ins_from_doc(self, document):
417
429
pad_num = self .seq_length - len (src )
418
430
419
431
if not self .dynamic_masking :
420
- src , tgt_mlm = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob , self .span_max_length )
432
+ src , tgt_mlm = mask_seq (src , self .tokenizer , self .whole_word_masking , self .span_masking ,
433
+ self .span_geo_prob , self .span_max_length )
421
434
src = (src , pad_num )
422
435
instance = (src , tgt_mlm , is_wrong_order , seg_pos )
423
436
else :
@@ -464,7 +477,7 @@ def worker(self, proc_id, start, end):
464
477
seg_pos = [self .seq_length ]
465
478
src = (src , 0 )
466
479
pickle .dump ((src , seg_pos ), dataset_writer )
467
- buffer = buffer [instances_num * (self .seq_length + 1 ): ]
480
+ buffer = buffer [instances_num * (self .seq_length + 1 ):]
468
481
469
482
else :
470
483
instances_num = len (document ) // (self .seq_length + 1 )
@@ -486,13 +499,17 @@ def worker(self, proc_id, start, end):
486
499
487
500
dataset_writer .close ()
488
501
502
+
489
503
class LlmPretrainDataset (Dataset ):
490
504
def __init__ (self , args , vocab , tokenizer ):
491
505
super (LlmPretrainDataset , self ).__init__ (args , vocab , tokenizer )
492
506
self .full_sentences = args .full_sentences
493
507
494
508
def worker (self , proc_id , start , end ):
495
509
print ("Worker %d is building dataset ... " % proc_id )
510
+ samples_num = 0
511
+ tokens_num = 0
512
+
496
513
set_seed (self .seed )
497
514
dataset_writer = open ("dataset-tmp-" + str (proc_id ) + ".pt" , "wb" )
498
515
pos = 0
@@ -517,7 +534,7 @@ def worker(self, proc_id, start, end):
517
534
seg_pos = [self .seq_length ]
518
535
src = (src , 0 )
519
536
pickle .dump ((src , seg_pos ), dataset_writer )
520
- buffer = buffer [instances_num * (self .seq_length + 1 ): ]
537
+ buffer = buffer [instances_num * (self .seq_length + 1 ):]
521
538
522
539
else :
523
540
instances_num = len (document ) // (self .seq_length + 1 )
@@ -533,7 +550,8 @@ def worker(self, proc_id, start, end):
533
550
pad_num = self .seq_length + 1 - len (src )
534
551
src = (src , pad_num )
535
552
pickle .dump ((src , seg_pos ), dataset_writer )
536
-
553
+ tokens_num += len (src )
554
+ samples_num += 1
537
555
if pos >= end :
538
556
break
539
557
@@ -675,7 +693,8 @@ def create_ins_from_doc(self, all_documents, document_index):
675
693
676
694
while i < len (document ):
677
695
segment = document [i ]
678
- if i in mask_seq_list and len (tgt ) + len (segment ) < target_tgt_seq_length and len (src ) + 1 < target_seq_length :
696
+ if i in mask_seq_list and len (tgt ) + len (segment ) < target_tgt_seq_length and len (
697
+ src ) + 1 < target_seq_length :
679
698
tgt = tgt + segment
680
699
src = src + [self .vocab .get (MASK_TOKEN )]
681
700
elif i not in mask_seq_list and len (src ) + len (segment ) < target_seq_length :
@@ -884,7 +903,8 @@ def worker(self, proc_id, start, end):
884
903
if len (line ) == 2 :
885
904
label = int (line [0 ])
886
905
text = line [1 ]
887
- src = [self .vocab .get (CLS_TOKEN )] + self .tokenizer .convert_tokens_to_ids (self .tokenizer .tokenize (text )) + [self .vocab .get (SEP_TOKEN )]
906
+ src = [self .vocab .get (CLS_TOKEN )] + self .tokenizer .convert_tokens_to_ids (
907
+ self .tokenizer .tokenize (text )) + [self .vocab .get (SEP_TOKEN )]
888
908
tgt_cls = label
889
909
seg_pos = [len (src )]
890
910
elif len (line ) == 3 : # For sentence pair input.
@@ -920,7 +940,8 @@ def worker(self, proc_id, start, end):
920
940
921
941
if not self .dynamic_masking :
922
942
src_single , pad_num = src
923
- src_single , tgt_mlm = mask_seq (src_single , self .tokenizer , self .whole_word_masking , self .span_masking , self .span_geo_prob , self .span_max_length )
943
+ src_single , tgt_mlm = mask_seq (src_single , self .tokenizer , self .whole_word_masking ,
944
+ self .span_masking , self .span_geo_prob , self .span_max_length )
924
945
src = (src_single , pad_num )
925
946
instance = (src , tgt_mlm , tgt_cls , seg_pos )
926
947
else :
@@ -1046,6 +1067,8 @@ class DalleDataset(FileWithTextDataset):
1046
1067
class LlmSftDataset (Dataset ):
1047
1068
def worker (self , proc_id , start , end ):
1048
1069
print ("Worker %d is building dataset ... " % proc_id )
1070
+ samples_num = 0
1071
+ tokens_num = 0
1049
1072
set_seed (self .seed )
1050
1073
dataset_writer = open ("dataset-tmp-" + str (proc_id ) + ".pt" , "wb" )
1051
1074
pos = 0
@@ -1079,7 +1102,10 @@ def worker(self, proc_id, start, end):
1079
1102
pad_num = self .seq_length - len (src )
1080
1103
1081
1104
pickle .dump (((src , pad_num ), seg_pos ), dataset_writer )
1105
+ tokens_num += len (src )
1106
+ samples_num += 1
1082
1107
if pos >= end :
1083
1108
break
1084
1109
1085
1110
dataset_writer .close ()
1111
+ return samples_num , tokens_num
0 commit comments