Commit 3fb43dc0 authored by cc's avatar cc
Browse files

Added ChEMBL dataset with examples

parent 02a498dc
Loading
Loading
Loading
Loading

datasets/chembl.csv

0 → 100644
+19107 −0

File added.

Preview size limit exceeded, changes collapsed.

+0 −0

Empty file added.

+74 −0
Original line number Diff line number Diff line
"""
KAGGLE dataset loader.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import time
import numpy as np
import deepchem as dc
import sys

sys.path.append(".")
from chembl_tasks import chembl_tasks

def remove_missing_entries(dataset):
    """Remove missing entries.

    Some of the datasets have missing entries that sneak in as zero'd out
    feature vectors. Get rid of them.
    """
    for i, (X, y, w, ids) in enumerate(dataset.itershards()):
        available_rows = X.any(axis=1)
        print("Shard %d has %d missing entries."
              % (i, np.count_nonzero(~available_rows)))
        X = X[available_rows]
        y = y[available_rows]
        w = w[available_rows]
        ids = ids[available_rows]
        dataset.set_shard(i, X, y, w, ids)


# Set shard size low to avoid memory problems.
def load_chembl(shard_size=2000, featurizer="ECFP", split='random'):
    """Load KAGGLE datasets. Does not do train/test split"""
    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
    # Set some global variables up top
    current_dir = os.path.dirname(os.path.realpath(__file__))

    # Load dataset
    print("About to load ChEMBL dataset.")
    dataset_path = os.path.join(
        current_dir, "../../datasets/chembl.csv")

    # Featurize KAGGLE dataset
    print("About to featurize ChEMBL dataset.")
    if featurizer == 'ECFP':
        featurizer = dc.feat.CircularFingerprint(size=1024)
    elif featurizer == 'GraphConv':
        featurizer = dc.feat.ConvMolFeaturizer()

    loader = dc.data.CSVLoader(
        tasks=chembl_tasks, smiles_field="smiles", featurizer=featurizer)

    dataset = loader.featurize(dataset_path, shard_size=shard_size)

    # Initialize transformers
    print("About to transform data")
    transformers = [
        dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]
    for transformer in transformers:
        dataset = transformer.transform(dataset)

    splitters = {'index': dc.splits.IndexSplitter(),
                 'random': dc.splits.RandomSplitter(),
                 'scaffold': dc.splits.ScaffoldSplitter()}
    splitter = splitters[split]
    print("Performing new split.")
    train, valid, test = splitter.train_valid_test_split(dataset)

    return chembl_tasks, (train, valid, test), transformers
 No newline at end of file
+70 −0
Original line number Diff line number Diff line
"""
Script that trains graph-conv models on Tox21 dataset.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
import deepchem as dc
from keras import backend as K
from chembl_datasets import load_chembl

# Only for debug!
np.random.seed(123)

g = tf.Graph()
sess = tf.Session(graph=g)
K.set_session(sess)

with g.as_default():
  tf.set_random_seed(123)
  chembl_tasks, datasets, transformers = load_chembl(
      featurizer='GraphConv', split='index')
  train_dataset, valid_dataset, test_dataset = datasets

  # Fit models
  metric = dc.metrics.Metric(dc.metrics.pearson_r2_score, np.mean)

  # Do setup required for tf/keras models
  # Number of features on conv-mols
  n_feat = 71
  # Batch size of models
  batch_size = 128
  graph_model = dc.nn.SequentialGraph(n_feat)
  graph_model.add(dc.nn.GraphConv(128, activation='relu'))
  graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
  graph_model.add(dc.nn.GraphPool())
  graph_model.add(dc.nn.GraphConv(128, activation='relu'))
  graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
  graph_model.add(dc.nn.GraphPool())
  # Gather Projection
  graph_model.add(dc.nn.Dense(256, activation='relu'))
  graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
  graph_model.add(dc.nn.GraphGather(batch_size, activation="tanh"))
  # Dense post-processing layer

  with tf.Session() as sess:
    model = dc.models.MultitaskGraphRegressor(
      sess, graph_model, len(chembl_tasks), batch_size=batch_size,
      learning_rate=1e-3, learning_rate_decay_time=1000,
      optimizer_type="adam", beta1=.9, beta2=.999)

    # Fit trained model
    model.fit(train_dataset, nb_epoch=10)

    print("Evaluating model")
    train_scores = model.evaluate(train_dataset, [metric], transformers)
    valid_scores = model.evaluate(valid_dataset, [metric], transformers)
    # Only use for final evaluation
    test_scores = model.evaluate(test_dataset, [metric], transformers)

    print("Train scores")
    print(train_scores)

    print("Validation scores")
    print(valid_scores)

    print("Test scores")
    print(test_scores)
+14 −0
Original line number Diff line number Diff line
chembl_tasks = ['CHEMBL1075194', 'CHEMBL1075263', 'CHEMBL1287626', 'CHEMBL1649052', 'CHEMBL1681617', 'CHEMBL1764942',
                'CHEMBL1798', 'CHEMBL1800', 'CHEMBL1804', 'CHEMBL1821', 'CHEMBL1836', 'CHEMBL1878', 'CHEMBL1887',
                'CHEMBL1901', 'CHEMBL1914263', 'CHEMBL1941', 'CHEMBL1942', 'CHEMBL1945', 'CHEMBL1963', 'CHEMBL1983',
                'CHEMBL2016428', 'CHEMBL2024', 'CHEMBL2107', 'CHEMBL211', 'CHEMBL2114', 'CHEMBL2115', 'CHEMBL2121',
                'CHEMBL2176850', 'CHEMBL2181', 'CHEMBL223', 'CHEMBL2286', 'CHEMBL2384898', 'CHEMBL2434', 'CHEMBL2488',
                'CHEMBL250', 'CHEMBL2514', 'CHEMBL253', 'CHEMBL255', 'CHEMBL2564', 'CHEMBL2605', 'CHEMBL263',
                'CHEMBL2647', 'CHEMBL2718', 'CHEMBL276', 'CHEMBL2798', 'CHEMBL287', 'CHEMBL2955', 'CHEMBL2967',
                'CHEMBL2978', 'CHEMBL302', 'CHEMBL317', 'CHEMBL3197', 'CHEMBL3228', 'CHEMBL327', 'CHEMBL3278',
                'CHEMBL3297639', 'CHEMBL3322', 'CHEMBL3359', 'CHEMBL3360', 'CHEMBL3374', 'CHEMBL339', 'CHEMBL3459',
                'CHEMBL3562166', 'CHEMBL3602', 'CHEMBL3611961', 'CHEMBL3636', 'CHEMBL3772', 'CHEMBL3798', 'CHEMBL3875',
                'CHEMBL3942', 'CHEMBL4123', 'CHEMBL4124', 'CHEMBL4333', 'CHEMBL4383', 'CHEMBL4552', 'CHEMBL4649',
                'CHEMBL4731', 'CHEMBL4761', 'CHEMBL4762', 'CHEMBL4843', 'CHEMBL4845', 'CHEMBL5136', 'CHEMBL5226',
                'CHEMBL5411', 'CHEMBL5412', 'CHEMBL5414', 'CHEMBL5739', 'CHEMBL5781', 'CHEMBL5844', 'CHEMBL5862',
                'CHEMBL5936', 'CHEMBL5968']
 No newline at end of file
Loading