Commit f20016c8 authored by Michelle Gill's avatar Michelle Gill
Browse files

Convert DAGTensorGraph to DAGModel

parent ccd2f0fc
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from deepchem.models.tensorgraph.IRV import TensorflowMultiTaskIRVClassifier
from deepchem.models.tensorgraph.robust_multitask import RobustMultitaskClassifier
from deepchem.models.tensorgraph.robust_multitask import RobustMultitaskRegressor
from deepchem.models.tensorgraph.progressive_multitask import ProgressiveMultitaskRegressor, ProgressiveMultitaskClassifier
from deepchem.models.tensorgraph.models.graph_models import WeaveModel, DTNNModel, DAGTensorGraph, GraphConvModel, MPNNModel
from deepchem.models.tensorgraph.models.graph_models import WeaveModel, DTNNModel, DAGModel, GraphConvModel, MPNNModel
from deepchem.models.tensorgraph.models.symmetry_function_regression import BPSymmetryFunctionRegression, ANIRegression

from deepchem.models.tensorgraph.models.seqtoseq import SeqToSeq
+3 −3
Original line number Diff line number Diff line
@@ -358,7 +358,7 @@ class DTNNModel(TensorGraph):
    return undo_transforms(retval, transformers)


class DAGTensorGraph(TensorGraph):
class DAGModel(TensorGraph):

  def __init__(self,
               n_tasks,
@@ -390,7 +390,7 @@ class DAGTensorGraph(TensorGraph):
    self.n_graph_feat = n_graph_feat
    self.n_outputs = n_outputs
    self.mode = mode
    super(DAGTensorGraph, self).__init__(**kwargs)
    super(DAGModel, self).__init__(**kwargs)
    self.build_graph()

  def build_graph(self):
@@ -508,7 +508,7 @@ class DAGTensorGraph(TensorGraph):
        yield feed_dict

  def predict_on_generator(self, generator, transformers=[], outputs=None):
    out = super(DAGTensorGraph, self).predict_on_generator(
    out = super(DAGModel, self).predict_on_generator(
        generator, transformers=[], outputs=outputs)
    if outputs is None:
      outputs = self.outputs
+1 −1
Original line number Diff line number Diff line
@@ -661,7 +661,7 @@ class TestOverfit(test_util.TensorFlowTestCase):
    transformer = dc.trans.DAGTransformer(max_atoms=50)
    dataset = transformer.transform(dataset)

    model = dc.models.DAGTensorGraph(
    model = dc.models.DAGModel(
        n_tasks,
        max_atoms=50,
        n_atom_feat=n_feat,
+2 −2
Original line number Diff line number Diff line
@@ -223,7 +223,7 @@ def benchmark_classification(train_dataset,
      test_dataset.reshard(reshard_size)
      test_dataset = transformer.transform(test_dataset)

    model = deepchem.models.DAGTensorGraph(
    model = deepchem.models.DAGModel(
        len(tasks),
        max_atoms=max_atoms,
        n_atom_feat=n_features,
@@ -558,7 +558,7 @@ def benchmark_regression(train_dataset,
      test_dataset.reshard(reshard_size)
      test_dataset = transformer.transform(test_dataset)

    model = deepchem.models.DAGTensorGraph(
    model = deepchem.models.DAGModel(
        len(tasks),
        max_atoms=max_atoms,
        n_atom_feat=n_features,
+1 −1
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ train_dataset = transformer.transform(train_dataset)
valid_dataset.reshard(reshard_size)
valid_dataset = transformer.transform(valid_dataset)

model = dc.models.DAGTensorGraph(
model = dc.models.DAGModel(
    len(delaney_tasks),
    max_atoms=max_atoms,
    n_atom_feat=n_atom_feat,
Loading