Commit 1c4e9ba6 authored by peastman's avatar peastman
Browse files

Optimizations to GraphConvModel

parent ab097fde
Loading
Loading
Loading
Loading
+19 −23
Original line number Diff line number Diff line
@@ -96,14 +96,13 @@ class GraphConv(tf.keras.layers.Layer):
    # Get collection of modified atom features
    new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

    split_features = tf.split(atom_features, deg_slice[:, 1])
    for deg in range(1, self.max_degree + 1):
      # Obtain relevant atoms for this degree
      rel_atoms = deg_summed[deg - 1]

      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[deg - self.min_degree]

      # Apply hidden affine to relevant atoms and append
      rel_out = tf.matmul(rel_atoms, next(W)) + next(b)
@@ -114,16 +113,12 @@ class GraphConv(tf.keras.layers.Layer):

    # Determine the min_deg=0 case
    if self.min_degree == 0:
      deg = 0

      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[0]

      # Only use the self layer
      out = tf.matmul(self_atoms, next(W)) + next(b)

      new_rel_atoms_collection[deg - self.min_degree] = out
      new_rel_atoms_collection[0] = out

    # Combine all atoms back into the list
    atom_features = tf.concat(axis=0, values=new_rel_atoms_collection)
@@ -173,12 +168,15 @@ class GraphPool(tf.keras.layers.Layer):

    # Tensorflow correctly processes empty lists when using concat

    split_features = tf.split(atom_features, deg_slice[:, 1])
    for deg in range(1, self.max_degree + 1):
      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[deg - self.min_degree]

      if deg_adj_lists[deg - 1].shape[0] == 0:
        # There are no neighbors of this degree, so just create an empty tensor directly.
        maxed_atoms = tf.zeros((0, self_atoms.shape[-1]))
      else:
        # Expand dims
        self_atoms = tf.expand_dims(self_atoms, 1)

@@ -190,9 +188,7 @@ class GraphPool(tf.keras.layers.Layer):
      deg_maxed[deg - self.min_degree] = maxed_atoms

    if self.min_degree == 0:
      begin = tf.stack([deg_slice[0, 0], 0])
      size = tf.stack([deg_slice[0, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[0]
      deg_maxed[0] = self_atoms

    return tf.concat(axis=0, values=deg_maxed)