Commit bf9e1f2b authored by peastman's avatar peastman
Browse files

Fixed failing test cases

parent b57dc3d8
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -320,6 +320,12 @@ class DAGLayer(KerasLayer):
        self.n_graph_feat, self.n_atom_feat, self.max_atoms, self.layer_sizes,
        self.init, self.activation, self.dropout, self.batch_size)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    training = kwargs['training'] if 'training' in kwargs else 1.0
    inputs.append(training)
    return super(DAGLayer, self).create_tensor(inputs, set_tensors, **kwargs)


class DAGGather(KerasLayer):
  """ TensorGraph style implementation
@@ -368,6 +374,12 @@ class DAGGather(KerasLayer):
        self.n_graph_feat, self.n_outputs, self.max_atoms, self.layer_sizes,
        self.init, self.activation, self.dropout)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    training = kwargs['training'] if 'training' in kwargs else 1.0
    inputs.append(training)
    return super(DAGGather, self).create_tensor(inputs, set_tensors, **kwargs)


class MessagePassing(KerasLayer):
  """ General class for MPNN
+1 −1
Original line number Diff line number Diff line
@@ -1112,7 +1112,7 @@ class TestLayers(test_util.TensorFlowTestCase):
          activation=activation,
          init=init_method,
          layer_sizes=layer_sizes)
      dag_gather.create_tensor(in_layers=[atom_features_tf, membership_tf])
      dag_gather.create_tensor(in_layers=[atom_features_tf, membership_tf, 0])

      sess.run(tf.global_variables_initializer())
      output = dag_gather.out_tensor.eval()