Unverified Commit ee8ab952 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1033 from lilleswing/master-graph-atom

Master Atom In Graph Featurizer
parents fe11a1d0 67a35f37
Loading
Loading
Loading
Loading
+88 −68
Original line number Diff line number Diff line
@@ -12,8 +12,8 @@ from deepchem.feat.mol_graphs import ConvMol, WeaveMol

def one_of_k_encoding(x, allowable_set):
  if x not in allowable_set:
    raise Exception(
        "input {0} not in allowable set{1}:".format(x, allowable_set))
    raise Exception("input {0} not in allowable set{1}:".format(
        x, allowable_set))
  return list(map(lambda s: x == s, allowable_set))


@@ -247,13 +247,25 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):


class ConvMolFeaturizer(Featurizer):

  name = ['conv_mol']

  def __init__(self):
    # Since ConvMol is an object and not a numpy array, need to set dtype to
    # object.
  def __init__(self, master_atom=False):
    """
    Parameters
    ----------
    master_atom: Boolean
      if true create a fake atom with bonds to every other atom.
      the initialization is the mean of the other atom features in
      the molecule.  This technique is briefly discussed in
      Neural Message Passing for Quantum Chemistry
      https://arxiv.org/pdf/1704.01212.pdf


    Since ConvMol is an object and not a numpy array, need to set dtype to
    object.
    """
    self.dtype = object
    self.master_atom = master_atom

  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
@@ -264,10 +276,14 @@ class ConvMolFeaturizer(Featurizer):

    # Stack nodes into an array
    nodes = np.vstack(nodes)
    if self.master_atom:
      master_atom_features = np.expand_dims(np.mean(nodes, axis=0), axis=0)
      nodes = np.concatenate([nodes, master_atom_features], axis=0)

    # Get bond lists with reverse edges included
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx())
                 for b in mol.GetBonds()]
    edge_list = [
        (b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()
    ]

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
@@ -275,11 +291,15 @@ class ConvMolFeaturizer(Featurizer):
      canon_adj_list[edge[0]].append(edge[1])
      canon_adj_list[edge[1]].append(edge[0])

    if self.master_atom:
      fake_atom_index = len(nodes) - 1
      for index in range(len(nodes) - 1):
        canon_adj_list[index].append(fake_atom_index)

    return ConvMol(nodes, canon_adj_list)


class WeaveFeaturizer(Featurizer):

  name = ['weave_mol']

  def __init__(self, graph_distance=True, explicit_H=False):