Commit 79655a7e authored by leswing's avatar leswing
Browse files

Revert Graph Conv For tox21

parent b45b0373
Loading
Loading
Loading
Loading
+27 −6
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@ from __future__ import division
from __future__ import unicode_literals

import numpy as np

np.random.seed(123)
import tensorflow as tf
tf.set_random_seed(123)
@@ -21,13 +20,35 @@ train_dataset, valid_dataset, test_dataset = tox21_datasets
metric = dc.metrics.Metric(
    dc.metrics.roc_auc_score, np.mean, mode="classification")

# Number of features on conv-mols
n_feat = 75
# Batch size of models
batch_size = 50

model = dc.models.tensorgraph.models.graph_conv_model(batch_size,
                                                      len(tox21_tasks))

model.fit(train_dataset, nb_epoch=10, checkpoint_interval=10)
graph_model = dc.nn.SequentialGraph(n_feat)
graph_model.add(dc.nn.GraphConv(64, n_feat, activation='relu'))
graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
graph_model.add(dc.nn.GraphPool())
graph_model.add(dc.nn.GraphConv(64, 64, activation='relu'))
graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
graph_model.add(dc.nn.GraphPool())
# Gather Projection
graph_model.add(dc.nn.Dense(128, 64, activation='relu'))
graph_model.add(dc.nn.BatchNormalization(epsilon=1e-5, mode=1))
graph_model.add(dc.nn.GraphGather(batch_size, activation="tanh"))

model = dc.models.MultitaskGraphClassifier(
    graph_model,
    len(tox21_tasks),
    n_feat,
    batch_size=batch_size,
    learning_rate=1e-3,
    learning_rate_decay_time=1000,
    optimizer_type="adam",
    beta1=.9,
    beta2=.999)

# Fit trained model
model.fit(train_dataset, nb_epoch=10)

print("Evaluating model")
train_scores = model.evaluate(train_dataset, [metric], transformers)