Commit 7500bcd5 authored by miaecle's avatar miaecle
Browse files

fix bug in benchmark

parent 1a3c14b0
Loading
Loading
Loading
Loading
+12 −4
Original line number Diff line number Diff line
@@ -245,9 +245,13 @@ def benchmark_classification(train_dataset,

    tf.set_random_seed(seed)
    graph_model = deepchem.nn.SequentialDAGGraph(
        n_features, batch_size=batch_size, max_atoms=max_atoms)
        n_features, max_atoms=max_atoms)
    graph_model.add(
        deepchem.nn.DAGLayer(n_graph_feat, n_features, max_atoms=max_atoms))
        deepchem.nn.DAGLayer(
            n_graph_feat,
            n_features,
            max_atoms=max_atoms,
            batch_size=batch_size))
    graph_model.add(deepchem.nn.DAGGather(max_atoms=max_atoms))

    model = deepchem.models.MultitaskGraphClassifier(
@@ -574,9 +578,13 @@ def benchmark_regression(train_dataset,

    tf.set_random_seed(seed)
    graph_model = deepchem.nn.SequentialDAGGraph(
        n_features, batch_size=batch_size, max_atoms=max_atoms)
        n_features, max_atoms=max_atoms)
    graph_model.add(
        deepchem.nn.DAGLayer(n_graph_feat, n_features, max_atoms=max_atoms))
        deepchem.nn.DAGLayer(
            n_graph_feat,
            n_features,
            max_atoms=max_atoms,
            batch_size=batch_size))
    graph_model.add(deepchem.nn.DAGGather(max_atoms=max_atoms))

    model = deepchem.models.MultitaskGraphRegressor(