Commit b93dec1a authored by miaecle's avatar miaecle
Browse files

fix test failure

parent 1e646c4e
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -677,15 +677,15 @@ class TestOverfit(test_util.TensorFlowTestCase):
    regression_metric = dc.metrics.Metric(
        dc.metrics.r2_score, task_averager=np.mean)
    n_tasks = y.shape[1]
    n_feat = list(dataset.get_data_shape())
    max_n_atoms = list(dataset.get_data_shape())
    batch_size = 10

    graph_model = dc.nn.SequentialDTNNGraph(max_n_atoms=n_feat[0])
    graph_model.add(dc.nn.DTNNEmbedding())
    graph_model.add(dc.nn.DTNNStep())
    graph_model.add(dc.nn.DTNNStep())
    graph_model.add(dc.nn.DTNNGather(n_tasks=n_tasks))

    graph_model = dc.nn.SequentialDTNNGraph(max_n_atoms=max_n_atoms)
    graph_model.add(dc.nn.DTNNEmbedding(n_embedding=20))
    graph_model.add(dc.nn.DTNNStep(n_embedding=20))
    graph_model.add(dc.nn.DTNNStep(n_embedding=20))
    graph_model.add(dc.nn.DTNNGather(n_embedding=20))
    n_feat = 20
    model = dc.models.DTNNGraphRegressor(
        graph_model,
        n_tasks,