-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathtrainer.py
58 lines (44 loc) · 1.67 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pdb
import json
import pandas as pd
import pickle
import argparse
from pathlib import Path
from kafka import KafkaConsumer
from utils.messages_utils import publish_traininig_completed
from utils.preprocess_data import build_train
KAFKA_HOST = 'localhost:9092'
RETRAIN_TOPIC = 'retrain_topic'
PATH = Path('data/')
TRAIN_DATA = PATH/'train/train.csv'
DATAPROCESSORS_PATH = PATH/'dataprocessors'
MODELS_PATH = PATH/'models'
MESSAGES_PATH = PATH/'messages'
def train(model_id, messages, hyper):
print("RETRAINING STARTED (model id: {})".format(model_id))
dtrain = build_train(TRAIN_DATA, DATAPROCESSORS_PATH, model_id, messages)
if hyper == "hyperopt":
# from train.train_hyperopt import LGBOptimizer
from train.train_hyperopt_mlflow import LGBOptimizer
elif hyper == "hyperparameterhunter":
# from train.train_hyperparameterhunter import LGBOptimizer
from train.train_hyperparameterhunter_mlfow import LGBOptimizer
LGBOpt = LGBOptimizer(dtrain, MODELS_PATH)
LGBOpt.optimize(maxevals=2, model_id=model_id)
print("RETRAINING COMPLETED (model id: {})".format(model_id))
def start(hyper):
consumer = KafkaConsumer(RETRAIN_TOPIC, bootstrap_servers=KAFKA_HOST)
for msg in consumer:
message = json.loads(msg.value)
if 'retrain' in message and message['retrain']:
model_id = message['model_id']
batch_id = message['batch_id']
message_fname = 'messages_{}_.txt'.format(batch_id)
messages = MESSAGES_PATH/message_fname
train(model_id, messages, hyper)
publish_traininig_completed(model_id)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--hyper", type=str, default="hyperopt")
args = parser.parse_args()
start(args.hyper)