Unverified Commit 698519d7 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2365 from MariBerry/master

Feature: Atom-based interpretation of GraphConvModel plus tutorial
parents a1d38afb b474155c
Loading
Loading
Loading
Loading
+122 −22
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ from deepchem.feat.complex_featurizers import ComplexNeighborListFragmentAtomicC
from deepchem.feat.mol_graphs import ConvMol, WeaveMol
from deepchem.data import DiskDataset
import logging
from typing import Optional, List
from typing import Optional, List, Union, Iterable
from deepchem.utils.typing import RDKitMol, RDKitAtom


@@ -630,6 +630,25 @@ class ConvMolFeaturizer(MolecularFeaturizer):
  Duvenaud graph convolutions [1]_ construct a vector of descriptors for each
  atom in a molecule. The featurizer computes that vector of local descriptors.

  Examples
  ---------
  >>> import deepchem as dc
  >>> smiles = ["C", "CCC"]
  >>> featurizer=dc.feat.ConvMolFeaturizer(per_atom_fragmentation=False)
  >>> f = featurizer.featurize(smiles)
  >>> # Using ConvMolFeaturizer to create featurized fragments derived from molecules of interest.
  ... # This is used only in the context of performing interpretation of models using atomic
  ... # contributions (atom-based model interpretation)
  ... smiles = ["C", "CCC"]
  >>> featurizer=dc.feat.ConvMolFeaturizer(per_atom_fragmentation=True)
  >>> f = featurizer.featurize(smiles)
  >>> len(f) # contains 2 lists with  featurized fragments from 2 mols
  2

  See Also
  --------
  Detailed examples of `GraphConvModel` interpretation are provided in Tutorial #28

  References
  ---------

@@ -643,8 +662,11 @@ class ConvMolFeaturizer(MolecularFeaturizer):
  """
  name = ['conv_mol']

  def __init__(self, master_atom=False, use_chirality=False,
               atom_properties=[]):
  def __init__(self,
               master_atom: bool = False,
               use_chirality: bool = False,
               atom_properties: Iterable[str] = [],
               per_atom_fragmentation: bool = False):
    """
    Parameters
    ----------
@@ -668,6 +690,13 @@ class ConvMolFeaturizer(MolecularFeaturizer):
      provided in atom_properties.  So "atom 00000000 sasa" would be the
      name of the molecule level property in mol where the solvent
      accessible surface area of atom 0 would be stored.
    per_atom_fragmentation: Boolean
      If True, then multiple "atom-depleted" versions of each molecule will be created (using featurize() method). 
      For each molecule, atoms are removed one at a time and the resulting molecule is featurized. 
      The result is a list of ConvMol objects,
      one with each heavy atom removed. This is useful for subsequent model interpretation: finding atoms
      favorable/unfavorable for (modelled) activity. This option is typically used in combination
      with a FlatteningTransformer to split the lists into separate samples.

    Since ConvMol is an object and not a numpy array, need to set dtype to
    object.
@@ -676,6 +705,43 @@ class ConvMolFeaturizer(MolecularFeaturizer):
    self.master_atom = master_atom
    self.use_chirality = use_chirality
    self.atom_properties = list(atom_properties)
    self.per_atom_fragmentation = per_atom_fragmentation

  def featurize(
      self,
      molecules: Union[RDKitMol, str, Iterable[RDKitMol], Iterable[str]],
      log_every_n: int = 1000) -> np.ndarray:
    """
    Override parent: aim is to add handling atom-depleted molecules featurization
    
    Parameters
    ----------
    molecules: rdkit.Chem.rdchem.Mol / SMILES string / iterable
      RDKit Mol, or SMILES string or iterable sequence of RDKit mols/SMILES
      strings.
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.

    Returns
    -------
    features: np.ndarray
      A numpy array containing a featurized representation of `datapoints`.
    """
    features = super(ConvMolFeaturizer, self).featurize(
        molecules, log_every_n=1000)
    if self.per_atom_fragmentation:
      # create temporary valid ids serving to filter out failed featurizations from every sublist
      # of features (i.e. every molecules' frags list), and also totally failed sublists.
      # This makes output digestable by Loaders
      valid_frag_inds = [[
          True if np.array(elt).size > 0 else False for elt in f
      ] for f in features]
      features = [[elt
                   for (is_valid, elt) in zip(l, m)
                   if is_valid]
                  for (l, m) in zip(valid_frag_inds, features)
                  if any(l)]
    return features

  def _get_atom_properties(self, atom):
    """
@@ -700,7 +766,39 @@ class ConvMolFeaturizer(MolecularFeaturizer):
    return np.array(values)

  def _featurize(self, mol):
    """Encodes mol as a ConvMol object."""
    """Encodes mol as a ConvMol object.
    If per_atom_fragmentation is True,
    then for each molecule a list of ConvMolObjects
    will be created"""

    def per_atom(n, a):
      """
      Enumerates fragments resulting from mol object,
      s.t. each fragment = mol with single atom removed (all possible removals are enumerated)
      Goes over nodes, deletes one at a time and updates adjacency list of lists (removes connections to that node)

      Parameters
      ----------
      n: np.array of nodes (number_of_nodes X number_of_features)
      a: list of nested lists of adjacent node pairs

      """
      for i in range(n.shape[0]):
        new_n = np.delete(n, (i), axis=0)
        new_a = []
        for j, node_pair in enumerate(a):
          if i != j:  # don't need this pair, no more connections to deleted node
            tmp_node_pair = []
            for v in node_pair:
              if v < i:
                tmp_node_pair.append(v)
              elif v > i:
                tmp_node_pair.append(
                    v -
                    1)  # renumber node, because of offset after node deletion
            new_a.append(tmp_node_pair)
        yield new_n, new_a

    # Get the node features
    idx_nodes = [(a.GetIdx(),
                  np.concatenate((atom_features(
@@ -721,7 +819,6 @@ class ConvMolFeaturizer(MolecularFeaturizer):
    edge_list = [
        (b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()
    ]

    # Get canonical adjacency list
    canon_adj_list = [[] for mol_id in range(len(nodes))]
    for edge in edge_list:
@@ -733,7 +830,10 @@ class ConvMolFeaturizer(MolecularFeaturizer):
      for index in range(len(nodes) - 1):
        canon_adj_list[index].append(fake_atom_index)

    if not self.per_atom_fragmentation:
      return ConvMol(nodes, canon_adj_list)
    else:
      return [ConvMol(n, a) for n, a in per_atom(nodes, canon_adj_list)]

  def feature_length(self):
    return 75 + len(self.atom_properties)
+11 −0
Original line number Diff line number Diff line
@@ -88,6 +88,17 @@ class TestConvMolFeaturizer(unittest.TestCase):
    assert np.array_equal(deg_adj_lists[5], np.zeros([0, 5], dtype=np.int32))
    assert np.array_equal(deg_adj_lists[6], np.zeros([0, 6], dtype=np.int32))

  def test_per_atom_fragmentation(self):
    """checks if instantiating featurizer with per_atom_fragmentation=True
    leads to  as many fragments' features, as many atoms mol has for any mol"""
    import rdkit.Chem
    raw_smiles = ['CC(CO)Cc1ccccc1', 'CC']
    mols = [rdkit.Chem.MolFromSmiles(m) for m in raw_smiles]
    featurizer = ConvMolFeaturizer(per_atom_fragmentation=True)
    feat = featurizer.featurize(mols)
    for i, j in zip(feat, mols):
      assert len(i) == j.GetNumHeavyAtoms()


class TestAtomicConvFeaturizer(unittest.TestCase):

+1 −0
Original line number Diff line number Diff line
@@ -21,4 +21,5 @@ from deepchem.trans.transformers import FeaturizationTransformer
from deepchem.trans.transformers import ImageTransformer
from deepchem.trans.transformers import DataTransforms
from deepchem.trans.transformers import Transformer
from deepchem.trans.transformers import FlatteningTransformer
from deepchem.trans.duplicate import DuplicateBalancingTransformer
+218 −0
Original line number Diff line number Diff line
10_filipski_40
     RDKit          3D

 48 50  0  0  1  0  0  0  0  0999 V2000
    9.1378   -7.4697   -1.1731 C   0  0  0  0  0  0  0  0  0  0  0  0
    9.0300   -8.7563   -1.7553 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.1829   -9.4791   -2.1168 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.4593   -8.9144   -1.9184 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.5888   -7.6306   -1.3431 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.4211   -6.9229   -0.9733 C   0  0  0  0  0  0  0  0  0  0  0  0
    8.0685   -6.6893   -0.7812 O   0  0  0  0  0  0  0  0  0  0  0  0
    6.7356   -7.1730   -0.9323 C   0  0  0  0  0  0  0  0  0  0  0  0
    5.8194   -5.9457   -0.8867 C   0  0  0  0  0  0  0  0  0  0  0  0
    6.3937   -8.1606    0.1955 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.0417  -10.7213   -2.6806 O   0  0  0  0  0  0  0  0  0  0  0  0
   10.6226  -11.7880   -2.0428 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.4794  -12.6365   -2.7738 C   0  0  0  0  0  0  0  0  0  0  0  0
   12.0777  -13.7503   -2.1503 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.8056  -14.0231   -0.7953 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.9593  -13.1740   -0.0542 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.3610  -12.0614   -0.6807 C   0  0  0  0  0  0  0  0  0  0  0  0
   12.5981  -15.4211    0.0061 S   0  0  0  0  0  0  0  0  0  0  0  0
   14.1883  -14.7546    0.5873 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.8095  -15.8020    1.1921 O   0  0  0  0  0  0  0  0  0  0  0  0
   12.8865  -16.4503   -1.0091 O   0  0  0  0  0  0  0  0  0  0  0  0
   12.9447   -7.0276   -1.1268 C   0  0  0  0  0  0  0  0  0  0  0  0
   14.1048   -7.6753   -1.5778 N   0  0  0  0  0  0  0  0  0  0  0  0
   15.3664   -7.2188   -1.4378 C   0  0  0  0  0  0  0  0  0  0  0  0
   15.4761   -5.9335   -0.7477 C   0  0  0  0  0  0  0  0  0  0  0  0
   14.3478   -5.3279   -0.3229 C   0  0  0  0  0  0  0  0  0  0  0  0
   13.0801   -5.8841   -0.5185 N   0  0  0  0  0  0  0  0  0  0  0  0
   16.3235   -7.8662   -1.8727 O   0  0  0  0  0  0  0  0  0  0  0  0
   17.0235   -5.2108   -0.4863 Cl  0  0  0  0  0  0  0  0  0  0  0  0
    8.0727   -9.2223   -1.9323 H   0  0  0  0  0  0  0  0  0  0  0  0
   12.3294   -9.4833   -2.2114 H   0  0  0  0  0  0  0  0  0  0  0  0
   10.5000   -5.9395   -0.5309 H   0  0  0  0  0  0  0  0  0  0  0  0
    6.5963   -7.6418   -1.9072 H   0  0  0  0  0  0  0  0  0  0  0  0
    4.7728   -6.2316   -0.9963 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.9216   -5.4076    0.0563 H   0  0  0  0  0  0  0  0  0  0  0  0
    6.0566   -5.2512   -1.6930 H   0  0  0  0  0  0  0  0  0  0  0  0
    7.0376   -9.0392    0.1822 H   0  0  0  0  0  0  0  0  0  0  0  0
    6.4989   -7.6921    1.1742 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.3655   -8.5122    0.1058 H   0  0  0  0  0  0  0  0  0  0  0  0
   11.6797  -12.4320   -3.8159 H   0  0  0  0  0  0  0  0  0  0  0  0
   12.7400  -14.3980   -2.7059 H   0  0  0  0  0  0  0  0  0  0  0  0
   10.7684  -13.3823    0.9883 H   0  0  0  0  0  0  0  0  0  0  0  0
    9.7026  -11.4187   -0.1132 H   0  0  0  0  0  0  0  0  0  0  0  0
   14.7527  -14.3892   -0.2677 H   0  0  0  0  0  0  0  0  0  0  0  0
   13.9992  -13.9328    1.2743 H   0  0  0  0  0  0  0  0  0  0  0  0
   14.7461  -15.5395    1.0917 H   0  0  0  0  0  0  0  0  0  0  0  0
   13.9997   -8.5573   -2.0516 H   0  0  0  0  0  0  0  0  0  0  0  0
   14.3815   -4.3776    0.1907 H   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  2  0
  1  6  1  0
  1  7  1  0
  2  3  1  0
  2 30  1  0
  3  4  2  0
  3 11  1  0
  4  5  1  0
  4 31  1  0
  5  6  2  0
  5 22  1  0
  6 32  1  0
  7  8  1  0
  8  9  1  0
  8 10  1  0
  8 33  1  0
  9 34  1  0
  9 35  1  0
  9 36  1  0
 10 37  1  0
 10 38  1  0
 10 39  1  0
 11 12  1  0
 12 13  2  0
 12 17  1  0
 13 14  1  0
 13 40  1  0
 14 15  2  0
 14 41  1  0
 15 16  1  0
 15 18  1  0
 16 17  2  0
 16 42  1  0
 17 43  1  0
 18 19  1  0
 18 20  2  0
 18 21  2  0
 19 44  1  0
 19 45  1  0
 19 46  1  0
 22 23  1  0
 22 27  2  0
 23 24  1  0
 23 47  1  0
 24 25  1  0
 24 28  2  0
 25 26  2  0
 25 29  1  0
 26 27  1  0
 26 48  1  0
M  END
>  <LogP(RRCK)>  (1) 
-5.08

$$$$
10_filipski_42
     RDKit          3D

 50 52  0  0  1  0  0  0  0  0999 V2000
    8.8247   -7.3140   -1.2684 C   0  0  0  0  0  0  0  0  0  0  0  0
    8.7978   -8.6432   -1.7590 C   0  0  0  0  0  0  0  0  0  0  0  0
    9.9897   -9.2996   -2.1198 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.2249   -8.6287   -2.0043 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.2728   -7.3060   -1.5122 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.0677   -6.6605   -1.1523 C   0  0  0  0  0  0  0  0  0  0  0  0
    7.7116   -6.5895   -0.8917 O   0  0  0  0  0  0  0  0  0  0  0  0
    6.4156   -7.1795   -0.9644 C   0  0  0  0  0  0  0  0  0  0  0  0
    5.4127   -6.0219   -0.9784 C   0  0  0  0  0  0  0  0  0  0  0  0
    6.1822   -8.1019    0.2432 C   0  0  0  0  0  0  0  0  0  0  0  0
    9.9229  -10.5823   -2.6015 O   0  0  0  0  0  0  0  0  0  0  0  0
   10.6835  -11.5390   -1.9805 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.6377  -12.2535   -2.7336 C   0  0  0  0  0  0  0  0  0  0  0  0
   12.4273  -13.2459   -2.1186 C   0  0  0  0  0  0  0  0  0  0  0  0
   12.2573  -13.5245   -0.7480 C   0  0  0  0  0  0  0  0  0  0  0  0
   11.3014  -12.8190    0.0104 C   0  0  0  0  0  0  0  0  0  0  0  0
   10.5113  -11.8283   -0.6083 C   0  0  0  0  0  0  0  0  0  0  0  0
   13.2741  -14.7862    0.0266 S   0  0  0  0  0  0  0  0  0  0  0  0
   14.8007  -14.0207    0.1586 N   0  0  0  0  0  0  0  0  0  0  0  0
   12.8065  -15.0295    1.4016 O   0  0  0  0  0  0  0  0  0  0  0  0
   13.4508  -15.9197   -0.8955 O   0  0  0  0  0  0  0  0  0  0  0  0
   12.5842   -6.5952   -1.3827 C   0  0  0  0  0  0  0  0  0  0  0  0
   13.7938   -7.3067   -1.3965 N   0  0  0  0  0  0  0  0  0  0  0  0
   15.0231   -6.7649   -1.2806 C   0  0  0  0  0  0  0  0  0  0  0  0
   15.0328   -5.3156   -1.1306 C   0  0  0  0  0  0  0  0  0  0  0  0
   13.8624   -4.6467   -1.1141 C   0  0  0  0  0  0  0  0  0  0  0  0
   12.6344   -5.3016   -1.2387 N   0  0  0  0  0  0  0  0  0  0  0  0
   16.0372   -7.4655   -1.3016 O   0  0  0  0  0  0  0  0  0  0  0  0
   15.8470  -14.7127    0.9154 C   0  0  0  0  0  0  0  0  0  0  0  0
    7.8748   -9.1913   -1.8675 H   0  0  0  0  0  0  0  0  0  0  0  0
   12.1279   -9.1407   -2.3028 H   0  0  0  0  0  0  0  0  0  0  0  0
   10.0889   -5.6466   -0.7773 H   0  0  0  0  0  0  0  0  0  0  0  0
    6.2849   -7.7285   -1.8974 H   0  0  0  0  0  0  0  0  0  0  0  0
    4.3881   -6.3892   -1.0449 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.4925   -5.4156   -0.0756 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.5848   -5.3681   -1.8339 H   0  0  0  0  0  0  0  0  0  0  0  0
    6.8956   -8.9248    0.2754 H   0  0  0  0  0  0  0  0  0  0  0  0
    6.2739   -7.5525    1.1802 H   0  0  0  0  0  0  0  0  0  0  0  0
    5.1840   -8.5392    0.2119 H   0  0  0  0  0  0  0  0  0  0  0  0
   11.7650  -12.0392   -3.7854 H   0  0  0  0  0  0  0  0  0  0  0  0
   13.1598  -13.7962   -2.6907 H   0  0  0  0  0  0  0  0  0  0  0  0
   11.1770  -13.0389    1.0604 H   0  0  0  0  0  0  0  0  0  0  0  0
    9.7750  -11.2901   -0.0280 H   0  0  0  0  0  0  0  0  0  0  0  0
   15.1266  -13.7716   -0.7786 H   0  0  0  0  0  0  0  0  0  0  0  0
   13.7507   -8.3089   -1.4905 H   0  0  0  0  0  0  0  0  0  0  0  0
   15.9705   -4.7876   -1.0331 H   0  0  0  0  0  0  0  0  0  0  0  0
   13.8284   -3.5714   -1.0035 H   0  0  0  0  0  0  0  0  0  0  0  0
   16.0696  -15.6784    0.4597 H   0  0  0  0  0  0  0  0  0  0  0  0
   16.7610  -14.1182    0.9298 H   0  0  0  0  0  0  0  0  0  0  0  0
   15.5270  -14.8822    1.9443 H   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  2  0
  1  6  1  0
  1  7  1  0
  2  3  1  0
  2 30  1  0
  3  4  2  0
  3 11  1  0
  4  5  1  0
  4 31  1  0
  5  6  2  0
  5 22  1  0
  6 32  1  0
  7  8  1  0
  8  9  1  0
  8 10  1  0
  8 33  1  0
  9 34  1  0
  9 35  1  0
  9 36  1  0
 10 37  1  0
 10 38  1  0
 10 39  1  0
 11 12  1  0
 12 13  2  0
 12 17  1  0
 13 14  1  0
 13 40  1  0
 14 15  2  0
 14 41  1  0
 15 16  1  0
 15 18  1  0
 16 17  2  0
 16 42  1  0
 17 43  1  0
 18 19  1  0
 18 20  2  0
 18 21  2  0
 19 29  1  0
 19 44  1  0
 22 23  1  0
 22 27  2  0
 23 24  1  0
 23 45  1  0
 24 25  1  0
 24 28  2  0
 25 26  2  0
 25 46  1  0
 26 27  1  0
 26 47  1  0
 29 48  1  0
 29 49  1  0
 29 50  1  0
M  END
>  <LogP(RRCK)>  (2) 
-4.82

$$$$
+63 −0
Original line number Diff line number Diff line
import tempfile
import os
import deepchem as dc
import numpy as np


def test_flattening_with_csv_load_withtask():
  fin = tempfile.NamedTemporaryFile(mode='w', delete=False)
  fin.write("smiles,endpoint\nc1ccccc1,1")
  fin.close()
  loader = dc.data.CSVLoader(
      ["endpoint"],
      feature_field="smiles",
      featurizer=dc.feat.ConvMolFeaturizer(per_atom_fragmentation=True))
  frag_dataset = loader.create_dataset(fin.name)
  transformer = dc.trans.FlatteningTransformer(dataset=frag_dataset)
  frag_dataset = transformer.transform(frag_dataset)
  assert len(frag_dataset) == 6
  assert np.shape(frag_dataset.y) == (6,
                                      1)  # y should be expanded up to X shape
  assert np.shape(frag_dataset.w) == (6,
                                      1)  # w should be expanded up to X shape


def test_flattening_with_csv_load_notask():
  fin = tempfile.NamedTemporaryFile(mode='w', delete=False)
  fin.write("smiles,endpoint\nc1ccccc1,1")
  fin.close()
  loader = dc.data.CSVLoader(
      [],
      feature_field="smiles",
      featurizer=dc.feat.ConvMolFeaturizer(per_atom_fragmentation=True))
  frag_dataset = loader.create_dataset(fin.name)
  transformer = dc.trans.FlatteningTransformer(dataset=frag_dataset)
  frag_dataset = transformer.transform(frag_dataset)
  assert len(frag_dataset) == 6


def test_flattening_with_sdf_load_withtask():
  cur_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.ConvMolFeaturizer(per_atom_fragmentation=True)
  loader = dc.data.SDFLoader(
      ["LogP(RRCK)"], featurizer=featurizer, sanitize=True)
  dataset = loader.create_dataset(
      os.path.join(cur_dir, "membrane_permeability.sdf"))
  transformer = dc.trans.FlatteningTransformer(dataset=dataset)
  frag_dataset = transformer.transform(dataset)
  assert len(frag_dataset) == 98
  assert np.shape(frag_dataset.y) == (98,
                                      1)  # y should be expanded up to X shape
  assert np.shape(frag_dataset.w) == (98,
                                      1)  # w should be expanded up to X shape


def test_flattening_with_sdf_load_notask():
  cur_dir = os.path.dirname(os.path.realpath(__file__))
  featurizer = dc.feat.ConvMolFeaturizer(per_atom_fragmentation=True)
  loader = dc.data.SDFLoader([], featurizer=featurizer, sanitize=True)
  dataset = loader.create_dataset(
      os.path.join(cur_dir, "membrane_permeability.sdf"))
  transformer = dc.trans.FlatteningTransformer(dataset=dataset)
  frag_dataset = transformer.transform(dataset)
  assert len(frag_dataset) == 98
Loading