Unverified Commit 718bb940 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1335 from lilleswing/fast-gc

Fast GraphConvs
parents 8abbf5d2 8738a7b5
Loading
Loading
Loading
Loading
+10 −15
Original line number Diff line number Diff line
@@ -86,6 +86,10 @@ class ConvMol(object):
        for deg in range(self.min_deg, self.max_deg + 1)
    ]

    self.degree_list = []
    for i, deg in enumerate(range(self.min_deg, self.max_deg + 1)):
      self.degree_list.extend([deg] * deg_size[i])

    # Get the the start indices for items in each block
    self.deg_start = cumulative_sum(deg_size)

@@ -264,17 +268,11 @@ class ConvMol(object):

    num_mols = len(mols)

    atoms_per_mol = [mol.get_num_atoms() for mol in mols]

    # Get atoms by degree
    atoms_by_deg = [
        mol.get_atoms_with_deg(deg)
        for deg in range(min_deg, max_deg + 1)
        for mol in mols
    ]

    # stack the atoms
    all_atoms = np.vstack(atoms_by_deg)
    # Results should be sorted by (atom_degree, mol_index)
    atoms_by_deg = np.concatenate([x.atom_features for x in mols])
    degree_vector = np.concatenate([x.degree_list for x in mols], axis=0)
    # Mergesort is a "stable" sort, so the array maintains it's secondary sort of mol_index
    all_atoms = atoms_by_deg[degree_vector.argsort(kind='mergesort')]

    # Sort all atoms by degree.
    # Get the size of each atom list separated by molecule id, then by degree
@@ -297,8 +295,7 @@ class ConvMol(object):

    # Determines the membership (atom i belongs to membership[i] molecule)
    membership = [
        k
        for deg in range(min_deg, max_deg + 1) for k in range(num_mols)
        k for deg in range(min_deg, max_deg + 1) for k in range(num_mols)
        for i in range(mol_deg_sz[deg][k])
    ]

@@ -371,7 +368,6 @@ class MultiConvMol(object):
  """

  def __init__(self, nodes, deg_adj_lists, deg_slice, membership, num_mols):

    self.nodes = nodes
    self.deg_adj_lists = deg_adj_lists
    self.deg_slice = deg_slice
@@ -398,7 +394,6 @@ class WeaveMol(object):
  """

  def __init__(self, nodes, pairs):

    self.nodes = nodes
    self.pairs = pairs
    self.num_atoms = self.nodes.shape[0]
+41 −0
Original line number Diff line number Diff line
from unittest import TestCase

import numpy as np
from deepchem.feat import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
from deepchem.molnet import load_bace_classification


class TestConvMol(TestCase):

  def get_molecules(self):
    tasks, all_dataset, transformers = load_bace_classification(
        featurizer="Raw")
    return all_dataset[0].X

  def test_mol_ordering(self):
    mols = self.get_molecules()
    featurizer = ConvMolFeaturizer()
    featurized_mols = featurizer.featurize(mols)
    for i in range(len(featurized_mols)):
      atom_features = featurized_mols[i].atom_features
      degree_list = np.expand_dims(featurized_mols[i].degree_list, axis=1)
      atom_features = np.concatenate([degree_list, atom_features], axis=1)
      featurized_mols[i].atom_features = atom_features

    conv_mol = ConvMol.agglomerate_mols(featurized_mols)

    for start, end in conv_mol.deg_slice.tolist():
      members = conv_mol.membership[start:end]
      sorted_members = np.array(sorted(members))
      members = np.array(members)
      self.assertTrue(np.all(sorted_members == members))

    conv_mol_atom_features = conv_mol.get_atom_features()

    adj_number = 0
    for start, end in conv_mol.deg_slice.tolist():
      deg_features = conv_mol_atom_features[start:end]
      adj_number_array = deg_features[:, 0]
      self.assertTrue(np.all(adj_number_array == adj_number))
      adj_number += 1
+4 −19
Original line number Diff line number Diff line
@@ -2737,27 +2737,12 @@ class GraphGather(Layer):
    # Extract graph topology
    membership = inputs[2]

    # Perform the mol gather

    assert self.batch_size > 1, "graph_gather requires batches larger than 1"

    # Obtain the partitions for each of the molecules
    activated_par = tf.dynamic_partition(atom_features, membership,
    sparse_reps = tf.unsorted_segment_sum(atom_features, membership,
                                          self.batch_size)
    max_reps = tf.unsorted_segment_max(atom_features, membership,
                                       self.batch_size)

    # Sum over atoms for each molecule
    sparse_reps = [
        tf.reduce_mean(activated, 0, keepdims=True)
        for activated in activated_par
    ]
    max_reps = [
        tf.reduce_max(activated, 0, keepdims=True)
        for activated in activated_par
    ]

    # Get the final sparse representations
    sparse_reps = tf.concat(axis=0, values=sparse_reps)
    max_reps = tf.concat(axis=0, values=max_reps)
    mol_features = tf.concat(axis=1, values=[sparse_reps, max_reps])

    if self.activation_fn is not None:
+2 −3
Original line number Diff line number Diff line
@@ -355,8 +355,8 @@ class DTNNModel(TensorGraph):
          start = start + num_atoms[im]
        feed_dict[self.atom_number] = np.concatenate(atom_number)
        distance = np.concatenate(distance, 0)
        feed_dict[self.distance] = np.exp(-np.square(distance - self.steps) /
                                          (2 * self.step_size**2))
        feed_dict[self.distance] = np.exp(
            -np.square(distance - self.steps) / (2 * self.step_size**2))
        feed_dict[self.distance_membership_i] = np.concatenate(
            distance_membership_i)
        feed_dict[self.distance_membership_j] = np.concatenate(
@@ -598,7 +598,6 @@ class GraphConvModel(TensorGraph):
    self.mode = mode
    self.dense_layer_size = dense_layer_size
    self.graph_conv_layers = graph_conv_layers
    kwargs['use_queue'] = False
    self.number_atom_features = number_atom_features
    self.n_classes = n_classes
    self.uncertainty = uncertainty