Skip to content

Commit

Permalink
Merge pull request #391 from corochann/fix_example_test
Browse files Browse the repository at this point in the history
support more models in example
  • Loading branch information
mottodora authored Sep 11, 2019
2 parents 2a45d68 + 266f08f commit de85501
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 6 deletions.
3 changes: 2 additions & 1 deletion examples/molnet/predict_molnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
def parse_arguments():
# Lists of supported preprocessing methods/models.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',
'gin', 'nfp_gwm', 'ggnn_gwm', 'rsgcn_gwm', 'gin_gwm']
'relgat', 'gin', 'gnnfilm',
'nfp_gwm', 'ggnn_gwm', 'rsgcn_gwm', 'gin_gwm']
# scale_list = ['standardize', 'none']
dataset_names = list(molnet_default_config.keys())

Expand Down
2 changes: 1 addition & 1 deletion examples/molnet/train_molnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def parse_arguments():
# Lists of supported preprocessing methods/models and datasets.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',
'relgat', 'gin',
'relgat', 'gin', 'gnnfilm',
'nfp_gwm', 'ggnn_gwm', 'rsgcn_gwm', 'gin_gwm']
dataset_names = list(molnet_default_config.keys())
scale_list = ['standardize', 'none']
Expand Down
3 changes: 2 additions & 1 deletion examples/qm9/predict_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
def parse_arguments():
# Lists of supported preprocessing methods/models.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',
'relgat']
'relgat', 'gin', 'gnnfilm',
'nfp_gwm', 'ggnn_gwm', 'rsgcn_gwm', 'gin_gwm']
label_names = ['A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2',
'zpve', 'U0', 'U', 'H', 'G', 'Cv']
scale_list = ['standardize', 'none']
Expand Down
2 changes: 1 addition & 1 deletion examples/qm9/test_qm9.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -e

# List of available graph convolution methods.
methods=(nfp ggnn schnet weavenet rsgcn relgcn relgat)
methods=(nfp ggnn schnet weavenet rsgcn relgcn relgat gin gnnfilm nfp_gwm ggnn_gwm rsgcn_gwm gin_gwm)

# device identifier; set it to -1 to train on the CPU (default).
device=${1:--1}
Expand Down
3 changes: 2 additions & 1 deletion examples/qm9/train_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def rmse(x0, x1):
def parse_arguments():
# Lists of supported preprocessing methods/models.
method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',
'relgat', 'gnnfilm']
'relgat', 'gin', 'gnnfilm',
'nfp_gwm', 'ggnn_gwm', 'rsgcn_gwm', 'gin_gwm']
label_names = ['A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2',
'zpve', 'U0', 'U', 'H', 'G', 'Cv']
scale_list = ['standardize', 'none']
Expand Down
2 changes: 1 addition & 1 deletion examples/tox21/test_tox21.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ device=${1:--1}
# Preprocessor parse result must contain both pos/neg samples
tox21_num_data=100

for method in nfp ggnn schnet weavenet rsgcn relgcn relgat
for method in nfp ggnn
do
if [ ! -f "input" ]; then
rm -rf input
Expand Down

0 comments on commit de85501

Please sign in to comment.