-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
47 lines (42 loc) · 2.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import utils
import implicit
import pandas as pd
from pathlib import Path
from config import config
from model import LightGCN, TopNModel, TopNPersonalized, TopNNearestModel
from dataloader import GowallaLightGCNDataset, GowallaTopNDataset, GowallaALSDataset
if __name__ == '__main__':
dataset_path = Path('dataset') / config['DATASET'] / config['DATASET']
if config['MODEL'] == 'LightGCN':
train_dataset = GowallaLightGCNDataset(f'{dataset_path}_custom.train')
test_dataset = GowallaLightGCNDataset(f'{dataset_path}_custom.test', train=False)
model = LightGCN(train_dataset)
model.fit(config['TRAIN_EPOCHS'], test_dataset)
elif config['MODEL'] == 'TopNModel':
train_dataset = GowallaTopNDataset(f'{dataset_path}.train')
test_dataset = GowallaTopNDataset(f'{dataset_path}.test', train=False)
model = TopNModel(config['TOP_N'])
model.fit(train_dataset)
model.eval(test_dataset)
elif config['MODEL'] == 'TopNPersonalized':
train_dataset = GowallaTopNDataset(f'{dataset_path}.train')
test_dataset = GowallaTopNDataset(f'{dataset_path}.test', train=False)
model = TopNPersonalized(config['TOP_N'])
model.fit(train_dataset)
model.eval(test_dataset)
elif config['MODEL'] == 'TopNNearestModel':
train_dataset = GowallaTopNDataset(f'{dataset_path}.train')
test_dataset = GowallaTopNDataset(f'{dataset_path}.test', train=False)
df = pd.concat([train_dataset.df, test_dataset.df])
calc_nearest = utils.calc_nearest(df)
model = TopNNearestModel(config['TOP_N'], calc_nearest)
model.fit(train_dataset)
model.eval(test_dataset)
elif config['MODEL'] == 'iALS':
gowalla_train, user_item_data, item_user_data = GowallaALSDataset(
f'{dataset_path}.train').get_dataset()
gowalla_test = GowallaALSDataset(f'{dataset_path}.test', train=False).get_dataset()
model = implicit.als.AlternatingLeastSquares(
iterations=config['ALS_N_ITERATIONS'], factors=config['ALS_N_FACTORS'])
model.fit_callback = utils.eval_als_model(model, user_item_data, gowalla_test)
model.fit(item_user_data)