Commit e971ccb2 authored by miaecle's avatar miaecle
Browse files

unit test for DTNN

parent 2e4b0eee
Loading
Loading
Loading
Loading
+12.8 KiB

File added.

No diff preview for this file type.

+47 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ import sklearn
import shutil
import tensorflow as tf
import deepchem as dc
import scipy.io
from tensorflow.python.framework import test_util
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
@@ -661,6 +662,52 @@ class TestOverfit(test_util.TensorFlowTestCase):

    assert scores[classification_metric.name] < .2

  def test_DTNN_multitask_regression_overfit(self):
    """Test graph-conv multitask overfits tiny data."""
    np.random.seed(123)
    tf.set_random_seed(123)
    g = tf.Graph()
    sess = tf.Session(graph=g)

    # Load mini log-solubility dataset.
    input_file = os.path.join(self.current_dir, "example_DTNN.mat")
    dataset = scipy.io.loadmat(input_file)
    X = dataset['X']
    y = dataset['T']
    w = np.ones_like(y)
    dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids=None)
    regression_metric = dc.metrics.Metric(
        dc.metrics.mean_absolute_error, mode="regression", task_averager=np.mean)
    n_tasks = y.shape[1]
    n_feat = list(dataset.get_data_shape())
    batch_size = 10

    graph_model = dc.nn.SequentialDTNNGraph(max_n_atoms=n_feat[0])
    graph_model.add(dc.nn.DTNNEmbedding())
    graph_model.add(dc.nn.DTNNStep())
    graph_model.add(dc.nn.DTNNStep())
    graph_model.add(dc.nn.DTNNGather(n_tasks=n_tasks))

    model = dc.models.DTNNRegressor(
        graph_model,
        n_tasks=n_tasks,
        batch_size=batch_size,
        learning_rate=1e-2,
        learning_rate_decay_time=1000,
        optimizer_type="adam",
        beta1=.9,
        beta2=.999)

    # Fit trained model
    model.fit(dataset, nb_epoch=20)
    model.save()

    # Eval model on train
    scores = model.evaluate(dataset, [classification_metric])

    assert scores[classification_metric.name] < .2


  def test_siamese_singletask_classification_overfit(self):
    """Test siamese singletask model overfits tiny data."""
    np.random.seed(123)