Skip to content

Commit 26a3e7e

Browse files
committed
catch SIGTERM
show messages on index
1 parent 0ab12a9 commit 26a3e7e

File tree

6 files changed

+139
-76
lines changed

6 files changed

+139
-76
lines changed

Diff for: img_prep.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ def deprocess_image(x):
2828
return x
2929

3030

31-
def random_image():
32-
"""
33-
Create a random image
34-
:return: A random image
35-
"""
36-
return np.random.uniform(0, 255, (1, 3, img_width, img_height))
31+
# def random_image():
32+
# """
33+
# Create a random image
34+
# :return: A random image
35+
# """
36+
# return np.random.uniform(0, 255, (1, 3, img_width, img_height))
3737

3838

3939
def grey_image():

Diff for: index.html

+6-4
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,22 @@
4040
if (trained){
4141
console.log('pause');
4242
$('#train').get(0).innerHTML = 'train';
43-
$('#status').get(0).innerHTML = 'pause';
43+
// $('#status').get(0).innerHTML = 'pause';
4444
trained = false;
4545
}
4646

4747
else {
4848
console.log('train');
4949
$('#train').get(0).innerHTML = 'pause';
50-
$('#status').get(0).innerHTML = 'training';
50+
// $('#status').get(0).innerHTML = 'training';
5151
trained = true;
5252
}
5353
}
5454

5555
function stop() {
5656
$.ajax('/stop')
5757
$('#train').get(0).innerHTML = 'train';
58-
$('#status').get(0).innerHTML = 'stop';
58+
// $('#status').get(0).innerHTML = 'stop';
5959
}
6060

6161
$(document).ready(function() {
@@ -64,14 +64,16 @@
6464
$.post('http://localhost:8000/lr?lr='+this.value);
6565
});
6666
updater.poll();
67+
// alert(updater2);
68+
updater2.poll();
6769
});
6870

6971

7072

7173
</script>
7274
<!--<span id="test" style="padding-left: 240px;"></span>-->
7375
<ul class="side-nav fixed">
74-
<li style="text-align: center"><b >status: <span id="status">stop</span></b></li>
76+
<li style="text-align: center"><b >status:</b> <br><span id="status">stop</span></li>
7577
<li><a class="waves-effect waves-light btn cyan accent-4 white-text" href="javascript: train()" id="train">train</a></li>
7678
<li><a class="waves-effect waves-light btn cyan accent-4 white-text" href="javascript: stop()" id="stop">stop</a></li>
7779
<li><a class="waves-effect waves-light btn cyan accent-4 white-text" href="javascript: toDefault()">Default View</a></li>

Diff for: neural_style.py

+59-22
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
from __future__ import print_function
22
import os
3-
import sys
4-
5-
import multiprocessing
6-
from start_training import start_training
7-
8-
import json
9-
import uuid
103
from tornado.ioloop import IOLoop
114
from tornado.web import RequestHandler, Application
125
from tornado.escape import json_decode
6+
import json
7+
# from main import start_training
8+
import multiprocessing
139
from tornado.concurrent import Future
1410
from tornado import gen
11+
import uuid
12+
# from training_record import TrainingRecorder
1513

16-
__author__ = 'xymeow', 'suquark'
14+
__author__ = 'xymeow'
1715

1816
# TODO: show training messages on the front end
1917

20-
label = None
21-
2218
start = False
2319
trained = False
2420
stop = False
2521
lr = 1.0
2622

27-
23+
p = None
24+
record_cache = None
2825
# use long polling to update the training results
2926
# I copied these codes from tornado/demo/chat
3027
class MessageBuffer(object):
@@ -67,7 +64,7 @@ def new_messages(self, messages):
6764

6865

6966
record_pool = MessageBuffer()
70-
67+
message_pool = MessageBuffer()
7168

7269
class MainHandler(RequestHandler):
7370
def get(self):
@@ -79,10 +76,13 @@ def post(self):
7976

8077

8178
class RecordHandler(RequestHandler):
79+
8280
def post(self):
81+
global record_cache
8382
json_dict = json_decode(self.request.body)
8483
print(self.request.body)
85-
84+
record_cache = json_dict['cache']
85+
json_dict.pop('cache')
8686
myjson = {
8787
'id': str(uuid.uuid4()),
8888
'json': json_dict
@@ -112,11 +112,14 @@ def on_connection_close(self):
112112

113113
class TrainHandler(RequestHandler):
114114
def get(self):
115-
global trained, start, label
115+
global trained
116+
global start
117+
global p
116118
if not start:
117119
start = True
118120
trained = True
119-
multiprocessing.Process(target=start_training).start()
121+
p = multiprocessing.Process(target=start_training)
122+
p.start()
120123
elif trained:
121124
trained = False
122125
elif not trained:
@@ -139,8 +142,11 @@ class StopHandler(RequestHandler):
139142
def get(self):
140143
global stop
141144
global start
145+
global p
142146
stop = True
143147
start = False
148+
if p != None:
149+
p.terminate()
144150

145151

146152
class PictureHandler(RequestHandler):
@@ -153,12 +159,12 @@ def get(self):
153159

154160
class InitJSONHandler(RequestHandler):
155161
def get(self):
156-
if os.path.exists(label + '.json'):
157-
with open(label + '.json', 'rb') as f:
162+
global record_cache
163+
if record_cache == None:
164+
with open('result.json', 'rb') as f:
158165
self.write(f.read())
159166
else:
160-
print('Record file not found.')
161-
self.write('{}')
167+
self.write(record_cache)
162168

163169

164170
class LearningRateHandler(RequestHandler):
@@ -173,6 +179,36 @@ def post(self):
173179
print('set learnig rate to {}'.format(lr))
174180

175181

182+
class PullMessageHandler(RequestHandler):
183+
184+
@gen.coroutine
185+
def post(self):
186+
cursor = self.get_argument("cursor", None)
187+
# Save the future returned by wait_for_messages so we can cancel
188+
# it in wait_for_messages
189+
self.future = message_pool.wait_for_messages(cursor=cursor)
190+
messages = yield self.future
191+
if self.request.connection.stream.closed():
192+
return
193+
self.write(dict(messages=messages))
194+
195+
def on_connection_close(self):
196+
message_pool.cancel_wait(self.future)
197+
198+
199+
class SetMessageHandler(RequestHandler):
200+
def post(self):
201+
# global record_cache
202+
json_dict = json_decode(self.request.body)
203+
print(self.request.body)
204+
myjson = {
205+
'id': str(uuid.uuid4()),
206+
'json': json_dict
207+
}
208+
message_pool.new_messages([myjson])
209+
210+
211+
176212
settings = {
177213
"static_path": os.path.join(os.path.dirname(__file__), "static")
178214
}
@@ -186,11 +222,12 @@ def post(self):
186222
(r"/picture", PictureHandler),
187223
(r"/init", InitJSONHandler),
188224
(r"/lr", LearningRateHandler),
189-
(r"/train", TrainHandler)
225+
(r"/train", TrainHandler),
226+
(r"/setmsg", SetMessageHandler),
227+
(r"/getmsg", PullMessageHandler)
190228
], **settings)
191229

192230
if __name__ == "__main__":
193-
# assert len(sys.argv) > 1, 'You should give the name of this training'
194-
label = sys.argv[1] if len(sys.argv) > 1 else 'result' # Which record would you like to take?
231+
print('server start at localhost:8000')
195232
application.listen(8000)
196233
IOLoop.instance().start()

Diff for: start_training.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,43 @@
99
from training_record import TrainingRecorder
1010
from img_prep import img_in, img_save, preprocess_image, grey_image
1111
from settings import content_path, style_path, loss_set, result_dir
12-
import requests, time
12+
import requests, time, signal, sys
1313

1414

1515
# load content and style image
16+
# record_cache = None
17+
tr = None
18+
19+
20+
def signal_term_handler(signal, frame):
21+
global tr
22+
if tr == None:
23+
print('stop training.')
24+
requests.post('http://localhost:8000/setmsg', None,
25+
{'msg': 'stop training.'})
26+
# sys.exit(0)
27+
else:
28+
print('stop training and save weights into json file.')
29+
requests.post('http://localhost:8000/setmsg', None,
30+
{'msg': 'stop training and save weights into json file.'})
31+
tr.export_file()
32+
sys.exit(0)
33+
34+
signal.signal(signal.SIGTERM, signal_term_handler)
35+
36+
1637
def start_training():
17-
content, style = img_in(content_path, style_path)
1838

19-
tr = TrainingRecorder(result_dir, loss_set, img_save, '.png')
39+
content, style = img_in(content_path, style_path)
40+
global tr
41+
tr = TrainingRecorder(loss_set, result_dir, img_save, '.png')
2042

21-
idx = tr.get_tail() - 1
22-
if idx >= 0:
43+
idx = tr.get_head() - 1
44+
if idx > 0:
2345
x = preprocess_image(tr.get_name(idx))
2446
else:
25-
x = grey_image()
47+
x = content
48+
2649
idx += 1
2750

2851
model = Model(content, style, x)
@@ -34,13 +57,16 @@ def start_training():
3457
lr = float(requests.get('http://localhost:8000/lr').text)
3558
if status == 'pause':
3659
i -= 1
60+
requests.post('http://localhost:8000/setmsg', None,
61+
{'msg': 'pause'})
3762
while True:
3863
time.sleep(1)
3964
status = requests.get('http://localhost:8000/status').text
4065
if status != 'pause':
4166
break
4267
elif status == 'training':
4368
print('Start of iteration', i)
69+
requests.post('http://localhost:8000/setmsg', None, {'msg': 'Start of iteration ' + str(i)})
4470
# msg = 'Start of iteration ' + str(i)
4571
if lst_lr != lr:
4672
# if learning rate has changes, set a new lr for the optimizer
@@ -51,9 +77,14 @@ def start_training():
5177
lst_lr = lr
5278
elif status == 'stop':
5379
print('stop training and save weights into json file.')
80+
requests.post('http://localhost:8000/setmsg', None,
81+
{'msg': 'stop training and save weights into json file.'})
5482
# msg = 'stop training and save weights into json file.'
5583
tr.export_file()
5684
break
5785

5886
except KeyboardInterrupt:
87+
print('stop training and save weights into json file.')
88+
requests.post('http://localhost:8000/setmsg', None,
89+
{'msg': 'stop training and save weights into json file.'})
5990
tr.export_file()

Diff for: training.py

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from settings import img_height, img_width, input_shape, learning_rate
55
from settings import loss_set, content_feature_layers, style_feature_layers
66
from vgg16_cnnonly_model import set_model_input, model_with_input
7+
import requests
78

89

910
def gram_matrix(x):
@@ -39,6 +40,8 @@ def __init__(self, content, style, x):
3940
self.model = None
4041
# evaluate features of content and style images
4142
print("Pre-evaluate features...")
43+
requests.post('http://localhost:8000/setmsg', None,
44+
{'msg': "Pre-evaluate features..."})
4245
# global msg
4346
# msg = "Pre-evaluate features..."
4447
self.gen_model(K.placeholder(input_shape))
@@ -60,6 +63,8 @@ def set_lr(self, learning_rate):
6063
# self.optimizer.lr.set_value(learning_rate)
6164
K.set_value(self.optimizer.lr, learning_rate)
6265
print('learning rate = {}'.format(learning_rate))
66+
requests.post('http://localhost:8000/setmsg', None,
67+
{'msg': 'set learning rate to {}'.format(learning_rate)})
6368
# global msg
6469
# msg = 'learning rate = {}'.format(learning_rate)
6570

@@ -146,6 +151,8 @@ def compile(self):
146151
"""
147152
# global msg
148153
print("Generate loss and grad...")
154+
requests.post('http://localhost:8000/setmsg', None,
155+
{'msg': "Generate loss and grad..."})
149156
# msg = "Generate loss and grad..."
150157
losses = [l * w for l, w in zip(*self.get_loss())]
151158
total_loss = sum(losses)

0 commit comments

Comments
 (0)