Skip to content

Commit 2162102

Browse files
committed
Initial
0 parents  commit 2162102

File tree

8 files changed

+271
-0
lines changed

8 files changed

+271
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.cache
2+
.idea
3+
.python-version
4+
*.egg-info
5+
*.pyc

LICENSE

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright 2017 Andrej Ocenas
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

README.md

Whitespace-only changes.

docker-compose.yml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
version: "2"
2+
services:
3+
elastic:
4+
image: elasticsearch
5+
ports:
6+
- 9200:9200
7+
- 9300:9300
8+
9+
kibana:
10+
image: kibana
11+
ports:
12+
- 5601:5601
13+
links:
14+
- elastic:elasticsearch

keras_elastic_callback/__init__.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from __future__ import print_function
2+
from __future__ import division
3+
from __future__ import unicode_literals
4+
5+
from datetime import datetime
6+
from keras.callbacks import Callback
7+
from elasticsearch import Elasticsearch
8+
from elasticsearch.helpers import bulk
9+
import numpy as np
10+
11+
12+
class ElasticCallback(Callback):
13+
"""
14+
Sends all data to elasticsearch.
15+
"""
16+
def __init__(
17+
self,
18+
run_name,
19+
index_name='keras',
20+
event_data=None,
21+
url=None,
22+
es_client=None,
23+
buffer_length=1,
24+
):
25+
"""
26+
:param str run_name: Name of the run that can be searched for.
27+
:param str index_name: Name of index in elasticsearch. Must be lowercase
28+
and cannot contain special characters.
29+
:param Dict event_data: Additional data that will be merged each each
30+
event.
31+
:param str url: Elasticsearch url <host>[:<port>]
32+
:param object es_client: Elasticsearch client
33+
:param int buffer_length: Length of event buffer, before events are sent
34+
to Elasticsearch. If length is 1, event is sent after and before
35+
each batch, which can slow down learning. If buffer_length is 0, events
36+
are flushed only on epoch end. If buffer_length is > 1, number of events
37+
will be buffered before flushed, but still everything is flushed on
38+
epoch end.
39+
"""
40+
super(ElasticCallback, self).__init__()
41+
if es_client:
42+
self.es = es_client
43+
else:
44+
self.es = Elasticsearch([url])
45+
self.index = index_name
46+
self.run_name = run_name
47+
self.data = event_data or {}
48+
self.batch_start_time = None
49+
self.epoch_start_time = None
50+
self.event_buffer = []
51+
self.buffer_length = buffer_length
52+
53+
def on_train_begin(self, logs={}):
54+
self._index(
55+
'train_begin',
56+
logs,
57+
)
58+
59+
def on_epoch_begin(self, epoch, logs={}):
60+
self.epoch_start_time = datetime.now()
61+
self._index(
62+
'epoch_begin',
63+
logs,
64+
epoch=epoch,
65+
)
66+
67+
def on_batch_begin(self, batch, logs={}):
68+
self.batch_start_time = datetime.now()
69+
self._add_to_queue(
70+
'batch_begin',
71+
logs,
72+
)
73+
74+
def on_batch_end(self, batch, logs={}):
75+
self._add_to_queue(
76+
'batch_end',
77+
logs,
78+
duration=(datetime.now() - self.batch_start_time).seconds,
79+
)
80+
81+
def on_epoch_end(self, epoch, logs={}):
82+
self._flush_queue()
83+
self._index(
84+
'epoch_end',
85+
logs,
86+
duration=(datetime.now() - self.epoch_start_time).seconds,
87+
epoch=epoch,
88+
)
89+
90+
def on_train_end(self, logs={}):
91+
self._index('train_end', logs)
92+
93+
def _add_to_queue(self, doc_type, logs, **kw):
94+
if self.buffer_length == 1:
95+
self._index(
96+
doc_type,
97+
logs,
98+
**kw
99+
)
100+
else:
101+
self.event_buffer.append((
102+
doc_type,
103+
self._create_event_body(doc_type, logs, **kw)
104+
))
105+
106+
if len(self.event_buffer) == self.buffer_length:
107+
self._flush_queue()
108+
109+
def _flush_queue(self):
110+
if len(self.event_buffer):
111+
bulk(self.es, self._map_actions(self.event_buffer))
112+
self.event_buffer = []
113+
114+
def _map_actions(self, events):
115+
def mapper(event):
116+
return {
117+
'_op_type': 'index',
118+
'_index': self.index,
119+
'_type': event[0],
120+
'_source': event[1],
121+
}
122+
return map(mapper, events)
123+
124+
def _index(self, doc_type, logs, **kw):
125+
self.es.index(
126+
index=self.index,
127+
doc_type=doc_type,
128+
body=self._create_event_body(doc_type, logs, **kw),
129+
)
130+
131+
def _create_event_body(self, doc_type, logs, **kw):
132+
body = dict(
133+
timestamp=datetime.utcnow().isoformat(),
134+
event=doc_type,
135+
run_name=self.run_name,
136+
**self._convert_np_arrays(logs)
137+
)
138+
body.update(kw)
139+
body.update(self.data)
140+
return body
141+
142+
@staticmethod
143+
def _convert_np_arrays(data):
144+
"""
145+
Convert numpy ndarrays in a dictionary to list, so it can be serialized
146+
to JSON.
147+
"""
148+
return {
149+
k: v.tolist() if type(v) == np.ndarray else v
150+
for k, v
151+
in data.items()
152+
}

setup.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[metadata]
2+
description-file = README.md

setup.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from setuptools import setup
2+
3+
setup(
4+
name='keras_elastic_callback',
5+
version='0.1',
6+
description='Keras callbacks that sends all events and data to elasticsearch',
7+
url='http://github.com/aocenas/keras_elastic_callback',
8+
author='Andrej Ocenas',
9+
author_email='[email protected]',
10+
license='MIT',
11+
packages=['keras_elastic_callback'],
12+
install_requires=[
13+
'keras',
14+
'numpy',
15+
'elasticsearch',
16+
],
17+
tests_require=[
18+
'pytest',
19+
'mock',
20+
],
21+
keywords=['keras', 'elasticsearch', 'machine learning'],
22+
zip_safe=False
23+
)

tests/test_callback.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import sys
2+
print sys.path
3+
from keras_elastic_callback import ElasticCallback
4+
from elasticsearch import Elasticsearch
5+
from mock import NonCallableMock
6+
7+
8+
def test_data():
9+
es_mock = NonCallableMock(spec=Elasticsearch)
10+
callback = ElasticCallback(
11+
'test_run',
12+
'test_index',
13+
event_data={'data_key': 'data_value'},
14+
es_client=es_mock,
15+
)
16+
17+
callback.on_epoch_begin(1)
18+
es_mock.index.assert_called_once()
19+
args = es_mock.index.call_args[1]
20+
assert args['index'] == 'test_index'
21+
assert args['doc_type'] == 'epoch_begin'
22+
assert args['body']['event'] == 'epoch_begin'
23+
assert args['body']['run_name'] == 'test_run'
24+
assert args['body']['data_key'] == 'data_value'
25+
assert args['body']['epoch'] == 1
26+
27+
28+
events_names = [
29+
'train_begin',
30+
'epoch_begin',
31+
'batch_begin',
32+
'batch_end',
33+
'epoch_end',
34+
'train_end',
35+
]
36+
37+
38+
def test_all_events():
39+
es_mock = NonCallableMock(spec=Elasticsearch)
40+
callback = ElasticCallback(
41+
'test_run',
42+
'test_index',
43+
es_client=es_mock,
44+
)
45+
46+
for event in events_names:
47+
48+
func = getattr(callback, 'on_' + event)
49+
if 'batch' in event or 'epoch' in event:
50+
func(1)
51+
else:
52+
func()
53+
args = es_mock.index.call_args[1]
54+
assert args['doc_type'] == event
55+
56+
57+
def test_buffer():
58+
es_mock = NonCallableMock(spec=Elasticsearch)
59+
callback = ElasticCallback(
60+
'test_run',
61+
'test_index',
62+
es_client=es_mock,
63+
buffer_length=10
64+
)
65+
66+
callback.on_batch_begin(1)
67+
callback.on_batch_end(1)
68+
es_mock.index.assert_not_called()

0 commit comments

Comments
 (0)