-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
120 lines (105 loc) · 3.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
@author: Gaetan Hadjeres
"""
import importlib
import os
import shutil
from datetime import datetime
import click
import torch
from transformer_bach.bach_dataloader import BachDataloaderGenerator
from transformer_bach.decoder_relative import TransformerBach
from transformer_bach.getters import get_data_processor
from transformer_bach.melodies import MARIO_MELODY, TETRIS_MELODY, LONG_TETRIS_MELODY
@click.command()
@click.option('-t', '--train', is_flag=True)
@click.option('-l', '--load', is_flag=True)
@click.option('-o', '--overfitted', is_flag=True)
@click.option('-c', '--config', type=click.Path(exists=True))
@click.option('-n', '--num_workers', type=int, default=0)
def main(train,
load,
overfitted,
config,
num_workers
):
# Use all gpus available
gpu_ids = [int(gpu) for gpu in range(torch.cuda.device_count())]
print(f'Using GPUs {gpu_ids}')
# Load config
config_path = config
config_module_name = os.path.splitext(config)[0].replace('/', '.')
config = importlib.import_module(config_module_name).config
# compute time stamp
if config['timestamp'] is not None:
timestamp = config['timestamp']
else:
timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
config['timestamp'] = timestamp
if load:
model_dir = os.path.dirname(config_path)
else:
model_dir = f'models/{config["savename"]}_{timestamp}'
# === Decoder ====
dataloader_generator_kwargs = config['dataloader_generator_kwargs']
dataloader_generator = BachDataloaderGenerator(
sequences_size=dataloader_generator_kwargs['sequences_size']
)
data_processor = get_data_processor(
dataloader_generator=dataloader_generator,
data_processor_type=config['data_processor_type'],
data_processor_kwargs=config['data_processor_kwargs']
)
decoder_kwargs = config['decoder_kwargs']
num_channels = 4
num_events_grouped = 4
num_events = dataloader_generator_kwargs['sequences_size'] * 4
transformer = TransformerBach(
model_dir=model_dir,
dataloader_generator=dataloader_generator,
data_processor=data_processor,
d_model=decoder_kwargs['d_model'],
num_encoder_layers=decoder_kwargs['num_encoder_layers'],
num_decoder_layers=decoder_kwargs['num_decoder_layers'],
n_head=decoder_kwargs['n_head'],
dim_feedforward=decoder_kwargs['dim_feedforward'],
dropout=decoder_kwargs['dropout'],
positional_embedding_size=decoder_kwargs['positional_embedding_size'],
num_channels=num_channels,
num_events=num_events,
num_events_grouped=num_events_grouped
)
if load:
if overfitted:
transformer.load(early_stopped=False)
else:
transformer.load(early_stopped=True)
transformer.to('cuda')
if train:
# Copy .py config file in the save directory before training
if not load:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
shutil.copy(config_path, f'{model_dir}/config.py')
transformer.to('cuda')
transformer.train_model(
batch_size=config['batch_size'],
num_batches=config['num_batches'],
num_epochs=config['num_epochs'],
lr=config['lr'],
plot=True,
num_workers=num_workers
)
melody_constraint = TETRIS_MELODY
# melody_constraint = LONG_TETRIS_MELODY
# melody_constraint = None
scores = transformer.generate(temperature=0.9,
top_p=0.7,
batch_size=3,
melody_constraint=melody_constraint,
hard_constraint=True,
show_debug_symbols=False,
exclude_non_note_symbols=True
)
if __name__ == '__main__':
main()