Unverified Commit 28bcc4f3 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1081 from lilleswing/pin-graphgather

Pin Graph Gather To the CPU
parents 4154d799 a7e4cc68
Loading
Loading
Loading
Loading
+32 −31
Original line number Diff line number Diff line
@@ -1547,7 +1547,7 @@ class SoftMaxCrossEntropy(Layer):
    if len(inputs) != 2:
      raise ValueError()
    labels, logits = inputs[0], inputs[1]
    out_tensor = tf.nn.softmax_cross_entropy_with_logits(
    out_tensor = tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=logits, labels=labels)
    if set_tensors:
      self.out_tensor = out_tensor
@@ -2445,6 +2445,7 @@ class GraphGather(Layer):
    super(GraphGather, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    with tf.device('/cpu'):
      inputs = self._get_input_tensors(in_layers)

      # x = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
@@ -2463,11 +2464,11 @@ class GraphGather(Layer):

      # Sum over atoms for each molecule
      sparse_reps = [
        tf.reduce_mean(activated, 0, keep_dims=True)
          tf.reduce_mean(activated, 0, keepdims=True)
          for activated in activated_par
      ]
      max_reps = [
        tf.reduce_max(activated, 0, keep_dims=True)
          tf.reduce_max(activated, 0, keepdims=True)
          for activated in activated_par
      ]