Unverified Commit cd4e8a5c authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1862 from peastman/conv

Minor optimizations to GraphConvModel
parents b4cdf474 ad6379a9
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ def cumulative_sum_minus_last(l, offset=0):
  l: list
    List of integers. Typically small counts.
  """
  return np.delete(np.insert(np.cumsum(l), 0, 0), -1) + offset
  return np.delete(np.insert(np.cumsum(l, dtype=np.int32), 0, 0), -1) + offset


def cumulative_sum(l, offset=0):
@@ -96,7 +96,7 @@ class ConvMol(object):
    ]

    # Convert to numpy array
    self.deg_block_indices = np.array(deg_block_indices)
    self.deg_block_indices = np.array(deg_block_indices, dtype=np.int32)

  def get_atoms_with_deg(self, deg):
    """Retrieves atom_features with the specific degree"""
@@ -157,7 +157,7 @@ class ConvMol(object):
      to_cat = [self.canon_adj_list[i] for i in indices]
      if len(to_cat) > 0:
        adj_list = np.vstack([self.canon_adj_list[i] for i in indices])
        self.deg_adj_lists[deg - self.min_deg] = adj_list
        self.deg_adj_lists[deg - self.min_deg] = adj_list.astype(np.int32)

      else:
        self.deg_adj_lists[deg - self.min_deg] = np.zeros(
@@ -333,7 +333,7 @@ class ConvMol(object):
        for deg in range(min_deg, max_deg + 1)
    ]

    # Update the old adjcency lists with the new atom indices and then combine
    # Update the old adjacency lists with the new atom indices and then combine
    # all together
    for deg in range(min_deg, max_deg + 1):
      row = 0  # Initialize counter
+19 −23
Original line number Diff line number Diff line
@@ -96,14 +96,13 @@ class GraphConv(tf.keras.layers.Layer):
    # Get collection of modified atom features
    new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

    split_features = tf.split(atom_features, deg_slice[:, 1])
    for deg in range(1, self.max_degree + 1):
      # Obtain relevant atoms for this degree
      rel_atoms = deg_summed[deg - 1]

      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[deg - self.min_degree]

      # Apply hidden affine to relevant atoms and append
      rel_out = tf.matmul(rel_atoms, next(W)) + next(b)
@@ -114,16 +113,12 @@ class GraphConv(tf.keras.layers.Layer):

    # Determine the min_deg=0 case
    if self.min_degree == 0:
      deg = 0

      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[0]

      # Only use the self layer
      out = tf.matmul(self_atoms, next(W)) + next(b)

      new_rel_atoms_collection[deg - self.min_degree] = out
      new_rel_atoms_collection[0] = out

    # Combine all atoms back into the list
    atom_features = tf.concat(axis=0, values=new_rel_atoms_collection)
@@ -173,12 +168,15 @@ class GraphPool(tf.keras.layers.Layer):

    # Tensorflow correctly processes empty lists when using concat

    split_features = tf.split(atom_features, deg_slice[:, 1])
    for deg in range(1, self.max_degree + 1):
      # Get self atoms
      begin = tf.stack([deg_slice[deg - self.min_degree, 0], 0])
      size = tf.stack([deg_slice[deg - self.min_degree, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[deg - self.min_degree]

      if deg_adj_lists[deg - 1].shape[0] == 0:
        # There are no neighbors of this degree, so just create an empty tensor directly.
        maxed_atoms = tf.zeros((0, self_atoms.shape[-1]))
      else:
        # Expand dims
        self_atoms = tf.expand_dims(self_atoms, 1)

@@ -190,9 +188,7 @@ class GraphPool(tf.keras.layers.Layer):
      deg_maxed[deg - self.min_degree] = maxed_atoms

    if self.min_degree == 0:
      begin = tf.stack([deg_slice[0, 0], 0])
      size = tf.stack([deg_slice[0, 1], -1])
      self_atoms = tf.slice(atom_features, begin, size)
      self_atoms = split_features[0]
      deg_maxed[0] = self_atoms

    return tf.concat(axis=0, values=deg_maxed)