Commit 18e74be0 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #411 from miaecle/IRV

IRV classifier
parents 4ee9d6a6 5ba2e919
Loading
Loading
Loading
Loading
+69 −23
Original line number Diff line number Diff line
# DeepChem
# DeepChem

DeepChem aims to provide a high quality open-source toolchain that
democratizes the use of deep-learning in drug discovery, materials science, and quantum
@@ -190,6 +190,8 @@ Index splitting
|Dataset    |Model               |Train score/ROC-AUC|Valid score/ROC-AUC|
|-----------|--------------------|-------------------|-------------------|
|tox21      |logistic regression |0.903              |0.705              |
|           |Random Forest       |0.999              |0.733              |
|           |IRV                 |0.811              |0.767              |
|           |Multitask network   |0.856              |0.763              |
|           |robust MT-NN        |0.857              |0.767              |
|           |graph convolution   |0.872              |0.798              |
@@ -202,6 +204,8 @@ Index splitting
|           |robust MT-NN        |0.809              |0.783              |
|           |graph convolution   |0.876              |0.852              |
|sider      |logistic regression |0.933              |0.620              |
|           |Random Forest       |0.999              |0.670              |
|           |IRV                 |0.649              |0.642              |
|           |Multitask network   |0.775              |0.634              |
|           |robust MT-NN        |0.803              |0.632              |
|           |graph convolution   |0.708              |0.594              |
@@ -210,6 +214,8 @@ Index splitting
|           |robust MT-NN        |0.825              |0.680              |
|           |graph convolution   |0.821              |0.720              |
|clintox    |logistic regression |0.967              |0.676              |
|           |Random Forest       |0.995              |0.776              |
|           |IRV                 |0.763              |0.814              |
|           |Multitask network   |0.934              |0.830              |
|           |robust MT-NN        |0.949              |0.827              |
|           |graph convolution   |0.946              |0.860              |
@@ -219,6 +225,8 @@ Random splitting
|Dataset    |Model               |Train score/ROC-AUC|Valid score/ROC-AUC|
|-----------|--------------------|-------------------|-------------------|
|tox21      |logistic regression |0.902              |0.715              |
|           |Random Forest       |0.999              |0.764              |
|           |IRV                 |0.808              |0.767              |
|           |Multitask network   |0.844              |0.795              |
|           |robust MT-NN        |0.855              |0.773              |
|           |graph convolution   |0.865              |0.827              |
@@ -231,6 +239,8 @@ Random splitting
|           |robust MT-NN        |0.811              |0.771              |
|           |graph convolution   |0.872       	     |0.844              |
|sider      |logistic regression |0.929        	     |0.656              |
|           |Random Forest       |0.999              |0.665              |
|           |IRV                 |0.648              |0.596              |
|           |Multitask network   |0.777        	     |0.655              |
|           |robust MT-NN        |0.804              |0.630              |
|           |graph convolution   |0.705        	     |0.618              |
@@ -239,6 +249,8 @@ Random splitting
|           |robust MT-NN        |0.822              |0.681              |
|           |graph convolution   |0.820        	     |0.717              |
|clintox    |logistic regression |0.972              |0.725              |
|           |Random Forest       |0.997              |0.670              |
|           |IRV                 |0.809              |0.846              |
|           |Multitask network   |0.951              |0.834              |
|           |robust MT-NN        |0.959              |0.830              |
|           |graph convolution   |0.975              |0.876              |
@@ -248,6 +260,8 @@ Scaffold splitting
|Dataset    |Model               |Train score/ROC-AUC|Valid score/ROC-AUC|
|-----------|--------------------|-------------------|-------------------|
|tox21      |logistic regression |0.900              |0.650              |
|           |Random Forest       |0.999              |0.629              |
|           |IRV                 |0.823              |0.708              |
|           |Multitask network   |0.863              |0.703              |
|           |robust MT-NN        |0.861              |0.710              |
|           |graph convolution   |0.885              |0.732              |
@@ -260,6 +274,8 @@ Scaffold splitting
|           |robust MT-NN        |0.812              |0.756              |
|           |graph convolution   |0.874              |0.817              |
|sider      |logistic regression |0.926              |0.592              |
|           |Random Forest       |0.999              |0.619              |
|           |IRV                 |0.639              |0.599              |
|           |Multitask network   |0.776              |0.557              |
|           |robust MT-NN        |0.797              |0.560              |
|           |graph convolution   |0.722              |0.583              |
@@ -268,6 +284,8 @@ Scaffold splitting
|           |robust MT-NN        |0.830              |0.614              |
|           |graph convolution   |0.832              |0.638              |
|clintox    |logistic regression |0.960              |0.803              |
|           |Random Forest       |0.993              |0.735              |
|           |IRV                 |0.793              |0.718              |
|           |Multitask network   |0.947              |0.862              |
|           |robust MT-NN        |0.953              |0.890              |
|           |graph convolution   |0.957              |0.823              |
@@ -276,40 +294,54 @@ Scaffold splitting

|Dataset         |Model               |Splitting   |Train score/R2|Valid score/R2|
|----------------|--------------------|------------|--------------|--------------|
|delaney         |MT-NN regression    |Index       |0.868         |0.578         |
|delaney         |Random Forest       |Index       |0.953         |0.626         |
|                |NN regression       |Index       |0.868         |0.578         |
|                |graphconv regression|Index       |0.967         |0.790         |
|                |MT-NN regression    |Random      |0.865         |0.574         |
|                |Random Forest       |Random      |0.951         |0.684         |
|                |NN regression       |Random      |0.865         |0.574         |
|                |graphconv regression|Random      |0.964         |0.782         |
|                |MT-NN regression    |Scaffold    |0.866         |0.342         |
|                |Random Forest       |Scaffold    |0.953         |0.284         |
|                |NN regression       |Scaffold    |0.866         |0.342         |
|                |graphconv regression|Scaffold    |0.967         |0.606         |
|sampl           |MT-NN regression    |Index       |0.917         |0.764         |
|sampl           |Random Forest       |Index       |0.968         |0.736         |
|                |NN regression       |Index       |0.917         |0.764         |
|                |graphconv regression|Index       |0.982         |0.864         |
|                |MT-NN regression    |Random      |0.908         |0.830         |
|                |Random Forest       |Random      |0.967         |0.752         |
|                |NN regression       |Random      |0.908         |0.830         |
|                |graphconv regression|Random      |0.987         |0.868         |
|                |MT-NN regression    |Scaffold    |0.891         |0.217         |
|                |Random Forest       |Scaffold    |0.966         |0.473         |
|                |NN regression       |Scaffold    |0.891         |0.217         |
|                |graphconv regression|Scaffold    |0.985         |0.666         |
|nci             |MT-NN regression    |Index       |0.171         |0.062         |
|nci             |NN regression       |Index       |0.171         |0.062         |
|                |graphconv regression|Index       |0.123         |0.048         |
|                |MT-NN regression    |Random      |0.168         |0.085         |
|                |NN regression       |Random      |0.168         |0.085         |
|                |graphconv regression|Random      |0.117         |0.076         |
|                |MT-NN regression    |Scaffold    |0.180         |0.052         |
|                |NN regression       |Scaffold    |0.180         |0.052         |
|                |graphconv regression|Scaffold    |0.131         |0.046         |
|pdbbind(core)   |MT-NN regression    |Random      |0.973         |0.494         |
|pdbbind(refined)|MT-NN regression    |Random      |0.987         |0.503         |
|pdbbind(full)   |MT-NN regression    |Random      |0.983         |0.528         |
|pdbbind(core)   |Random Forest       |Random      |0.969         |0.445         |
|                |NN regression       |Random      |0.973         |0.494         |
|pdbbind(refined)|Random Forest       |Random      |0.963         |0.511         |
|                |NN regression       |Random      |0.987         |0.503         |
|pdbbind(full)   |Random Forest       |Random      |0.965         |0.493         |
|                |NN regression       |Random      |0.983         |0.528         |
|chembl          |MT-NN regression    |Index       |0.443         |0.427         |
|                |MT-NN regression    |Random      |0.464         |0.434         |
|                |MT-NN regression    |Scaffold    |0.484         |0.361         |
|gdb7            |MT-NN regression    |Index       |0.994         |0.010         |
|                |MT-NN regression    |Random      |0.860         |0.773         |
|                |MT-NN regression    |User-defined|0.996         |0.996         | 
|qm7             |NN regression       |Index       |0.994         |0.969         |
|                |NN regression       |Random      |0.995         |0.992         |
|                |NN regression       |Stratified  |0.992         |0.992         | 
|qm7b            |MT-NN regression    |Index       |0.883         |0.785         |
|                |MT-NN regression    |Random      |0.864         |0.838         |
|                |MT-NN regression    |Stratified  |0.871         |0.847         | 
|kaggle          |MT-NN regression    |User-defined|0.748         |0.452         |

|Dataset         |Model            |Splitting   |Train score/MAE(kcal/mol)|Valid score/MAE(kcal/mol)|
|----------------|--------------------|------------|-------------------------|-------------------------|
|gdb7            |MT-NN regression    |Index       |18.3                     |172.0                    |
|                |MT-NN regression    |Random      |44.3                     |59.1                     |
|                |MT-NN regression    |User-defined|9.0                      |9.5                      |
|----------------|-----------------|------------|-------------------------|-------------------------|
|qm7             |NN regression    |Index       |22.1                     |23.2                     |
|                |NN regression    |Random      |16.2                     |17.7                     |
|                |NN regression    |Stratified  |20.5                     |20.8                     |
|                |NN regression    |User-defined|9.0                      |9.5                      |


* General features

@@ -331,7 +363,8 @@ Number of tasks and examples in the datasets
|pdbbind(refined)|1          |3706       |
|pdbbind(full)   |1          |11908      |
|chembl(5thresh) |691        |23871      |
|gdb7            |1          |7165       |
|qm7             |1          |7165       |
|qm7b            |14         |7211       |



@@ -342,6 +375,8 @@ Time needed for benchmark test(~20h in total)
|tox21           |logistic regression |30              |60             |
|                |Multitask network   |30              |60             |
|                |robust MT-NN        |30              |90             |
|                |random forest       |30              |6000           |
|                |IRV                 |30              |650            |
|                |graph convolution   |40              |160            |
|muv             |logistic regression |600             |450            |
|                |Multitask network   |600             |400            |
@@ -354,22 +389,33 @@ Time needed for benchmark test(~20h in total)
|sider           |logistic regression |15              |80             |
|                |Multitask network   |15              |75             |
|                |robust MT-NN        |15              |150            |
|                |random forest       |15              |2200           |
|                |IRV                 |15              |150            |
|                |graph convolution   |20              |50             |
|toxcast         |logistic regression |80              |2600           |
|                |Multitask network   |80              |2300           |
|                |robust MT-NN        |80              |4000           |
|                |graph convolution   |80              |900            |
|clintox         |logistic regression |15              |10             |
|                |Multitask network   |15              |20             |
|                |robust MT-NN        |15              |30             |
|                |random forest       |15              |200            |
|                |IRV                 |15              |10             |
|                |graph convolution   |20              |130            |
|delaney         |MT-NN regression    |10              |40             |
|                |graphconv regression|10              |40             |
|                |random forest       |10              |30             |
|sampl           |MT-NN regression    |10              |30             |
|                |graphconv regression|10              |40             |
|                |random forest       |10              |20             |
|nci             |MT-NN regression    |400             |1200           |
|                |graphconv regression|400             |2500           |
|pdbbind(core)   |MT-NN regression    |0(featurized)   |30             |
|pdbbind(refined)|MT-NN regression    |0(featurized)   |40             |
|pdbbind(full)   |MT-NN regression    |0(featurized)   |60             |
|chembl          |MT-NN regression    |200             |9000           |
|gdb7            |MT-NN regression    |10              |110            |
|qm7             |MT-NN regression    |10              |400            |
|qm7b            |MT-NN regression    |10              |600            |
|kaggle          |MT-NN regression    |2200            |3200           |


+1 −0
Original line number Diff line number Diff line
@@ -20,3 +20,4 @@ from deepchem.models.tensorflow_models.robust_multitask import RobustMultitaskCl
from deepchem.models.tensorflow_models.lr import TensorflowLogisticRegression
from deepchem.models.tensorflow_models.progressive_multitask import ProgressiveMultitaskRegressor
from deepchem.models.tensorflow_models.progressive_joint import ProgressiveJointRegressor
from deepchem.models.tensorflow_models.IRV import TensorflowMultiTaskIRVClassifier
 No newline at end of file
+114 −0
Original line number Diff line number Diff line
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import time
import numpy as np
import tensorflow as tf

from deepchem.utils.save import log
from deepchem.models.tensorflow_models import TensorflowGraph
from deepchem.models.tensorflow_models import TensorflowGraphModel
from deepchem.models.tensorflow_models.lr import TensorflowLogisticRegression


class TensorflowMultiTaskIRVClassifier(TensorflowLogisticRegression):

  def __init__(self,
               n_tasks,
               K=10,
               logdir=None,
               n_classes=2,
               penalty=0.0,
               penalty_type="l2",
               learning_rate=0.001,
               momentum=.8,
               optimizer="adam",
               batch_size=50,
               verbose=True,
               seed=None,
               **kwargs):
    """Initialize TensorflowMultiTaskIRVClassifier
    
    Parameters
    ----------
    n_tasks: int
      Number of tasks
    K: int
      Number of nearest neighbours used in classification
    logdir: str
      Location to save data
    n_classes: int
      number of different labels
    penalty: float
      Amount of penalty (l2 or l1 applied)
    penalty_type: str
      Either "l2" or "l1"
    learning_rate: float
      Learning rate for model.
    momentum: float
      Momentum. Only applied if optimizer=="momentum"
    optimizer: str
      Type of optimizer applied.
    batch_size: int
      Size of minibatches for training.
    verbose: True 
      Perform logging.
    seed: int
      If not none, is used as random seed for tensorflow.        

    """

    self.n_tasks = n_tasks
    self.K = K
    self.n_features = 2 * self.K * self.n_tasks
    print("n_features after fit_transform: %d" % int(self.n_features))
    TensorflowGraphModel.__init__(
        self,
        n_tasks,
        self.n_features,
        logdir=logdir,
        layer_sizes=None,
        weight_init_stddevs=None,
        bias_init_consts=None,
        penalty=penalty,
        penalty_type=penalty_type,
        dropouts=None,
        n_classes=n_classes,
        learning_rate=learning_rate,
        momentum=momentum,
        optimizer=optimizer,
        batch_size=batch_size,
        pad_batches=False,
        verbose=verbose,
        seed=seed,
        **kwargs)

  def build(self, graph, name_scopes, training):
    """Constructs the graph architecture of IRV as described in:
       
       https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2750043/
    """
    placeholder_scope = TensorflowGraph.get_placeholder_scope(graph,
                                                              name_scopes)
    K = self.K
    with graph.as_default():
      output = []
      with placeholder_scope:
        self.features = tf.placeholder(
            tf.float32, shape=[None, self.n_features], name='mol_features')
      with tf.name_scope('variable'):
        V = tf.Variable(tf.constant([0.01, 1.]), name="vote", dtype=tf.float32)
        W = tf.Variable(tf.constant([1., 1.]), name="w", dtype=tf.float32)
        b = tf.Variable(tf.constant([0.01]), name="b", dtype=tf.float32)
        b2 = tf.Variable(tf.constant([0.01]), name="b2", dtype=tf.float32)
      for count in range(self.n_tasks):
        similarity = self.features[:, 2 * K * count:(2 * K * count + K)]
        ys = tf.to_int32(
            self.features[:, (2 * K * count + K):2 * K * (count + 1)])
        R = b + W[0] * similarity + W[1] * tf.constant(
            np.arange(K) + 1, dtype=tf.float32)
        R = tf.sigmoid(R)
        z = tf.reduce_sum(R * tf.gather(V, ys), axis=1) + b2
        output.append(tf.reshape(z, shape=[-1, 1]))
    return output
+217 −105

File changed.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
@@ -14,3 +14,4 @@ from deepchem.trans.transformers import BalancingTransformer
from deepchem.trans.transformers import CDFTransformer
from deepchem.trans.transformers import PowerTransformer
from deepchem.trans.transformers import CoulombFitTransformer
from deepchem.trans.transformers import IRVTransformer
Loading