forked from nyu-dl/conditional-molecular-design-ssvae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
86 lines (70 loc) · 2.28 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import print_function
import os
import tensorflow as tf
import numpy as np
import pandas as pds
from preprocessing import ZINC
import SSVAE
# FLAGS
FLAGS = tf.app.flags.FLAGS
# Experiment name
tf.flags.DEFINE_string('output_dir', './output', 'output folder.')
tf.flags.DEFINE_string('experiment_name', 'dbg300',
'All outputs of this experiment is'
' saved under a folder with the same name.')
tf.app.flags.DEFINE_bool('debug', True, 'debug mode.')
experiment_dir = os.path.join(FLAGS.output_dir, FLAGS.experiment_name)
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MkDir(FLAGS.output_dir)
if not tf.gfile.IsDirectory(experiment_dir):
tf.gfile.MkDir(experiment_dir)
# pre-defined parameters
frac=0.5
beta=10000.
data_uri='./data/ZINC_310k.csv'
save_uri=os.path.join(experiment_dir, 'models', 'model.ckpt')
debug=True
if debug:
ntrn=300
ntst=100
frac_val=0.1
dim_z = 10
dim_h = 25
n_hidden = 2
batch_size = 10
max_epoch = 10
else:
ntrn=300000
ntst=10000
frac_val=0.05
dim_z = 100
dim_h = 250
n_hidden = 3
batch_size = 200
max_epoch = 300
# data preparation
print('::: data preparation')
smiles = pds.read_csv(data_uri).as_matrix()[:ntrn+ntst,0] #0: SMILES
Y = np.asarray(pds.read_csv(data_uri).as_matrix()[:ntrn+ntst,1:], dtype=np.float32) # 1: MolWT, 2: LogP, 3: QED
list_seq = ZINC.smiles_to_seq(smiles)
Xs, X=ZINC.vectorize(list_seq)
seqlen_x = X.shape[1]
dim_x = X.shape[2]
dim_y = Y.shape[1]
model = SSVAE.Model(seqlen_x = seqlen_x, dim_x = dim_x, dim_y = dim_y, dim_z = dim_z, dim_h = dim_h,
n_hidden = n_hidden, batch_size = batch_size, beta = float(beta), char_set = ZINC.char_set,
save_uri = save_uri)
with model.session:
## model training
print('::: model training')
model.trainXY(max_epoch, X, Xs, Y, ntrn, ntst, frac, frac_val, experiment_dir)
## unconditional generation
for t in range(10):
smi = model.sampling_unconditional()
print([t, smi, ZINC.get_property(smi)])
## conditional generation (e.g. MolWt=250)
yid = 0
ytarget = 250.
for t in range(10):
smi = model.sampling_conditional_transform(yid, ytarget)
print([t, smi, ZINC.get_property(smi)])