Commit 293073fb authored by miaecle's avatar miaecle
Browse files

wrap up

parent 0ee319d9
Loading
Loading
Loading
Loading
+21 −5
Original line number Diff line number Diff line
@@ -220,7 +220,7 @@ Index splitting
|           |MT-NN classification|0.934              |0.830              |
|           |Robust MT-NN        |0.949              |0.827              |
|           |Graph convolution   |0.946              |0.860              |
|           |Weave               |0.907              |0.879              |
|           |Weave               |0.942              |0.917              |
|hiv        |Logistic regression |0.864              |0.739              |
|           |Random forest       |0.999              |0.720              |
|           |XGBoost             |0.917              |0.745              |
@@ -228,11 +228,13 @@ Index splitting
|           |NN classification   |0.761              |0.652              |
|           |Robust NN           |0.780              |0.708              |
|           |Graph convolution   |0.876              |0.779              |
|           |Weave               |0.907              |0.753              |
|muv        |Logistic regression |0.963              |0.766              |
|           |XGBoost             |0.895              |0.714              |
|           |MT-NN classification|0.904              |0.764              |
|           |Robust MT-NN        |0.934              |0.781              |
|           |Graph convolution   |0.840              |0.823              |
|           |Weave               |0.762              |0.761              |
|pcba       |Logistic regression |0.809              |0.776              |
|           |XGBoost             |0.931              |0.847              |
|           |MT-NN classification|0.826              |0.802              |
@@ -245,6 +247,7 @@ Index splitting
|           |MT-NN classification|0.775              |0.634              |
|           |Robust MT-NN        |0.803              |0.632              |
|           |Graph convolution   |0.708              |0.594              |
|           |Weave               |0.591              |0.580              |
|tox21      |Logistic regression |0.903              |0.705              |
|           |Random forest       |0.999              |0.733              |
|           |XGBoost             |0.891              |0.753              |
@@ -258,6 +261,7 @@ Index splitting
|           |MT-NN classification|0.830              |0.678              |
|           |Robust MT-NN        |0.825              |0.680              |
|           |Graph convolution   |0.821              |0.720              |
|           |Weave               |0.766              |0.715              |

Random splitting

@@ -269,12 +273,14 @@ Random splitting
|           |NN classification   |0.877              |0.790              |
|           |Robust NN           |0.887              |0.864              |
|           |Graph convolution   |0.906              |0.861              |
|           |Weave               |0.807              |0.780              |
|bbbp       |Logistic regression |0.980              |0.876              |
|           |Random forest       |0.999              |0.918              |
|           |IRV                 |0.904              |0.917              |
|           |NN classification   |0.882              |0.915              |
|           |Robust NN           |0.878              |0.878              |
|           |Graph convolution   |0.962              |0.897              |
|           |Weave               |0.929              |0.934              |
|clintox    |Logistic regression |0.972              |0.725              |
|           |Random forest       |0.997              |0.670              |
|           |XGBoost             |0.886              |0.731              |
@@ -282,7 +288,7 @@ Random splitting
|           |MT-NN classification|0.951              |0.834              |
|           |Robust MT-NN        |0.959              |0.830              |
|           |Graph convolution   |0.975              |0.876              |
|           |Weave               |0.890              |0.738              |
|           |Weave               |0.945              |0.818              |
|hiv        |Logistic regression |0.860              |0.806              |
|           |Random forest       |0.999              |0.850              |
|           |XGBoost             |0.933              |0.841              |
@@ -290,11 +296,13 @@ Random splitting
|           |NN classification   |0.742              |0.715              |
|           |Robust NN           |0.753              |0.727              |
|           |Graph convolution   |0.847              |0.803              |
|           |Weave               |0.902              |0.825              |
|muv        |Logistic regression |0.957              |0.719              |
|           |XGBoost             |0.874              |0.696              |
|           |MT-NN classification|0.902              |0.734              |
|           |Robust MT-NN        |0.933              |0.732              |
|           |Graph convolution   |0.860              |0.730              |
|           |Weave               |0.763              |0.763              |
|pcba       |Logistic regression |0.808        	     |0.776              |
|           |MT-NN classification|0.811        	     |0.778              |
|           |Robust MT-NN        |0.811              |0.771              |
@@ -306,6 +314,7 @@ Random splitting
|           |MT-NN classification|0.777        	     |0.655              |
|           |Robust MT-NN        |0.804              |0.630              |
|           |Graph convolution   |0.705        	     |0.618              |
|           |Weave               |0.616              |0.645              |
|tox21      |Logistic regression |0.902              |0.715              |
|           |Random forest       |0.999              |0.764              |
|           |XGBoost             |0.874              |0.773              |
@@ -313,12 +322,13 @@ Random splitting
|           |MT-NN classification|0.844              |0.795              |
|           |Robust MT-NN        |0.855              |0.773              |
|           |Graph convolution   |0.865              |0.827              |
|           |Weave               |0.796              |0.781              |
|           |Weave               |0.837              |0.830              |
|toxcast    |Logistic regression |0.725        	     |0.586              |
|           |XGBoost             |0.738              |0.633              |
|           |MT-NN classification|0.836        	     |0.684              |
|           |Robust MT-NN        |0.822              |0.681              |
|           |Graph convolution   |0.820        	     |0.717              |
|           |Weave               |0.757              |0.729              |

Scaffold splitting

@@ -330,12 +340,14 @@ Scaffold splitting
|           |NN classification   |0.897              |0.743              |
|           |Robust NN           |0.910              |0.747              |
|           |Graph convolution   |0.920              |0.682              |
|           |Weave               |0.860              |0.629              |
|bbbp       |Logistic regression |0.980              |0.959              |
|           |Random forest       |0.999              |0.953              |
|           |IRV                 |0.914              |0.961              |
|           |NN classification   |0.899              |0.961              |
|           |Robust NN           |0.908              |0.956              |
|           |Graph convolution   |0.968              |0.950              |
|           |Weave               |0.925              |0.968              |
|clintox    |Logistic regression |0.965              |0.688              |
|           |Random forest       |0.993              |0.735              |
|           |XGBoost             |0.873              |0.850              |
@@ -343,7 +355,7 @@ Scaffold splitting
|           |MT-NN classification|0.937              |0.828              |
|           |Robust MT-NN        |0.956              |0.821              |
|           |Graph convolution   |0.965              |0.900              |
|           |Weave               |0.888              |0.873              |
|           |Weave               |0.950              |0.947              |
|hiv        |Logistic regression |0.858              |0.798              |
|           |Random forest       |0.946              |0.562              |
|           |XGBoost             |0.927              |0.830              |
@@ -351,11 +363,13 @@ Scaffold splitting
|           |NN classification   |0.775              |0.765              |
|           |Robust NN           |0.785              |0.748              |
|           |Graph convolution   |0.867              |0.769              |
|           |Weave               |0.875              |0.816              |
|muv        |Logistic regression |0.947              |0.767              |
|           |XGBoost             |0.875              |0.705              |
|           |MT-NN classification|0.899              |0.762              |
|           |Robust MT-NN        |0.944              |0.726              |
|           |Graph convolution   |0.872              |0.795              |
|           |Weave               |0.780              |0.773              |
|pcba       |Logistic regression |0.810              |0.742              |
|           |MT-NN classification|0.814              |0.760              |
|           |Robust MT-NN        |0.812              |0.756              |
@@ -367,6 +381,7 @@ Scaffold splitting
|           |MT-NN classification|0.776              |0.557              |
|           |Robust MT-NN        |0.797              |0.560              |
|           |Graph convolution   |0.722              |0.583              |
|           |Weave               |0.600              |0.529              |
|tox21      |Logistic regression |0.900              |0.650              |
|           |Random forest       |0.999              |0.629              |
|           |XGBoost             |0.881              |0.703              |
@@ -374,12 +389,13 @@ Scaffold splitting
|           |MT-NN classification|0.863              |0.703              |
|           |Robust MT-NN        |0.861              |0.710              |
|           |Graph convolution   |0.885              |0.732              |
|           |Weave               |0.812              |0.727              |
|           |Weave               |0.866              |0.773              |
|toxcast    |Logistic regression |0.716              |0.492              |
|           |XGBoost             |0.741              |0.587              |
|           |MT-NN classification|0.828              |0.617              |
|           |Robust MT-NN        |0.830              |0.614              |
|           |Graph convolution   |0.832              |0.638              |
|           |Weave               |0.766              |0.637              |


* Regression
+5 −12
Original line number Diff line number Diff line
@@ -525,11 +525,8 @@ class WeaveGraphTopology_v2(GraphTopology):
        shape=(None, self.n_pair_feat),
        name=self.name + '_pair_features')
    self.pair_split_placeholder = tf.placeholder(
        dtype='int32', shape=(self.max_atoms,), 
        dtype='int32', shape=(None,), 
        name=self.name + '_pair_split')
    self.pair_membership_placeholder = tf.placeholder(
        dtype='bool', shape=(self.max_atoms,), 
        name=self.name + '_pair_membership')
    self.atom_split_placeholder = tf.placeholder(
        dtype='int32', shape=(self.batch_size,), 
        name=self.name + '_atom_split')
@@ -538,8 +535,7 @@ class WeaveGraphTopology_v2(GraphTopology):
        name=self.name + '_atom_to_pair')
    
    # Define the list of tensors to be used as topology
    self.topology = [self.pair_split_placeholder, self.pair_membership_placeholder,
                     self.atom_split_placeholder, self.atom_to_pair_placeholder]
    self.topology = [self.pair_split_placeholder, self.atom_split_placeholder, self.atom_to_pair_placeholder]
    self.inputs = [self.atom_features_placeholder]
    self.inputs += self.topology

@@ -577,9 +573,10 @@ class WeaveGraphTopology_v2(GraphTopology):
      # index of pair features
      C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
      atom_to_pair.append(np.transpose(np.array([C1.flatten()+start, C0.flatten()+start])))
      start = start + n_atoms
      # number of pairs for each atom
      pair_split.extend([n_atoms]*n_atoms)
      pair_split.extend(C1.flatten()+start)
      start = start + n_atoms
    
      # atom features
      atom_feat.append(mol.get_atom_features())
      # pair features
@@ -590,15 +587,11 @@ class WeaveGraphTopology_v2(GraphTopology):
    pair_feat = np.concatenate(pair_feat, axis=0)
    atom_to_pair = np.concatenate(atom_to_pair, axis=0)
    atom_split = np.array(atom_split)
    n_pair = len(pair_split)
    pair_split = np.pad(pair_split, ((0, max_atoms-n_pair)), 'constant')
    pair_membership = np.array([True]*n_pair + [False]*(max_atoms-n_pair))
    # Generate dicts
    dict_DTNN = {
        self.atom_features_placeholder: atom_feat,
        self.pair_features_placeholder: pair_feat,
        self.pair_split_placeholder: pair_split,
        self.pair_membership_placeholder: pair_membership,
        self.atom_split_placeholder: atom_split,
        self.atom_to_pair_placeholder: atom_to_pair
    }
+1 −1
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ hps['dag'] = {
hps['weave'] = {
    'batch_size': 64,
    'nb_epoch': 40,
    'learning_rate': 0.0001,
    'learning_rate': 0.001,
    'n_graph_feat': 128,
    'n_pair_feat': 14,
    'seed': 123
+13 −13
Original line number Diff line number Diff line
@@ -271,14 +271,14 @@ def benchmark_classification(train_dataset,
    max_atoms_test = max([mol.get_num_atoms() for mol in test_dataset.X])
    max_atoms = max([max_atoms_train, max_atoms_valid, max_atoms_test])
    
    graph_model = deepchem.nn.SequentialWeaveGraph(max_atoms=max_atoms,
        n_atom_feat=n_features, n_pair_feat=n_pair_feat)
    graph_model.add(deepchem.nn.WeaveLayer(max_atoms, 75, 14))
    graph_model.add(deepchem.nn.WeaveLayer(max_atoms, 50, 50, update_pair=False))
    graph_model.add(deepchem.nn.WeaveConcat(batch_size, n_output=n_graph_feat))
    graph_model = deepchem.nn.SequentialWeaveGraph_v2(batch_size,
        max_atoms=max_atoms, n_atom_feat=n_features, n_pair_feat=n_pair_feat)
    graph_model.add(deepchem.nn.WeaveLayer_v2(max_atoms, 75, 14))
    graph_model.add(deepchem.nn.WeaveLayer_v2(max_atoms, 50, 50, update_pair=False))
    graph_model.add(deepchem.nn.Dense(n_graph_feat, 50, activation='tanh'))
    graph_model.add(deepchem.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph_model.add(
        deepchem.nn.WeaveGather(
        deepchem.nn.WeaveGather_v2(
            batch_size, n_input=n_graph_feat, gaussian_expand=True))

    model = deepchem.models.MultitaskGraphClassifier(
@@ -595,14 +595,14 @@ def benchmark_regression(train_dataset,
    max_atoms_test = max([mol.get_num_atoms() for mol in test_dataset.X])
    max_atoms = max([max_atoms_train, max_atoms_valid, max_atoms_test])
    
    graph_model = deepchem.nn.SequentialWeaveGraph(max_atoms=max_atoms,
        n_atom_feat=n_features, n_pair_feat=n_pair_feat)
    graph_model.add(deepchem.nn.WeaveLayer(max_atoms, 75, 14))
    graph_model.add(deepchem.nn.WeaveLayer(max_atoms, 50, 50, update_pair=False))
    graph_model.add(deepchem.nn.WeaveConcat(batch_size, n_output=n_graph_feat))
    graph_model = deepchem.nn.SequentialWeaveGraph_v2(batch_size,
        max_atoms=max_atoms, n_atom_feat=n_features, n_pair_feat=n_pair_feat)
    graph_model.add(deepchem.nn.WeaveLayer_v2(max_atoms, 75, 14))
    graph_model.add(deepchem.nn.WeaveLayer_v2(max_atoms, 50, 50, update_pair=False))
    graph_model.add(deepchem.nn.Dense(n_graph_feat, 50, activation='tanh'))
    graph_model.add(deepchem.nn.BatchNormalization(epsilon=1e-5, mode=1))
    graph_model.add(
        deepchem.nn.WeaveGather(
        deepchem.nn.WeaveGather_v2(
            batch_size, n_input=n_graph_feat, gaussian_expand=True))

    model = deepchem.models.MultitaskGraphRegressor(
+3 −6
Original line number Diff line number Diff line
@@ -202,17 +202,14 @@ class WeaveLayer_v2(WeaveLayer):
    pair_features = x[1]

    pair_split = x[2]
    pair_membership = x[3]
    atom_split = x[4]
    atom_to_pair = x[5]
    atom_split = x[3]
    atom_to_pair = x[4]

    AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
    AA = self.activation(AA)
    PA = tf.matmul(pair_features, self.W_PA) + self.b_PA
    PA = self.activation(PA)
    PAs = tf.split(PA, pair_split, axis=0)
    PA = [tf.reduce_sum(molecule, 0) for molecule in PAs]
    PA = tf.boolean_mask(PA, pair_membership)
    PA = tf.segment_sum(PA, pair_split)
    
    A = tf.matmul(tf.concat([AA, PA], 1), self.W_A) + self.b_A
    A = self.activation(A)
Loading