Skip to content

Commit

Permalink
Merge pull request #425 from corochann/refactor_wle
Browse files Browse the repository at this point in the history
Refactor wle
  • Loading branch information
corochann authored Aug 14, 2020
2 parents 1f12b9b + 883e5c0 commit ef7dc40
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 152 deletions.
3 changes: 2 additions & 1 deletion chainer_chemistry/dataset/parsers/data_frame_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def parse(self, df, return_smiles=False, target_index=None,
else:
raise NotImplementedError

smileses = numpy.array(smiles_list) if return_smiles else None
smileses = numpy.array(
smiles_list, dtype=object) if return_smiles else None
if return_is_successful:
is_successful = numpy.array(is_successful_list)
else:
Expand Down
3 changes: 2 additions & 1 deletion chainer_chemistry/dataset/parsers/sdf_file_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def parse(self, filepath, return_smiles=False, target_index=None,
# Spec not finalized yet for general case
result = pp.process(filepath)

smileses = numpy.array(smiles_list) if return_smiles else None
smileses = numpy.array(
smiles_list, dtype=object) if return_smiles else None
if return_is_successful:
is_successful = numpy.array(is_successful_list)
else:
Expand Down
7 changes: 1 addition & 6 deletions chainer_chemistry/datasets/molnet/molnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,7 @@ def postprocess_label(label_list):
raise TypeError("split must be None, str or instance of"
" BaseSplitter, but got {}".format(type(split)))

if isinstance(splitter, ScaffoldSplitter):
get_smiles = True
else:
get_smiles = return_smiles

if isinstance(splitter, DeepChemScaffoldSplitter):
if isinstance(splitter, (ScaffoldSplitter, DeepChemScaffoldSplitter)):
get_smiles = True
else:
get_smiles = return_smiles
Expand Down
12 changes: 11 additions & 1 deletion examples/molnet_wle/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ In this directory, we provide an implementaion of [Weisfeiler-Lehman Embedding (

## How to run the code

### Test run command

```bash
# Training tox21 dataset using RSGCN-CWLE model. Short 3 epoch for testing.
python train_molnet_wle.py --dataset tox21 --method rsgcn_cwle --epoch 3 --device 0

# Prediction with trained model
python predict_molnet_wle.py --dataset tox21 --method rsgcn_cwle --in-dir result --device 0
```

### Train the model by specifying dataset

Basically, no changes from the original molnet examples (examples/molnet/train_molnet.py).
Expand All @@ -14,7 +24,7 @@ To test WLE, choose one of 'xxx_wle', 'xxx_cwle', and 'xxx_gwle' where 'xxx' is
- xxx_cwle (recommended): apply the Concat WLE to the GNN 'xxx'
- xxx_gwle: apply the Gated-sum WLE to the GNN 'xxx'

#### additional options
#### Additional options

Introducing the WLE, we have some more additional options.
In general you do not need to specify these options (use the default values!).
Expand Down
2 changes: 0 additions & 2 deletions examples/molnet_wle/predict_molnet_wle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from __future__ import print_function

import argparse
import json
import os

import chainer
from chainer.iterators import SerialIterator
from chainer.training.extensions import Evaluator
from chainer_chemistry.training.extensions.roc_auc_evaluator import ROCAUCEvaluator # NOQA
from chainer import cuda
# Proposed by Ishiguro
# ToDo: consider go/no-go with following modification
# Re-load the best-validation score snapshot using serializers
Expand Down
133 changes: 0 additions & 133 deletions examples/molnet_wle/summary_eval_molnet.py

This file was deleted.

9 changes: 1 addition & 8 deletions examples/molnet_wle/train_molnet_wle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from chainer.training import extensions as E

from chainer_chemistry.dataset.converters import converter_method_dict
from chainer_chemistry.config import MAX_ATOMIC_NUM
from chainer_chemistry.dataset.converters import concat_mols
from chainer_chemistry.dataset.preprocessors import preprocess_method_dict, wle
from chainer_chemistry import datasets as D
from chainer_chemistry.datasets.molnet.molnet_config import molnet_default_config # NOQA
Expand All @@ -27,7 +25,6 @@
from chainer_chemistry.models.prediction import Classifier
from chainer_chemistry.models.prediction import Regressor
from chainer_chemistry.models.prediction import set_up_predictor
from chainer_chemistry.training.extensions import BatchEvaluator, ROCAUCEvaluator # NOQA
from chainer_chemistry.training.extensions.auto_print_report import AutoPrintReport # NOQA
from chainer_chemistry.utils import save_json
from chainer_chemistry.models.cwle.cwle_graph_conv_model import MAX_WLE_NUM
Expand Down Expand Up @@ -156,11 +153,7 @@ def download_entire_dataset(dataset_name, num_data, labels, method, cache_dir, a
labels=labels,
split=dc_scaffold_splitter,
target_index=target_index)
# To use the splitter defined in the config file
#dataset_parts = D.molnet.get_molnet_dataset(dataset_name, preprocessor,
# labels=labels,
# target_index=target_index)


dataset_parts = dataset_parts['dataset']

# Cache the downloaded dataset.
Expand Down

0 comments on commit ef7dc40

Please sign in to comment.