Commit 10d73eb2 authored by miaecle's avatar miaecle
Browse files

building weave layers

parent 02d5b59c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ from deepchem.feat.base_classes import Featurizer
from deepchem.feat.base_classes import ComplexFeaturizer
from deepchem.feat.base_classes import UserDefinedFeaturizer
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.graph_features import WeaveFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
from deepchem.feat.basic import RDKitDescriptors
from deepchem.feat.coulomb_matrices import CoulombMatrix
+149 −68
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from rdkit import Chem
import itertools, operator

from deepchem.feat import Featurizer
from deepchem.feat.mol_graphs import ConvMol
from deepchem.feat.mol_graphs import ConvMol, WeaveMol


def one_of_k_encoding(x, allowable_set):
  if x not in allowable_set:
@@ -15,12 +16,14 @@ def one_of_k_encoding(x, allowable_set):
        "input {0} not in allowable set{1}:".format(x, allowable_set))
  return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
  """Maps inputs not in the allowable set to the last element."""
  if x not in allowable_set:
    x = allowable_set[-1]
  return list(map(lambda s: x == s, allowable_set))


def get_intervals(l):
  """For list of lists, gets the cumulative products of the lengths"""
  intervals = len(l) * [0]
@@ -31,6 +34,7 @@ def get_intervals(l):

  return intervals


def safe_index(l, e):
  """Gets the index of e in l, providing an index of len(l) if not found"""
  try:
@@ -38,22 +42,26 @@ def safe_index(l, e):
  except:
    return len(l)

possible_atom_list = ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br',
                      'Fe', 'Ca', 'Cu', 'Mc', 'Pd', 'Pb',
                      'K','I','Al','Ni','Mn']

possible_atom_list = [
    'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu',
    'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn'
]
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
possible_hybridization_list = [Chem.rdchem.HybridizationType.SP,
                               Chem.rdchem.HybridizationType.SP2,
                               Chem.rdchem.HybridizationType.SP3,
                               Chem.rdchem.HybridizationType.SP3D,
                               Chem.rdchem.HybridizationType.SP3D2]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]

reference_lists = [possible_atom_list, possible_numH_list,
                   possible_valence_list, possible_formal_charge_list,
                   possible_number_radical_e_list, possible_hybridization_list]
reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
    possible_formal_charge_list, possible_number_radical_e_list,
    possible_hybridization_list
]

intervals = get_intervals(reference_lists)

@@ -70,6 +78,7 @@ def get_feature_list(atom):

  return features


def features_to_id(features, intervals):
  """Convert list of features into index using spacings provided in intervals"""
  id = 0
@@ -80,6 +89,7 @@ def features_to_id(features, intervals):
  id = id + 1
  return id


def id_to_features(id, intervals):
  features = 6 * [0]

@@ -94,62 +104,128 @@ def id_to_features(id, intervals):
  features[0] = id
  return features


def atom_to_id(atom):
  """Return a unique id corresponding to the atom type"""
  features = get_feature_list(atom)
  return features_to_id(features, intervals)


def atom_features(atom, bool_id_feat=False):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
    return np.array(one_of_k_encoding_unk(
        atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
         'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',    # H?
         'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) +
        [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +
    return np.array(
        one_of_k_encoding_unk(
            atom.GetHybridization(),
            [Chem.rdchem.HybridizationType.SP,
             Chem.rdchem.HybridizationType.SP2,
             Chem.rdchem.HybridizationType.SP3,
             Chem.rdchem.HybridizationType.SP3D,
             Chem.rdchem.HybridizationType.SP3D2]) +
        [atom.GetIsAromatic()])
            atom.GetSymbol(),
            [
                'C',
                'N',
                'O',
                'S',
                'F',
                'Si',
                'P',
                'Cl',
                'Br',
                'Mg',
                'Na',
                'Ca',
                'Fe',
                'As',
                'Al',
                'I',
                'B',
                'V',
                'K',
                'Tl',
                'Yb',
                'Sb',
                'Sn',
                'Ag',
                'Pd',
                'Co',
                'Se',
                'Ti',
                'Zn',
                'H',  # H?
                'Li',
                'Ge',
                'Cu',
                'Au',
                'Ni',
                'Cd',
                'In',
                'Mn',
                'Zr',
                'Cr',
                'Pt',
                'Hg',
                'Pb',
                'Unknown'
            ]) + one_of_k_encoding(atom.GetDegree(), [
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
            ]) + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])
        + [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +
        one_of_k_encoding_unk(atom.GetHybridization(), [
            Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
            SP3D, Chem.rdchem.HybridizationType.SP3D2
        ]) + [atom.GetIsAromatic()])


def bond_features(bond):
  bt = bond.GetBondType()
  return np.array([bt == Chem.rdchem.BondType.SINGLE,
                   bt == Chem.rdchem.BondType.DOUBLE,
                   bt == Chem.rdchem.BondType.TRIPLE,
                   bt == Chem.rdchem.BondType.AROMATIC,
                   bond.GetIsConjugated(),
                   bond.IsInRing()])
  
def pair_features(mol, canon_adj_list):
  features = np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms(), 12))
  for a1 in mol.GetAtoms():
    a1_id = a1.GetIdx()
    for a2 in mol.GetAtoms():
      a2_id = a2.GetIdx()
      if a2_id in canon_adj_list[a1_id]:
        bt = bond_features(mol.GetBondBetweenAtoms(a1_id, a2_id))
  return np.array([bt == Chem.rdchem.BondType.SINGLE,
                   bt == Chem.rdchem.BondType.DOUBLE,
                   bt == Chem.rdchem.BondType.TRIPLE,
                   bt == Chem.rdchem.BondType.AROMATIC,
                   bond.GetIsConjugated(),
                   bond.IsInRing()])
  return np.array([
      bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
      bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(), bond.IsInRing()
  ])


def pair_features(mol, edge_list, canon_adj_list, bt_len=6):
  max_distance = 7
  features = np.zeros(
      (mol.GetNumAtoms(), mol.GetNumAtoms(), bt_len + max_distance + 1))
  num_atoms = mol.GetNumAtoms()
  rings = mol.GetRingInfo().AtomRings()
  for a1 in range(num_atoms):
    for a2 in canon_adj_list[a1]:
      features[a1, a2, :bt_len] = np.asarray(
          edge_list[tuple(sorted((a1, a2)))], dtype=float)
    for ring in rings:
      if a1 in ring:
        features[a1, ring, bt_len] = 1
        features[a1, a1, bt_len] = 0.
    # find graph distance between two atoms
    distance = find_distance(
        a1, num_atoms, canon_adj_list, max_distance=max_distance)
    features[a1, :, bt_len + 1:] = distance

  return features


def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):
  distance = np.zeros((num_atoms, max_distance))
  radial = 0
  adj_list = set(canon_adj_list[a1])
  all_list = set([a1])
  while radial < max_distance:
    distance[list(adj_list), radial] = 1
    all_list.update(adj_list)
    next_adj = set()
    for adj in adj_list:
      next_adj.update(canon_adj_list[adj])
    adj_list = next_adj - all_list
    radial = radial + 1
  return distance


class ConvMolFeaturizer(Featurizer):

  name = ['conv_mol']

  def __init__(self):
    # Since ConvMol is an object and not a numpy array, need to set dtype to
    # object.
@@ -166,7 +242,8 @@ class ConvMolFeaturizer(Featurizer):
    nodes = np.vstack(nodes)

    # Get bond lists with reverse edges included
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx())
                 for b in mol.GetBonds()]

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
@@ -180,13 +257,13 @@ class ConvMolFeaturizer(Featurizer):
class WeaveFeaturizer(Featurizer):

  name = ['weave_mol']

  def __init__(self):
    # Since ConvMol is an object and not a numpy array, need to set dtype to
    # object.
    # Set dtype
    self.dtype = object

  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
    """Encodes mol as a WeaveMol object."""
    # Atom features
    idx_nodes = [(a.GetIdx(), atom_features(a)) for a in mol.GetAtoms()]
    idx_nodes.sort()  # Sort by ind to ensure same order as rd_kit
@@ -195,15 +272,19 @@ class WeaveFeaturizer(Featurizer):
    # Stack nodes into an array
    nodes = np.vstack(nodes)

    # Get bond lists with reverse edges included
    edge_list = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]
    # Get bond lists
    edge_list = {}
    for b in mol.GetBonds():
      edge_list[tuple(sorted([b.GetBeginAtomIdx(), b.GetEndAtomIdx()
                             ]))] = bond_features(b)

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
    for edge in edge_list:
    for edge in edge_list.keys():
      canon_adj_list[edge[0]].append(edge[1])
      canon_adj_list[edge[1]].append(edge[0])

    pairs = pair_features(mol, canon_adj_list)
    # Calculate pair features
    pairs = pair_features(mol, edge_list, canon_adj_list, bt_len=6)

    return ConvMol(nodes, canon_adj_list)
 No newline at end of file
    return WeaveMol(nodes, pairs)
+25 −0
Original line number Diff line number Diff line
@@ -388,3 +388,28 @@ class MultiConvMol(object):

  def get_num_molecules(self):
    return self.num_mols


class WeaveMol(object):
  """Holds information about a molecule
  Molecule struct used in weave models
  """

  def __init__(self, nodes, pairs):

    self.nodes = nodes
    self.pairs = pairs
    self.num_atom = self.nodes.shape[0]
    self.n_features = self.nodes.shape[1]

  def get_pair_features(self):
    return self.pairs

  def get_atom_features(self):
    return self.nodes

  def get_num_atoms(self):
    return self.num_atoms

  def get_num_features(self):
    return self.n_features
 No newline at end of file
+2 −0
Original line number Diff line number Diff line
@@ -31,6 +31,8 @@ def load_delaney(featurizer='ECFP', split='index'):
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
    featurizer = deepchem.feat.ConvMolFeaturizer()
  elif featurizer == 'Weave':
    featurizer = deepchem.feat.WeaveFeaturizer()
  elif featurizer == 'Raw':
    featurizer = deepchem.feat.RawFeaturizer()

+251 −0
Original line number Diff line number Diff line
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 30 14:02:04 2017

@author: michael
"""

from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
from deepchem.nn import activations
from deepchem.nn import initializations
from deepchem.nn import model_ops
from deepchem.nn.copy import Layer


class WeaveLayer(Layer):
  """" Main layer of DAG model
  For a molecule with n atoms, n different graphs are generated and run through
  The final outputs of each graph become the graph features of corresponding
  atom, which will be summed and put into another network in DAGGather Layer
  """

  def __init__(self,
               n_atom_input_feat=75,
               n_pair_input_feat=14,
               n_atom_output_feat=50,
               n_pair_output_feat=50,
               n_hidden_AA=50,
               n_hidden_PA=50,
               n_hidden_AP=50,
               n_hidden_PP=50,
               init='glorot_uniform',
               activation='relu',
               dropout=None,
               **kwargs):
    """
    Parameters
    ----------
    n_atom_input_feat: int
      Number of features for each atom in input.
    n_pair_input_feat: int
      Number of features for each pair of atoms in input.
    n_atom_output_feat: int
      Number of features for each atom in output.
    n_pair_output_feat: int
      Number of features for each pair of atoms in output.
    n_hidden_XX: int
      Number of units(convolution depths) in corresponding hidden layer
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied
    dropout: float, optional
      Dropout probability, not supported here

    """
    super(WeaveLayer, self).__init__(**kwargs)

    self.init = initializations.get(init)  # Set weight initialization
    self.activation = activations.get(activation)  # Get activations
    self.n_hidden_AA = n_hidden_AA
    self.n_hidden_PA = n_hidden_PA
    self.n_hidden_AP = n_hidden_AP
    self.n_hidden_PP = n_hidden_PP
    self.n_hidden_A = n_hidden_AA + n_hidden_PA
    self.n_hidden_P = n_hidden_AP + n_hidden_PP

    self.n_atom_input_feat = n_atom_input_feat
    self.n_pair_input_feat = n_pair_input_feat
    self.n_atom_output_feat = n_atom_output_feat
    self.n_pair_output_feat = n_pair_output_feat

  def build(self):
    """"Construct internal trainable weights.
    """

    self.W_AA = self.init([self.n_atom_input_feat, self.n_hidden_AA])
    self.b_AA = model_ops.zeros(shape=[
        self.n_hidden_AA,
    ])

    self.W_PA = self.init([self.n_pair_input_feat, self.n_hidden_PA])
    self.b_PA = model_ops.zeros(shape=[
        self.n_hidden_PA,
    ])

    self.W_A = self.init([self.n_hidden_A, self.n_atom_output_feat])
    self.b_A = model_ops.zeros(shape=[
        self.n_atom_output_feat,
    ])

    self.W_AP = self.init([self.n_atom_input_feat * 2, self.n_hidden_AP])
    self.b_AP = model_ops.zeros(shape=[
        self.n_hidden_AP,
    ])

    self.W_PP = self.init([self.n_pair_input_feat, self.n_hidden_PP])
    self.b_PP = model_ops.zeros(shape=[
        self.n_hidden_PP,
    ])

    self.W_P = self.init([self.n_hidden_P, self.n_pair_output_feat])
    self.b_P = model_ops.zeros(shape=[
        self.n_pair_output_feat,
    ])

    self.trainable_weights = self.W_AA + self.b_AA + self.W_PA + self.b_PA + \
        self.W_A + self.b_A + self.W_AP + self.b_AP + self.W_PP + self.b_PP + \
        self.W_P + self.b_P

  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    x = [atom_features, pair_features, atom_mask, pair_mask]
    
    Parameters
    ----------
    x: list
      list of Tensors of form described above.
    mask: bool, optional
      Ignored. Present only to shadow superclass call() method.

    Returns
    -------
    A: Tensor
      Tensor of atom_features
    P: Tensor
      Tensor of pair_features
    """
    # Add trainable weights
    self.build()

    atom_features = x[0]
    pair_features = x[1]

    atom_mask = x[2]
    pair_mask = x[3]
    max_atoms = atom_features.get_shape().as_list()[1]

    AA = tf.tensordot(atom_features, self.W_AA, [[2], [0]]) + self.b_AA
    PA = tf.reduce_sum(
        tf.tensordot(pair_features, self.W_PA, [[3], [0]]) + self.b_PA, axis=2)
    A = tf.tensordot(tf.concat([AA, PA], 2), self.W_A, [[2], [0]]) + self.b_A
    AP_combine = tf.concat([
        tf.stack([atom_features] * max_atoms, axis=2),
        tf.stack([atom_features] * max_atoms, axis=1)
    ], 3)
    AP_combine_t = tf.transpose(AP_combine, perm=[0, 2, 1, 3])
    AP = tf.tensordot(AP_combine + AP_combine_t, self.W_AP,
                      [[3], [0]]) + self.b_AP
    PP = tf.tensordot(pair_features, self.W_PP, [[3], [0]]) + self.b_PP
    P = tf.tensordot(tf.concat([AP, PP], 3), self.W_P, [[3], [0]]) + self.b_P

    A = tf.multiply(A, tf.expand_dims(atom_mask, axis=2))
    P = tf.multiply(P, tf.expand_dims(pair_mask, axis=3))
    return A, P


class WeaveGather(Layer):
  """" Main layer of DAG model
  For a molecule with n atoms, n different graphs are generated and run through
  The final outputs of each graph become the graph features of corresponding
  atom, which will be summed and put into another network in DAGGather Layer
  """

  def __init__(self,
               n_atom_input_feat=50,
               n_hidden=128,
               init='glorot_uniform',
               activation='relu',
               gaussian_expand=True,
               dropout=None,
               **kwargs):
    """
    Parameters
    ----------
    n_atom_input_feat: int
      Number of features for each atom in input.
    n_pair_input_feat: int
      Number of features for each pair of atoms in input.
    n_atom_output_feat: int
      Number of features for each atom in output.
    n_pair_output_feat: int
      Number of features for each pair of atoms in output.
    n_hidden_XX: int
      Number of units(convolution depths) in corresponding hidden layer
    init: str, optional
      Weight initialization for filters.
    activation: str, optional
      Activation function applied
    dropout: float, optional
      Dropout probability, not supported here
    gaussian_expand: boolean. optional
      Whether to expand each dimension of atomic features by gaussian histogram

    """
    super(WeaveLayer, self).__init__(**kwargs)

    self.init = initializations.get(init)  # Set weight initialization
    self.activation = activations.get(activation)  # Get activations
    self.n_hidden = n_hidden
    self.n_atom_input_feat = n_atom_input_feat
    self.gaussian_expand = gaussian_expand
    if gaussian_expand:
      self.n_outputs = self.n_hidden * 11
    else:
      self.n_outputs = self.n_hidden

  def build(self):
    """"Construct internal trainable weights.
    """

    self.W = self.init([self.n_atom_input_feat, self.n_hidden])
    self.b = model_ops.zeros(shape=[
        self.n_hidden,
    ])

    self.trainable_weights = self.W + self.b

  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    Parameters
    ----------
    x: Tensor
      Tensors of atom features
    mask: bool, optional
      Ignored. Present only to shadow superclass call() method.

    Returns
    -------
    outputs: Tensor
      Tensor of molecular features
    """
    # Add trainable weights
    self.build()

    outputs = tf.tensordot(x, self.W, [[2], [0]]) + self.b
    if self.gaussian_expand:
      outputs = self.gaussian_histogram(outputs)
    outputs = tf.reduce_sum(outputs, axis=1)
    return outputs

  def gaussian_histogram(x):

    return x