-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathclassifier.py
199 lines (179 loc) · 8.71 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""Train simple fastText-style classifier.
Inputs:
words - text to classify
ngrams - n char ngrams for each word in words
labels - output classes to classify
Model:
word embedding
ngram embedding
LogisticRegression classifier of embeddings to labels
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inputs
import sys
import tensorflow as tf
from tensorflow.contrib.layers import feature_column
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
tf.flags.DEFINE_string("train_records", None,
"Training file pattern for TFRecords, can use wildcards")
tf.flags.DEFINE_string("eval_records", None,
"Evaluation file pattern for TFRecords, can use wildcards")
tf.flags.DEFINE_string("predict_records", None,
"File pattern for TFRecords to predict, can use wildcards")
tf.flags.DEFINE_string("label_file", None, "File containing output labels")
tf.flags.DEFINE_integer("num_labels", None, "Number of output labels")
tf.flags.DEFINE_string("vocab_file", None, "Vocabulary file, one word per line")
tf.flags.DEFINE_integer("vocab_size", None, "Number of words in vocabulary")
tf.flags.DEFINE_integer("num_oov_vocab_buckets", 20,
"Number of hash buckets to use for OOV words")
tf.flags.DEFINE_string("model_dir", ".",
"Output directory for checkpoints and summaries")
tf.flags.DEFINE_string("export_dir", None, "Directory to store savedmodel")
tf.flags.DEFINE_integer("embedding_dimension", 10, "Dimension of word embedding")
tf.flags.DEFINE_boolean("use_ngrams", False, "Use character ngrams in embedding")
tf.flags.DEFINE_integer("num_ngram_buckets", 1000000,
"Number of hash buckets for ngrams")
tf.flags.DEFINE_integer("ngram_embedding_dimension", 10, "Dimension of word embedding")
tf.flags.DEFINE_float("learning_rate", 0.001, "Learning rate for training")
tf.flags.DEFINE_float("clip_gradient", 5.0, "Clip gradient norm to this ratio")
tf.flags.DEFINE_integer("batch_size", 128, "Training minibatch size")
tf.flags.DEFINE_integer("train_steps", 1000,
"Number of train steps, None for continuous")
tf.flags.DEFINE_integer("eval_steps", 100, "Number of eval steps")
tf.flags.DEFINE_integer("num_epochs", None, "Number of training data epochs")
tf.flags.DEFINE_integer("checkpoint_steps", 1000,
"Steps between saving checkpoints")
tf.flags.DEFINE_integer("num_threads", 1, "Number of reader threads")
tf.flags.DEFINE_boolean("log_device_placement", False, "log where ops are located")
tf.flags.DEFINE_boolean("horovod", False,
"Run across multiple GPUs using Horovod MPI. https://github.com/uber/horovod")
tf.flags.DEFINE_boolean("debug", False, "Debug")
FLAGS = tf.flags.FLAGS
if FLAGS.horovod:
try:
import horovod.tensorflow as hvd
except ImportError, e:
print(e)
print("Make sure Horovod is installed: https://github.com/uber/horovod")
sys.exit(1)
hvd.init()
def InputFn(mode, input_file):
return inputs.InputFn(
mode, FLAGS.use_ngrams, input_file, FLAGS.vocab_file, FLAGS.vocab_size,
FLAGS.embedding_dimension, FLAGS.num_oov_vocab_buckets,
FLAGS.label_file, FLAGS.num_labels,
FLAGS.ngram_embedding_dimension, FLAGS.num_ngram_buckets,
FLAGS.batch_size, FLAGS.num_epochs, FLAGS.num_threads)
def Exports(probs, embedding):
exports = {
"proba": tf.estimator.export.ClassificationOutput(scores=probs),
"embedding": tf.estimator.export.RegressionOutput(value=embedding),
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
tf.estimator.export.ClassificationOutput(scores=probs),
}
return exports
def FastTextEstimator(model_dir, config=None):
params = {
"learning_rate": FLAGS.learning_rate,
}
def model_fn(features, labels, mode, params):
features["text"] = tf.sparse_tensor_to_dense(features["text"],
default_value=" ")
if FLAGS.use_ngrams:
features["ngrams"] = tf.sparse_tensor_to_dense(features["ngrams"],
default_value=" ")
text_lookup_table = tf.contrib.lookup.index_table_from_file(
FLAGS.vocab_file, FLAGS.num_oov_vocab_buckets, FLAGS.vocab_size)
text_ids = text_lookup_table.lookup(features["text"])
text_embedding_w = tf.Variable(tf.random_uniform(
[FLAGS.vocab_size + FLAGS.num_oov_vocab_buckets, FLAGS.embedding_dimension],
-0.1, 0.1))
text_embedding = tf.reduce_mean(tf.nn.embedding_lookup(
text_embedding_w, text_ids), axis=-2)
input_layer = text_embedding
if FLAGS.use_ngrams:
ngram_hash = tf.string_to_hash_bucket(features["ngrams"],
FLAGS.num_ngram_buckets)
ngram_embedding_w = tf.Variable(tf.random_uniform(
[FLAGS.num_ngram_buckets, FLAGS.ngram_embedding_dimension], -0.1, 0.1))
ngram_embedding = tf.reduce_mean(tf.nn.embedding_lookup(
ngram_embedding_w, ngram_hash), axis=-2)
ngram_embedding = tf.expand_dims(ngram_embedding, -2)
input_layer = tf.concat([text_embedding, ngram_embedding], -1)
num_classes = FLAGS.num_labels
logits = tf.contrib.layers.fully_connected(
inputs=input_layer, num_outputs=num_classes,
activation_fn=None)
predictions = tf.argmax(logits, axis=-1)
probs = tf.nn.softmax(logits)
loss, train_op = None, None
metrics = {}
if mode != tf.estimator.ModeKeys.PREDICT:
label_lookup_table = tf.contrib.lookup.index_table_from_file(
FLAGS.label_file, vocab_size=FLAGS.num_labels)
labels = label_lookup_table.lookup(labels)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
opt = tf.train.AdamOptimizer(params["learning_rate"])
if FLAGS.horovod:
opt = hvd.DistributedOptimizer(opt)
train_op = opt.minimize(loss, global_step=tf.train.get_global_step())
metrics = {
"accuracy": tf.metrics.accuracy(labels, predictions)
}
exports = {}
if FLAGS.export_dir:
exports = Exports(probs, text_embedding)
return tf.estimator.EstimatorSpec(
mode, predictions=predictions, loss=loss, train_op=train_op,
eval_metric_ops=metrics, export_outputs=exports)
session_config = tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)
if FLAGS.horovod:
session_config.gpu_options.visible_device_list = str(hvd.local_rank())
config = tf.contrib.learn.RunConfig(
save_checkpoints_secs=None,
save_checkpoints_steps=FLAGS.checkpoint_steps,
session_config=session_config)
return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=params, config=config)
def FastTrain():
print("FastTrain", FLAGS.train_steps)
estimator = FastTextEstimator(FLAGS.model_dir)
print("TEST" + FLAGS.train_records)
train_input = InputFn(tf.estimator.ModeKeys.TRAIN, FLAGS.train_records)
print("STARTING TRAIN")
hooks = None
if FLAGS.horovod:
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
estimator.train(input_fn=train_input, steps=FLAGS.train_steps, hooks=hooks)
print("TRAIN COMPLETE")
if not FLAGS.horovod or hvd.rank() == 0:
print("EVALUATE")
eval_input = InputFn(tf.estimator.ModeKeys.EVAL, FLAGS.eval_records)
#eval_metrics = { "accuracy": tf.metrics.accuracy(labels, predictions) }
result = estimator.evaluate(input_fn=eval_input, steps=FLAGS.eval_steps, hooks=None)
print(result)
print("DONE")
if FLAGS.export_dir:
print("EXPORTING")
estimator.export_savedmodel(FLAGS.export_dir,
inputs.ServingInputFn(FLAGS.use_ngrams))
def main(_):
if not FLAGS.vocab_size:
FLAGS.vocab_size = len(open(FLAGS.vocab_file).readlines())
if not FLAGS.num_labels:
FLAGS.num_labels = len(open(FLAGS.label_file).readlines())
if FLAGS.horovod:
nproc = hvd.size()
total = FLAGS.train_steps
FLAGS.train_steps = total / nproc
print("Running %d steps on each of %d processes for %d total" % (
FLAGS.train_steps, nproc, total))
FastTrain()
if __name__ == '__main__':
if FLAGS.debug:
tf.logging.set_verbosity(tf.logging.DEBUG)
tf.app.run()