Commit 85124381 authored by Michelle Gill's avatar Michelle Gill
Browse files

PetroskiSuchTensorGraph to PetroskiSuchModel

parent 591ae6e8
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -519,7 +519,7 @@ class DAGTensorGraph(TensorGraph):
    return out


class PetroskiSuchTensorGraph(TensorGraph):
class PetroskiSuchModel(TensorGraph):
  """
      Model from Robust Spatial Filtering with Graph Convolutional Neural Networks
      https://arxiv.org/abs/1703.00792
@@ -545,7 +545,7 @@ class PetroskiSuchTensorGraph(TensorGraph):
    self.error_bars = True if 'error_bars' in kwargs and kwargs['error_bars'] else False
    self.dropout = dropout
    kwargs['use_queue'] = False
    super(PetroskiSuchTensorGraph, self).__init__(**kwargs)
    super(PetroskiSuchModel, self).__init__(**kwargs)
    self.build_graph()

  def build_graph(self):
+2 −2
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ import tensorflow as tf
tf.set_random_seed(123)
import deepchem as dc
from deepchem.molnet import load_tox21
from deepchem.models.tensorgraph.models.graph_models import PetroskiSuchTensorGraph
from deepchem.models.tensorgraph.models.graph_models import PetroskiSuchModel

model_dir = "/tmp/graph_conv"

@@ -32,7 +32,7 @@ metric = dc.metrics.Metric(
# Batch size of models
batch_size = 128

model = PetroskiSuchTensorGraph(
model = PetroskiSuchModel(
    len(tox21_tasks), batch_size=batch_size, mode='classification')

model.fit(train_dataset, nb_epoch=10)