-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathneural_style.py
233 lines (190 loc) · 6.2 KB
/
neural_style.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from __future__ import print_function
import os
from tornado.ioloop import IOLoop
from tornado.web import RequestHandler, Application
from tornado.escape import json_decode
import json
# from main import start_training
import multiprocessing
from tornado.concurrent import Future
from tornado import gen
import uuid
# from training_record import TrainingRecorder
__author__ = 'xymeow'
# TODO: show training messages on the front end
start = False
trained = False
stop = False
lr = 1.0
p = None
record_cache = None
# use long polling to update the training results
# I copied these codes from tornado/demo/chat
class MessageBuffer(object):
def __init__(self):
self.waiters = set()
self.cache = []
self.cache_size = 200
def wait_for_messages(self, cursor=None):
# Construct a Future to return to our caller. This allows
# wait_for_messages to be yielded from a coroutine even though
# it is not a coroutine itself. We will set the result of the
# Future when results are available.
result_future = Future()
if cursor:
new_count = 0
for msg in reversed(self.cache):
if msg["id"] == cursor:
break
new_count += 1
if new_count:
result_future.set_result(self.cache[-new_count:])
return result_future
self.waiters.add(result_future)
return result_future
def cancel_wait(self, future):
self.waiters.remove(future)
# Set an empty result to unblock any coroutines waiting.
future.set_result([])
def new_messages(self, messages):
# logging.info("Sending new message to %r listeners", len(self.waiters))
for future in self.waiters:
future.set_result(messages)
self.waiters = set()
self.cache.extend(messages)
if len(self.cache) > self.cache_size:
self.cache = self.cache[-self.cache_size:]
record_pool = MessageBuffer()
message_pool = MessageBuffer()
class MainHandler(RequestHandler):
def get(self):
self.render('index.html', json='/json')
def post(self):
json = json_decode(self.request.body)
self.render('index.html', json=json)
class RecordHandler(RequestHandler):
def post(self):
global record_cache
json_dict = json_decode(self.request.body)
print(self.request.body)
record_cache = json_dict['cache']
json_dict.pop('cache')
myjson = {
'id': str(uuid.uuid4()),
'json': json_dict
}
record_pool.new_messages([myjson])
class PullRecordHandler(RequestHandler):
def get(self):
return
@gen.coroutine
def post(self):
cursor = self.get_argument("cursor", None)
# Save the future returned by wait_for_messages so we can cancel
# it in wait_for_messages
self.future = record_pool.wait_for_messages(cursor=cursor)
messages = yield self.future
if self.request.connection.stream.closed():
return
self.write(dict(messages=messages))
def on_connection_close(self):
record_pool.cancel_wait(self.future)
class TrainHandler(RequestHandler):
def get(self):
global trained
global start
global p
if not start:
start = True
trained = True
p = multiprocessing.Process(target=start_training)
p.start()
elif trained:
trained = False
elif not trained:
trained = True
self.write(json.dumps(trained))
class StatusHandler(RequestHandler):
def get(self):
global trained
status = 'training'
if not trained and not stop:
status = 'pause'
if stop or not start:
status = 'stop'
self.write(status)
class StopHandler(RequestHandler):
def get(self):
global stop
global start
global p
stop = True
start = False
if p != None:
p.terminate()
class PictureHandler(RequestHandler):
def get(self):
print(self.get_argument('path'))
p = open(self.get_argument('path'), 'rb')
self.write(p.read())
p.close()
class InitJSONHandler(RequestHandler):
def get(self):
global record_cache
if record_cache == None:
with open('result.json', 'rb') as f:
self.write(f.read())
else:
self.write(record_cache)
class LearningRateHandler(RequestHandler):
def get(self):
global lr
self.write(str(lr))
def post(self):
global lr
# print(self.get_argument('lr'))
lr = float(self.get_argument('lr'))
print('set learnig rate to {}'.format(lr))
class PullMessageHandler(RequestHandler):
@gen.coroutine
def post(self):
cursor = self.get_argument("cursor", None)
# Save the future returned by wait_for_messages so we can cancel
# it in wait_for_messages
self.future = message_pool.wait_for_messages(cursor=cursor)
messages = yield self.future
if self.request.connection.stream.closed():
return
self.write(dict(messages=messages))
def on_connection_close(self):
message_pool.cancel_wait(self.future)
class SetMessageHandler(RequestHandler):
def post(self):
# global record_cache
json_dict = json_decode(self.request.body)
print(self.request.body)
myjson = {
'id': str(uuid.uuid4()),
'json': json_dict
}
message_pool.new_messages([myjson])
settings = {
"static_path": os.path.join(os.path.dirname(__file__), "static")
}
application = Application([
(r"/", MainHandler),
(r"/json", PullRecordHandler),
(r"/stop", StopHandler),
(r"/record", RecordHandler),
(r"/status", StatusHandler),
(r"/picture", PictureHandler),
(r"/init", InitJSONHandler),
(r"/lr", LearningRateHandler),
(r"/train", TrainHandler),
(r"/setmsg", SetMessageHandler),
(r"/getmsg", PullMessageHandler)
], **settings)
if __name__ == "__main__":
print('server start at localhost:8000')
application.listen(8000)
IOLoop.instance().start()