Commit b3e5e256 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #358 from CCXD/master

Delaney, kaggle and nci example bugfixes and ChEMBL dataset
parents 4c1c2bb8 98c63fbf
Loading
Loading
Loading
Loading
+988 KiB

File added.

No diff preview for this file type.

+8.11 MiB

File added.

No diff preview for this file type.

+0 −0

Empty file added.

+98 −0
Original line number Diff line number Diff line
"""
ChEMBL dataset loader.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import time
import numpy as np
import deepchem as dc
import sys

sys.path.append(".")
from chembl_tasks import chembl_tasks

# Set shard size low to avoid memory problems.
def load_chembl(shard_size=2000, featurizer="ECFP", set="5thresh", split="random"):
    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
    # Set some global variables up top
    current_dir = os.path.dirname(os.path.realpath(__file__))

    # Load dataset
    print("About to load ChEMBL dataset.")
    if split == "year":
        train_datasets, valid_datasets, test_datasets = [], [], []
        train_files = os.path.join(current_dir,
                                   "year_sets/chembl_%s_ts_train.csv.gz" % set)
        valid_files = os.path.join(current_dir,
                                   "year_sets/chembl_%s_ts_valid.csv.gz" % set)
        test_files = os.path.join(current_dir,
                                  "year_sets/chembl_%s_ts_test.csv.gz" % set)
    else:
        dataset_path = os.path.join(
            current_dir, "../../datasets/chembl_%s.csv.gz" % set)

    # Featurize ChEMBL dataset
    print("About to featurize ChEMBL dataset.")
    if featurizer == 'ECFP':
        featurizer = dc.feat.CircularFingerprint(size=1024)
    elif featurizer == 'GraphConv':
        featurizer = dc.feat.ConvMolFeaturizer()

    loader = dc.data.CSVLoader(
        tasks=chembl_tasks, smiles_field="smiles", featurizer=featurizer)

    if split == "year":
        print("Featurizing train datasets")
        train_dataset = loader.featurize(
            train_files, shard_size=shard_size)

        print("Featurizing valid datasets")
        valid_dataset = loader.featurize(
            valid_files, shard_size=shard_size)

        print("Featurizing test datasets")
        test_dataset = loader.featurize(
            test_files, shard_size=shard_size)
    else:
        dataset = loader.featurize(dataset_path, shard_size=shard_size)

    # Initialize transformers
    print("About to transform data")
    if split == "year":
        transformers = [
            dc.trans.NormalizationTransformer(transform_y=True, dataset=train_dataset)]
        for transformer in transformers:
            for dataset in [train_dataset, valid_dataset, test_dataset]:
                transformer.transform(dataset)
    else:
        transformers = [
            dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]
        for transformer in transformers:
            dataset = transformer.transform(dataset)

    splitters = {'index': dc.splits.IndexSplitter(),
                 'random': dc.splits.RandomSplitter(),
                 'scaffold': dc.splits.ScaffoldSplitter()}
    if split in splitters:
        splitter = splitters[split]
        print("Performing new split.")
        train, valid, test = splitter.train_valid_test_split(dataset)
    elif split == "year":
        print("Featurizing train datasets")
        train = loader.featurize(
            train_files, shard_size=shard_size)

        print("Featurizing valid datasets")
        valid = loader.featurize(
            valid_files, shard_size=shard_size)

        print("Featurizing test datasets")
        test = loader.featurize(
            test_files, shard_size=shard_size)

    return chembl_tasks, (train, valid, test), transformers
+69 −0
Original line number Diff line number Diff line
"""
Script that trains graph-conv models on ChEMBL dataset.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
import deepchem as dc
from keras import backend as K
from chembl_datasets import load_chembl

# Only for debug!
np.random.seed(123)

g = tf.Graph()
sess = tf.Session(graph=g)
K.set_session(sess)

with g.as_default():
  tf.set_random_seed(123)
  chembl_tasks, datasets, transformers = load_chembl(shard_size=2000,
    featurizer="GraphConv", set="5thresh", split="random")
  train_dataset, valid_dataset, test_dataset = datasets

  # Fit models
  metric = dc.metrics.Metric(dc.metrics.pearson_r2_score, np.mean)

  # Do setup required for tf/keras models
  # Number of features on conv-mols
  n_feat = 75
  # Batch size of models
  batch_size = 128
  graph_model = dc.nn.SequentialGraph(n_feat)
  graph_model.add(dc.nn.GraphConv(128, activation='relu'))
  graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
  graph_model.add(dc.nn.GraphPool())
  graph_model.add(dc.nn.GraphConv(128, activation='relu'))
  graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
  graph_model.add(dc.nn.GraphPool())
  # Gather Projection
  graph_model.add(dc.nn.Dense(256, activation='relu'))
  graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
  graph_model.add(dc.nn.GraphGather(batch_size, activation="tanh"))
  # Dense post-processing layer

  with tf.Session() as sess:
    model = dc.models.MultitaskGraphRegressor(
      sess, graph_model, len(chembl_tasks), batch_size=batch_size,
      learning_rate=1e-3, learning_rate_decay_time=1000,
      optimizer_type="adam", beta1=.9, beta2=.999)

    # Fit trained model
    model.fit(train_dataset, nb_epoch=20)

    print("Evaluating model")
    train_scores = model.evaluate(train_dataset, [metric], transformers)
    valid_scores = model.evaluate(valid_dataset, [metric], transformers)
    test_scores = model.evaluate(test_dataset, [metric], transformers)

    print("Train scores")
    print(train_scores)

    print("Validation scores")
    print(valid_scores)

    print("Test scores")
    print(test_scores)
Loading