-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtraining_record.py
105 lines (91 loc) · 3.09 KB
/
training_record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import time
import re
import os
from os import path, listdir
import requests, json
class TrainingRecorder(object):
def __init__(self, loss_set, output_dir, saver, ext=None, axis='iteration'):
self.start_time = time.time()
self.last_time = self.start_time
self.loss_set = loss_set
self.loss_label = list(list(zip(*loss_set))[0])
self.loss_label.insert(0, 'total') # add total loss
self.losses = dict([(n, []) for n in self.loss_label])
self.outputs = []
self.output_dir = output_dir
if not path.exists(self.output_dir):
os.makedirs(self.output_dir)
self.ext = ext
self.saver = saver
self.axis = axis
def get_head(self):
"""
Get the index which we should start from
:return:
"""
# idx = -1
with open('result.json', 'r') as f:
myjson = json.load(f)
if len(myjson['output']) == 0:
return 1
self.outputs = myjson['output']
self.losses = myjson['loss']
it = myjson['output'][-1]
return it[0]+1
def reset(self):
"""
Reset start_time
:return:
"""
self.start_time = time.time()
def record(self, iteration, losses, output):
"""
Record it.
:param iteration: current iteration
:param losses: current loss
:param output: current output
:return:
"""
nowtime = time.time()
print('Iteration %d completed in %fs' % (iteration, nowtime - self.last_time))
timeoffset = nowtime - self.start_time
self.last_time = nowtime
axis = iteration if self.axis == 'iteration' else timeoffset
# Append loss
for t, v in zip(self.loss_label, losses):
self.losses[t].append([axis, float(v)])
print('%s loss value: %e' % (t, v))
# Append output
output_path = self.get_name(iteration)
self.saver(output, output_path)
self.outputs.append([axis, output_path])
print("results saved at %s \n" % output_path)
myjson = {
'losses': [float(loss) for loss in losses],
'loss_label': self.loss_label,
'output': [axis, output_path],
'iter': iteration,
'cache': self.export()
}
requests.post('http://localhost:8000/record', None, myjson)
def get_name(self, idx):
"""
Get the path to file according to the index
:param idx: current index
:return: Path to file
"""
return path.join(self.output_dir, str(idx)) + (self.ext if self.ext else '')
def export(self):
"""
Export recording as json string
:return: json for recording
"""
return json.dumps({'loss': self.losses, 'output': self.outputs})
def export_file(self):
"""
Export recording to file `$output_dir.json`
:return:
"""
with open('result.json', 'w') as f:
return json.dump({'loss': self.losses, 'output': self.outputs}, f)