Unverified Commit 049779ec authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1476 from VIGS25/atomic-conv-featurizer

#1362: Add Atomic Convolution Featurization for PDB Complexes
parents c43482fd 0ff47b57
Loading
Loading
Loading
Loading
+152 −0
Original line number Diff line number Diff line
@@ -4,8 +4,18 @@ from __future__ import unicode_literals
import numpy as np
from rdkit import Chem

import deepchem as dc
from deepchem.feat import Featurizer
from deepchem.feat.atomic_coordinates import ComplexNeighborListFragmentAtomicCoordinates
from deepchem.feat.mol_graphs import ConvMol, WeaveMol
from deepchem.data import DiskDataset
import multiprocessing
import logging


def _featurize_complex(featurizer, mol_pdb_file, protein_pdb_file, log_message):
  logging.info(log_message)
  return featurizer._featurize_complex(mol_pdb_file, protein_pdb_file)


def one_of_k_encoding(x, allowable_set):
@@ -424,3 +434,145 @@ class WeaveFeaturizer(Featurizer):
        graph_distance=self.graph_distance)

    return WeaveMol(nodes, pairs)


class AtomicConvFeaturizer(ComplexNeighborListFragmentAtomicCoordinates):
  """This class computes the Atomic Convolution features"""

  # TODO (VIGS25): Complete the description

  name = ['atomic_conv']

  def __init__(self,
               labels,
               neighbor_cutoff,
               frag1_num_atoms=70,
               frag2_num_atoms=634,
               complex_num_atoms=701,
               max_num_neighbors=12,
               batch_size=24,
               atom_types=[
                   6, 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35.,
                   53., -1.
               ],
               radial=[[
                   1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0,
                   7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0
               ], [0.0, 4.0, 8.0], [0.4]],
               layer_sizes=[32, 32, 16],
               strip_hydrogens=True,
               learning_rate=0.001,
               epochs=10):
    """
    Parameters

    labels: numpy.ndarray
      Labels which we want to predict using the model
    neighbor_cutoff: int
      TODO (VIGS25): Add description
    frag1_num_atoms: int
      Number of atoms in first fragment
    frag2_num_atoms: int
      Number of atoms in second fragment
    complex_num_atoms: int
      TODO (VIGS25) : Add description
    max_num_neighbors: int
      Maximum number of neighbors possible for an atom
    batch_size: int
      Batch size used for training and evaluation
    atom_types: list
      List of atoms recognized by model. Atoms are indicated by their
      nuclear numbers.
    radial: list
      TODO (VIGS25): Add description
    layer_sizes: list
      List of layer sizes for the AtomicConvolutional Network
    strip_hydrogens: bool
      Whether to remove hydrogens while computing neighbor features
    learning_rate: float
      Learning rate for training the model
    epochs: int
      Number of epochs to train the model for
    """

    self.atomic_conv_model = dc.models.tensorgraph.models.atomic_conv.AtomicConvModel(
        frag1_num_atoms=frag1_num_atoms,
        frag2_num_atoms=frag2_num_atoms,
        complex_num_atoms=complex_num_atoms,
        max_num_neighbors=max_num_neighbors,
        batch_size=batch_size,
        atom_types=atom_types,
        radial=radial,
        layer_sizes=layer_sizes,
        learning_rate=learning_rate)

    super(AtomicConvFeaturizer, self).__init__(
        frag1_num_atoms=frag1_num_atoms,
        frag2_num_atoms=frag2_num_atoms,
        complex_num_atoms=complex_num_atoms,
        max_num_neighbors=max_num_neighbors,
        neighbor_cutoff=neighbor_cutoff,
        strip_hydrogens=strip_hydrogens)

    self.epochs = epochs
    self.labels = labels

  def featurize_complexes(self, mol_files, protein_files):
    pool = multiprocessing.Pool()
    results = []
    for i, (mol_file, protein_pdb) in enumerate(zip(mol_files, protein_files)):
      log_message = "Featurizing %d / %d" % (i, len(mol_files))
      results.append(
          pool.apply_async(_featurize_complex,
                           (self, mol_file, protein_pdb, log_message)))
    pool.close()
    features = []
    failures = []
    for ind, result in enumerate(results):
      new_features = result.get()
      # Handle loading failures which return None
      if new_features is not None:
        features.append(new_features)
      else:
        failures.append(ind)

    features = np.asarray(features)
    labels = np.delete(self.labels, failures)
    dataset = DiskDataset.from_numpy(features, labels)

    # Fit atomic conv model
    self.atomic_conv_model.fit(dataset, nb_epoch=self.epochs)

    # Add the Atomic Convolution layers to fetches
    layers_to_fetch = list()
    for layer in self.atomic_conv_model.layers.values():
      if isinstance(layer,
                    dc.models.tensorgraph.models.atomic_conv.AtomicConvolution):
        layers_to_fetch.append(layer)

    # Extract the atomic convolution features
    atomic_conv_features = list()
    feed_dict_generator = self.atomic_conv_model.default_generator(
        dataset=dataset, epochs=1)

    for feed_dict in self.atomic_conv_model._create_feed_dicts(
        feed_dict_generator, training=False):
      frag1_conv, frag2_conv, complex_conv = self.atomic_conv_model._run_graph(
          outputs=layers_to_fetch, feed_dict=feed_dict, training=False)
      concatenated = np.concatenate(
          [frag1_conv, frag2_conv, complex_conv], axis=1)
      atomic_conv_features.append(concatenated)

    batch_size = self.atomic_conv_model.batch_size

    if len(features) % batch_size != 0:
      num_batches = (len(features) // batch_size) + 1
      num_to_skip = num_batches * batch_size - len(features)
    else:
      num_to_skip = 0

    atomic_conv_features = np.asarray(atomic_conv_features)
    atomic_conv_features = atomic_conv_features[-num_to_skip:]
    atomic_conv_features = np.squeeze(atomic_conv_features)

    return atomic_conv_features, failures
+34 −4
Original line number Diff line number Diff line
@@ -10,11 +10,8 @@ __license__ = "MIT"

import unittest
import os
import sys
import numpy as np
from deepchem.feat.mol_graphs import ConvMol
from deepchem.feat.mol_graphs import MultiConvMol
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.graph_features import ConvMolFeaturizer, AtomicConvFeaturizer


class TestConvMolFeaturizer(unittest.TestCase):
@@ -95,3 +92,36 @@ class TestConvMolFeaturizer(unittest.TestCase):
    assert np.array_equal(deg_adj_lists[4], np.zeros([0, 4], dtype=np.int32))
    assert np.array_equal(deg_adj_lists[5], np.zeros([0, 5], dtype=np.int32))
    assert np.array_equal(deg_adj_lists[6], np.zeros([0, 6], dtype=np.int32))


class TestAtomicConvFeaturizer(unittest.TestCase):

  def test_feature_generation(self):
    """Test if featurization works using AtomicConvFeaturizer."""
    dir_path = os.path.dirname(os.path.realpath(__file__))
    ligand_file = os.path.join(dir_path, "data/3zso_ligand_hyd.pdb")
    protein_file = os.path.join(dir_path, "data/3zso_protein.pdb")
    # Pulled from PDB files. For larger datasets with more PDBs, would use
    # max num atoms instead of exact.

    frag1_num_atoms = 44  # for ligand atoms
    frag2_num_atoms = 2336  # for protein atoms
    complex_num_atoms = 2380  # in total
    max_num_neighbors = 4
    # Cutoff in angstroms
    neighbor_cutoff = 4

    labels = np.array([0, 0])

    featurizer = AtomicConvFeaturizer(
        labels=labels,
        batch_size=1,
        epochs=1,
        frag1_num_atoms=frag1_num_atoms,
        frag2_num_atoms=frag2_num_atoms,
        complex_num_atoms=complex_num_atoms,
        max_num_neighbors=max_num_neighbors,
        neighbor_cutoff=neighbor_cutoff)

    features, _ = featurizer.featurize_complexes([ligand_file, ligand_file],
                                                 [protein_file, protein_file])
+18 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import logging
import tarfile
from deepchem.feat import rdkit_grid_featurizer as rgf
from deepchem.feat.atomic_coordinates import ComplexNeighborListFragmentAtomicCoordinates
from deepchem.feat.graph_features import AtomicConvFeaturizer

logger = logging.getLogger(__name__)

@@ -233,6 +234,22 @@ def load_pdbbind(featurizer="grid", split="random", subset="core", reload=True):
        frag1_num_atoms, frag2_num_atoms, complex_num_atoms, max_num_neighbors,
        neighbor_cutoff)

  elif featurizer == "atomic_conv":
    frag1_num_atoms = 70  # for ligand atoms
    frag2_num_atoms = 24000  # for protein atoms
    complex_num_atoms = 24070  # in total
    max_num_neighbors = 4
    # Cutoff in angstroms
    neighbor_cutoff = 4
    featurizer = AtomicConvFeaturizer(
        labels=labels,
        frag1_num_atoms=frag1_num_atoms,
        frag2_num_atoms=frag2_num_atoms,
        complex_num_atoms=complex_num_atoms,
        neighbor_cutoff=neighbor_cutoff,
        max_num_neighbors=max_num_neighbors,
        batch_size=64)

  else:
    raise ValueError("Featurizer not supported")
  print("Featurizing Complexes")
@@ -241,6 +258,7 @@ def load_pdbbind(featurizer="grid", split="random", subset="core", reload=True):
  # Delete labels for failing elements
  labels = np.delete(labels, failures)
  dataset = deepchem.data.DiskDataset.from_numpy(features, labels)
  print('Featurization complete.')
  # No transformations of data
  transformers = []
  if split == None: