Commit db43fe5c authored by Vignesh's avatar Vignesh
Browse files

Fixes to remove batch size dependence

parent fd536039
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -54,11 +54,9 @@ class AdaptiveFilter(Layer):
  def _build(self):
    if self.combine_method == "linear":
      self.Q = self.init(
          shape=(self.batch_size, self.num_nodes,
                 self.num_nodes + self.num_node_features))
          shape=(self.num_nodes + self.num_node_features, self.num_nodes))
    else:
      self.Q = self.init(
          shape=(self.batch_size, self.num_node_features, self.num_nodes))
      self.Q = self.init(shape=(self.num_node_features, self.num_nodes))

    self.trainable_weights = [self.Q]

@@ -74,10 +72,10 @@ class AdaptiveFilter(Layer):

    if self.combine_method == "linear":
      concatenated = tf.concat([A_tilda_k, X], axis=2)
      transposed = tf.transpose(concatenated, perm=[0, 2, 1])
      adp_fn_val = act_fn(tf.matmul(self.trainable_weights[0], transposed))
      adp_fn_val = act_fn(
          tf.tensordot(concatenated, self.trainable_weights[0], axes=1))
    else:
      adp_fn_val = act_fn(tf.matmul(A_tilda_k, tf.matmul(X, self.Q)))
      adp_fn_val = act_fn(tf.matmul(A_tilda_k, tf.tensordot(X, self.Q, axes=1)))
    out_tensor = adp_fn_val
    if set_tensors:
      self.variables = self.trainable_weights
+32 −32
Original line number Diff line number Diff line
@@ -68,11 +68,11 @@ class HAGCN(TensorGraph):
          Feature(
              name="graph_adjacency_{}".format(k),
              dtype=tf.float32,
              shape=[self.batch_size, self.max_nodes, self.max_nodes]))
              shape=[None, self.max_nodes, self.max_nodes]))
    self.X = Feature(
        name='atom_features',
        dtype=tf.float32,
        shape=[self.batch_size, self.max_nodes, self.num_node_features])
        shape=[None, self.max_nodes, self.num_node_features])

    graph_layers = list()
    adaptive_filters = list()
@@ -195,33 +195,33 @@ class HAGCN(TensorGraph):

        yield feed_dict

  def predict(self, dataset, transformers=[], outputs=None):
    """
    Uses self to make predictions on provided Dataset object.

    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    transformers: list
      List of dc.trans.Transformers.
    outputs: object
      If outputs is None, then will assume outputs=self.default_outputs. If outputs is
      a Layer/Tensor, then will evaluate and return as a single ndarray. If
      outputs is a list of Layers/Tensors, will return a list of ndarrays.

    Returns
    -------
    results: numpy ndarray or list of numpy ndarrays
    """
    generator = self.default_generator(dataset, predict=True, pad_batches=True)
    preds = self.predict_on_generator(generator, transformers, outputs)
    if len(dataset.y) % self.batch_size == 0:
      return preds
    else:
      after_pad = (len(dataset.y) // self.batch_size + 1) * self.batch_size
      closest = (len(dataset.y) // self.batch_size) * self.batch_size
      remainder = len(dataset.y) % self.batch_size
      num_added = after_pad - remainder - closest
      preds = preds[:-num_added]
      return preds
  # def predict(self, dataset, transformers=[], outputs=None):
  #   """
  #   Uses self to make predictions on provided Dataset object.
  #
  #   Parameters
  #   ----------
  #   dataset: dc.data.Dataset
  #     Dataset to make prediction on
  #   transformers: list
  #     List of dc.trans.Transformers.
  #   outputs: object
  #     If outputs is None, then will assume outputs=self.default_outputs. If outputs is
  #     a Layer/Tensor, then will evaluate and return as a single ndarray. If
  #     outputs is a list of Layers/Tensors, will return a list of ndarrays.
  #
  #   Returns
  #   -------
  #   results: numpy ndarray or list of numpy ndarrays
  #   """
  #   generator = self.default_generator(dataset, predict=True, pad_batches=True)
  #   preds = self.predict_on_generator(generator, transformers, outputs)
  #   if len(dataset.y) % self.batch_size == 0:
  #     return preds
  #   else:
  #     after_pad = (len(dataset.y) // self.batch_size + 1) * self.batch_size
  #     closest = (len(dataset.y) // self.batch_size) * self.batch_size
  #     remainder = len(dataset.y) % self.batch_size
  #     num_added = after_pad - remainder - closest
  #     preds = preds[:-num_added]
  #     return preds