forked from breadbread1984/OCR-tf2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsave_model.py
41 lines (34 loc) · 1.23 KB
/
save_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/python3
import sys;
from os import mkdir;
from os.path import join, exists;
import tensorflow as tf;
from create_dataset import SampleGenerator;
from models import CTPN, CRNN;
def save_ctpn():
ctpn = CTPN();
optimizer = tf.keras.optimizers.Adam(tf.keras.optimizers.schedules.ExponentialDecay(1e-5, decay_steps = 30000, decay_rate = 0.1));
checkpoint = tf.train.Checkpoint(model = ctpn, optimizer = optimizer);
checkpoint.restore(tf.train.latest_checkpoint('checkpoints'));
if False == exists("model"): mkdir("model");
ctpn.save(join("model","ctpn.h5"));
def save_ocr():
generator = SampleGenerator(10);
crnn = CRNN(generator.vocab_size() + 1);
optimizer = tf.keras.optimizers.Adam(1e-4);
checkpoint = tf.train.Checkpoint(model = crnn, optimizer = optimizer);
checkpoint.restore(tf.train.latest_checkpoint('checkpoints'));
if False == exists('model'): mkdir("model");
crnn.save(join("model", "crnn.h5"));
if __name__ == "__main__":
assert tf.executing_eagerly();
if len(sys.argv) != 2:
print("Usage: " + sys.argv[0] + " (ctpn|ocr)");
exit(1);
if sys.argv[1] not in ['ctpn', 'ocr']:
print("only support ctpn or ocr!");
exit(1);
if sys.argv[1] == "ctpn":
save_cptn();
else:
save_ocr();