Commit 98b329e2 authored by leswing's avatar leswing
Browse files

Pin Graph Gather

parent 4154d799
Loading
Loading
Loading
Loading
+31 −30
Original line number Diff line number Diff line
@@ -2445,6 +2445,7 @@ class GraphGather(Layer):
    super(GraphGather, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    with tf.device('/cpu'):
      inputs = self._get_input_tensors(in_layers)

      # x = [atom_features, deg_slice, membership, deg_adj_list placeholders...]