-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathPredictor.py
156 lines (134 loc) · 6.2 KB
/
Predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/python3
import os;
from math import ceil;
import csv;
import tensorflow as tf;
from BERT import BERTClassifier;
# NOTE: only member functions without under score prefix are mean for users.
class Predictor(object):
def __init__(self, max_seq_len = 128):
# create bert object and tokenizer
self.classifier, self.tokenizer = BERTClassifier(max_seq_len = max_seq_len);
# load finetune parameters if it exists
if (os.path.exists('bert.h5')):
print("loading trained parameters...");
self.classifier.load_weights('bert.h5');
# save max sequence length
self.max_seq_len = max_seq_len;
def _read_tsv(self, inputfile):
with tf.io.gfile.GFile(inputfile, "r") as f:
reader = csv.reader(f, delimiter = '\t');
lines = [];
for line in reader:
lines.append(line);
return lines;
def _create_classifier_examples(self, examples):
dataset = [];
for line in examples:
label = int(line[0]);
question = line[1];
answer = line[2];
dataset.append((question,answer,label));
return dataset;
def _create_classifier_datasets(self, data_dir = None):
assert type(data_dir) is str;
train_examples = self._read_tsv(os.path.join(data_dir, "train.tsv"));
test_examples = self._read_tsv(os.path.join(data_dir, "test.tsv"));
trainset = self._create_classifier_examples(train_examples);
testset = self._create_classifier_examples(test_examples);
def write_tfrecord(dataset, output_file):
# write to tfrecord
writer = tf.io.TFRecordWriter(output_file);
for example in dataset:
input_ids, input_mask, segment_ids = self._preprocess(example[0], example[1]);
# write to tfrecord
tf_example = tf.train.Example(features = tf.train.Features(
feature = {
"input_ids": tf.train.Feature(int64_list = tf.train.Int64List(value = list(input_ids))),
"input_mask": tf.train.Feature(int64_list = tf.train.Int64List(value = list(input_mask))),
"segment_ids": tf.train.Feature(int64_list = tf.train.Int64List(value = list(segment_ids))),
"label_ids": tf.train.Feature(int64_list = tf.train.Int64List(value = [example[2]]))
}
));
writer.write(tf_example.SerializeToString());
writer.close();
write_tfrecord(trainset, "trainset.tfrecord");
write_tfrecord(testset, "testset.tfrecord");
self.trainset_size = len(trainset);
self.testset_size = len(testset);
def _preprocess(self, question, answer):
# tokenize question and answer.
tokens_question = self.tokenizer.tokenize(question);
tokens_answer = self.tokenizer.tokenize(answer);
# truncate to max seq len.
while True:
total_length = len(tokens_question) + len(tokens_answer);
if total_length <= self.max_seq_len - 3: break;
if len(tokens_question) > len(tokens_answer): tokens_question.pop();
else: tokens_answer.pop();
tokens = [];
segment_ids = [];
# insert question segment
tokens.append('[CLS]');
segment_ids.append(0);
for token in tokens_question:
tokens.append(token);
segment_ids.append(0);
tokens.append('[SEP]');
segment_ids.append(0);
# insert answer segment
for token in tokens_answer:
tokens.append(token);
segment_ids.append(1);
tokens.append('[SEP]');
segment_ids.append(1);
# tokenize into input_ids
input_ids = self.tokenizer.convert_tokens_to_ids(tokens);
# mask the valid token
input_mask = [1] * len(input_ids);
# padding 0
while len(input_ids) < self.max_seq_len:
input_ids.append(0);
input_mask.append(0);
segment_ids.append(0);
assert len(input_ids) == self.max_seq_len;
assert len(input_mask) == self.max_seq_len;
assert len(segment_ids) == self.max_seq_len;
return input_ids, input_mask, segment_ids;
@tf.function
def _classifier_input_fn(self, serialized_example):
feature = tf.io.parse_single_example(
serialized_example,
features = {
'input_ids': tf.io.FixedLenFeature((self.max_seq_len), dtype = tf.int64),
'input_mask': tf.io.FixedLenFeature((self.max_seq_len), dtype = tf.int64),
'segment_ids': tf.io.FixedLenFeature((self.max_seq_len), dtype = tf.int64),
'label_ids': tf.io.FixedLenFeature((), dtype = tf.int64)
}
);
for name in list(feature.keys()):
feature[name] = tf.cast(feature[name], dtype = tf.int32);
return (feature['input_ids'], feature['segment_ids']), feature['label_ids'];
def finetune(self, data_dir = None, batch = 32, epochs = 3):
assert type(data_dir) is str;
# create dataset in tfrecord format.
self._create_classifier_datasets(data_dir);
# load from the tfrecord file
trainset = tf.data.TFRecordDataset('trainset.tfrecord').map(self._classifier_input_fn).batch(batch).repeat().shuffle(buffer_size = 100);
# finetune the bert model
steps_per_epoch = ceil(self.trainset_size / batch);
self.classifier.fit(trainset, epochs = epochs, steps_per_epoch = steps_per_epoch);
# save model
self.classifier.save_weights('bert.h5');
def predict(self, question, answer):
input_ids, input_mask, segment_ids = self._preprocess(question, answer);
# add batch dim.
input_ids = tf.expand_dims(tf.constant(input_ids, dtype = tf.int32),0);
segment_ids = tf.expand_dims(tf.constant(segment_ids, dtype = tf.int32),0);
logits = self.classifier.predict([input_ids, segment_ids]);
out = tf.math.argmax(logits, axis = -1)[0];
return out, logits[0][1];
if __name__ == "__main__":
assert tf.executing_eagerly();
predictor = Predictor();
predictor.finetune('dataset');