Skip to content

Commit 4b834cd

Browse files
committed
Comments and more comments.
1 parent 82ae661 commit 4b834cd

File tree

8 files changed

+126
-22
lines changed

8 files changed

+126
-22
lines changed

.gitignore

-4
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,3 @@
3737
ehthumbs.db
3838
Thumbs.db
3939
*.*~
40-
41-
# Data folders #
42-
################
43-
air/models

air/api.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,33 @@
88
LOSS_LENGTH = 0
99

1010
def make_register():
11+
""" Creates a dict registry object to hold all annotated definitions.
12+
Returns: A dictionary with all definitions in this file.
13+
"""
1114
registry = {}
1215
def registrar(func):
1316
registry[func.__name__] = func
1417
return func
1518
registrar.all = registry
1619
return registrar
20+
21+
# Dictionary holding all definitions in this file.
1722
endpoint = make_register()
1823

1924
def handleNaN(val):
25+
""" Turns all NaN values to 0.
26+
Returns: 0 if val == NaN, otherwise the value.
27+
"""
2028
if math.isnan(val):
2129
return 0
2230
return val
2331

2432
@endpoint
2533
def infer_types(args, files):
34+
""" Given a model handle, returns a dictionary of column-name to type JSON.
35+
Args: Model handle
36+
Returns: a JSON holding the column-name to type map.
37+
"""
2638
try:
2739
model = get_model(args['handle'])
2840
except Exception as e:
@@ -31,6 +43,10 @@ def infer_types(args, files):
3143

3244
@endpoint
3345
def infer(args, files):
46+
""" Given a model handle and input values, this def runs the model inference graph and returns the predictions.
47+
Args: Model handle, input values.
48+
Returns: A JSON containing all the model predictions.
49+
"""
3450
#clear_session() # Clears TF graphs.
3551
try:
3652
model = get_model(args['handle'])
@@ -45,6 +61,10 @@ def infer(args, files):
4561

4662
@endpoint
4763
def train(args, files):
64+
""" Runs the training for the given model.
65+
Args: Model handle.
66+
Returns: A JSON confirming that training has been kicked-off.
67+
"""
4868
clear_session() # Clears TF graphs.
4969
clear_thread_cache() # We need to clear keras models since graph is deleted.
5070
try:
@@ -56,6 +76,10 @@ def train(args, files):
5676

5777
@endpoint
5878
def train_status(args, files):
79+
""" Grabs the metrics from disk and returns them for the given model handle.
80+
Args: Model handle.
81+
Returns: A JSON with a dictionary of keras_model_name -> metric_name -> list(metric values)
82+
"""
5983
try:
6084
model = get_model(args['handle'])
6185
except:
@@ -73,6 +97,14 @@ def train_status(args, files):
7397

7498
@endpoint
7599
def upload_csv(args, files):
100+
""" Takes in a csv and creates a Model around it.
101+
CSV needs to have a feature per column. Also needs to have at least one column marked as output by prepending
102+
'output_' to the column name (first row in file). Types will be conservatively infered from the input (ie type will be
103+
string as long as one cell contains a non-numeric character).
104+
105+
Files: Path to tmp CSV file on server (handled by the framework).
106+
Returns: A JSON with the model handle just created, and the infered feature types.
107+
"""
76108
if 'upload' not in files:
77109
print 'Files not specified in upload: ' + files
78110
return 'No file specified'
@@ -94,8 +126,13 @@ def upload_csv(args, files):
94126

95127
return json.dumps({'status': 'OK', 'handle': model.get_handle(), 'types': model.types})
96128

97-
# Calls endpoint with a map of arguments.
98129
def resolve_endpoint(endpoint_str, args, files):
130+
""" Reroutes the request to the matching endpoint definition.
131+
See make_registrar for more information.
132+
Params: arguments and files needed to run the endpoint (every endpoint receives both dictionaries). Also receives the
133+
name of the endpoint.
134+
Returns: The output of the endpoint.
135+
"""
99136
if endpoint_str in endpoint.all:
100137
return endpoint.all[endpoint_str](args, files)
101138
else:

air/config.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22

33
class Config():
4-
mc = {}
5-
ROOT_PATH = os.path.dirname(os.path.realpath(__file__))
6-
STATIC_PATH = os.path.dirname(os.path.realpath(__file__)) + '/static'
4+
""" Static global class holding important parameters.
5+
"""
6+
mc = {} # Memchached object shared across threads.
7+
ROOT_PATH = os.path.dirname(os.path.realpath(__file__)) # Path to the project's directory.
8+
STATIC_PATH = os.path.dirname(os.path.realpath(__file__)) + '/static' # Path to static resources.
79

810
def set_mc(self, mc):
911
self.mc = mc
@@ -12,4 +14,4 @@ def get_mc(self):
1214
return self.mc
1315

1416
global config
15-
config = Config()
17+
config = Config() # Singleton class object to be used across the project.

air/db.py

+44-9
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,23 @@
1212

1313
MODEL_PATH = '/models'
1414
HANDLE_LENGTH = 10
15-
keras_cache = {}
15+
keras_cache = {} # Local thread memory cache.
1616

1717
def clear_thread_cache():
1818
keras_cache = {}
1919

2020
def handle2path(handle):
21+
""" Translates a handle to a full path on the FS. """
2122
return config.ROOT_PATH + MODEL_PATH + "/" + handle
2223

2324
def path2handle(path):
25+
""" Translates a full path to a handle. """
2426
return path.split('/')[-1]
2527

2628
def new_model():
29+
""" Construcs an empty Model assiging it a new random hex ID and persiting it to disk.
30+
Returns: The Model instance.
31+
"""
2732
filename = random_hex()
2833
while os.path.isfile(filename):
2934
filename = random_hex()
@@ -38,24 +43,34 @@ def new_model():
3843
return model
3944

4045
def save_model(model):
46+
""" Saves the given model to disk. """
4147
with open(model.model_path, 'w+') as f:
4248
f.write(model.to_json())
4349
config.get_mc().set(path2handle(model.model_path), model.to_json())
4450

4551
def get_model(handle):
52+
""" Fetches the model from memory or disk with a matching handle.
53+
Returns: The Model instance if the model is found, None otherwise.
54+
"""
4655
mem_try = config.get_mc().get(handle)
4756
if mem_try:
4857
m = Model()
4958
m.from_json(mem_try)
5059
return m
5160
model_path = config.ROOT_PATH + MODEL_PATH + "/" + handle
52-
with open(model_path, "r") as f:
53-
model = Model()
54-
model.from_json(f.read())
55-
config.get_mc().set(handle, model.to_json())
56-
return model
61+
try:
62+
with open(model_path, "r") as f:
63+
model = Model()
64+
model.from_json(f.read())
65+
config.get_mc().set(handle, model.to_json())
66+
return model
67+
except:
68+
return None
5769

5870
def parse_val(value):
71+
""" Infers the type of the value by trying to parse it to different formats.
72+
Returns: The parse value and the type.
73+
"""
5974
if not value:
6075
return value, None
6176
tests = [
@@ -83,6 +98,8 @@ def parse_val(value):
8398
return value.decode('utf-8', 'ignore'), 'str'
8499

85100
def persist_keras_model(handle, model):
101+
""" Persists a keras model to disk.
102+
"""
86103
model_dir = config.ROOT_PATH + MODEL_PATH
87104

88105
# Clear first all previously persisted models.
@@ -94,6 +111,9 @@ def persist_keras_model(handle, model):
94111
model.save(os.path.join(model_dir, name))
95112

96113
def _load_keras_model(handle):
114+
""" Loads a keras model from disk.
115+
Returns: The keras model instance if found, None otherwise.
116+
"""
97117
name = handle + '_keras'
98118
print 'load ' + name
99119
model_dir = config.ROOT_PATH + MODEL_PATH
@@ -106,6 +126,9 @@ def _load_keras_model(handle):
106126
return model
107127

108128
def load_keras_model(handle):
129+
""" Loads a keras model from cache or disk.
130+
Returns: The keras model instance.
131+
"""
109132
if handle in keras_cache:
110133
print 'From thread cache'
111134
return keras_cache[handle]
@@ -114,12 +137,20 @@ def load_keras_model(handle):
114137
return model
115138

116139
def delete_model(handle):
140+
""" Deletes all models with the given handle if found. """
117141
model_dir = config.ROOT_PATH + MODEL_PATH
118142
for f in os.listdir(model_dir):
119143
if re.search(handle + ".*", f):
120144
os.remove(os.path.join(model_dir, f))
145+
config.get_mc().delete(handle)
146+
121147

122148
def load_csvs(file_list):
149+
""" Loads csv from files and returns the parsed value dictionary.
150+
Params: The list of files.
151+
Returns: Three dictionaries. The first is feature-name -> value_list, the second one feature_name -> type and the
152+
third one feature_name -> [min_value, max_value] if applies.
153+
"""
123154
print 'File of csvs to load ' + unicode(file_list)
124155
data = {}
125156
types = {}
@@ -129,7 +160,7 @@ def load_csvs(file_list):
129160
reader = csv.reader(read_f)
130161
headers = []
131162
for row in reader:
132-
if not headers:
163+
if not headers: # If first row, load the headers assuming they are contained in the first row.
133164
headers = row
134165
output_headers = 0
135166
for h in headers:
@@ -138,7 +169,7 @@ def load_csvs(file_list):
138169
output_headers += 1 if h.startswith('output_') else 0
139170
if not output_headers:
140171
return 'No outputs defined in CSV. Please define columns as outputs by preppending \'output_\'.', ''
141-
else:
172+
else: # If not first row, parse values assuming the headers dictionary has been already filled.
142173
for idx, value in enumerate(row):
143174
val, typ = parse_val(value)
144175
data[headers[idx]].append(val)
@@ -147,22 +178,25 @@ def load_csvs(file_list):
147178
types[headers[idx]] = typ
148179
else:
149180
print 'WARN: CSV %s not found' % f
181+
150182
# Fix '' values, and standardize formats.
151183
for header, column in data.iteritems():
152184
for idx, value in enumerate(column):
153185
if not value:
154186
data[header][idx] = 0 if types[header] != 'str' else ''
155187
else:
156188
data[header][idx] = unicode(data[header][idx]) if types[header] == 'str' else data[header][idx]
189+
157190
# Normalize numeric inputs to -1 to 1.
158191
norms = {}
159-
160192
for header, column in data.iteritems():
161193
if types[header] != 'str':
162194
floor = float(min(column))
163195
ceil = float(max(column))
164196
norms[header] = (floor, ceil)
165197
data[header] = [(x-floor)/(ceil - floor) for x in column]
198+
199+
# Run some last verifications so that all features have the same amount of rows.
166200
length = 0
167201
for header, column in data.iteritems():
168202
if not length:
@@ -173,6 +207,7 @@ def load_csvs(file_list):
173207
return data, types, norms
174208

175209
def random_hex():
210+
""" Creates a random hex string ID """
176211
ran = random.randrange(16**HANDLE_LENGTH)
177212
return "%010x" % ran
178213

air/keras_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tensorflow as tf
55

66
def activate(activation, tensor):
7+
""" Maps a string activation to the keras backend function """
78
if activation == 'tanh':
89
return K.tanh(tensor)
910
elif activation == 'sigmoid':

0 commit comments

Comments
 (0)