Commit db14587e authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #247 from peastman/neighbors

Use mdtraj to build neighbor lists
parents f379159a dff9b181
Loading
Loading
Loading
Loading
+26 −142
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ __copyright__ = "Copyright 2016, Stanford University"
__license__ = "LGPL v2.1+"

import numpy as np
import mdtraj
from deepchem.featurizers import Featurizer
from deepchem.featurizers import ComplexFeaturizer
from deepchem.featurizers.grid_featurizer import load_molecule 
@@ -48,147 +49,29 @@ class AtomicCoordinates(Featurizer):
    coords = [coords]
    return coords

def compute_neighbor_list(coords, neighbor_cutoff, max_num_neighbors):
def compute_neighbor_list(coords, neighbor_cutoff, max_num_neighbors, periodic_box_size):
  """Computes a neighbor list from atom coordinates."""
  N = coords.shape[0]
  x_bins, y_bins, z_bins = get_cells(coords, neighbor_cutoff)

  # Associate each atom with cell it belongs to. O(N)
  cell_to_atoms, atom_to_cell = put_atoms_in_cells(
      coords, x_bins, y_bins, z_bins)

  # Associate each cell with its neighbor cells. Assumes periodic boundary
  # conditions, so does wrapround. O(constant)
  N_x, N_y, N_z = len(x_bins), len(y_bins), len(z_bins)
  neighbor_cell_map = compute_neighbor_cell_map(N_x, N_y, N_z)

  # For each atom, loop through all atoms in its cell and neighboring cells.
  # Accept as neighbors only those within threshold. This computation should be
  # O(Nm), where m is the number of atoms within a set of neighboring-cells.
  traj = mdtraj.Trajectory(coords.reshape((1, N, 3)), None)
  box_size = None
  if periodic_box_size is not None:
    box_size = np.array(periodic_box_size)
    traj.unitcell_vectors = np.array([[[box_size[0], 0, 0], [0, box_size[1], 0], [0, 0, box_size[2]]]], dtype=np.float32)
  neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
  neighbor_list = {}
  for atom in range(N):
    cell = atom_to_cell[atom]
    neighbor_cells = neighbor_cell_map[cell]
    # For smaller systems especially, the periodic boundary conditions can
    # result in neighboring cells being seen multiple times. Use a set() to
    # make sure duplicate neighbors are ignored. Convert back to list before
    # returning. 
    neighbor_list[atom] = set()
    for neighbor_cell in neighbor_cells:
      atoms_in_cell = cell_to_atoms[neighbor_cell]
      for neighbor_atom in atoms_in_cell:
        if neighbor_atom == atom:
          continue
        # TODO(rbharath): How does distance need to be modified here to
        # account for periodic boundary conditions?
        dist = np.linalg.norm(coords[atom] - coords[neighbor_atom])
        if dist < neighbor_cutoff:
          neighbor_list[atom].add((neighbor_atom, dist))
        
    # Sort neighbors by distance
    closest_neighbors = sorted(
        list(neighbor_list[atom]), key=lambda elt: elt[1])
    closest_neighbors = [nbr for (nbr, dist) in closest_neighbors]
    # Pick up to max_num_neighbors
    closest_neighbors = closest_neighbors[:max_num_neighbors]
    neighbor_list[atom] = closest_neighbors
  for i in range(N):
    if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
      delta = coords[i]-coords.take(neighbors[i], axis=0)
      if box_size is not None:
        delta -= np.round(delta/box_size)*box_size
      dist = np.linalg.norm(delta, axis=1)
      sorted_neighbors = list(zip(dist, neighbors[i]))
      sorted_neighbors.sort()
      neighbor_list[i] = [sorted_neighbors[j][1] for j in range(max_num_neighbors)]
    else:
      neighbor_list[i] = list(neighbors[i])
  return neighbor_list

def get_cells(coords, neighbor_cutoff):
  """Computes cells given molecular coordinates."""
  x_max, x_min = np.amax(coords[:, 0]), np.amin(coords[:, 0])
  y_max, y_min = np.amax(coords[:, 1]), np.amin(coords[:, 1])
  z_max, z_min = np.amax(coords[:, 2]), np.amin(coords[:, 2])

  # Compute cells for this molecule. O(constant)
  x_bins, y_bins, z_bins = [], [], []
  x_current, y_current, z_current = x_min, y_min, z_min
  while x_current < x_max:
    x_bins.append((x_current, x_current+neighbor_cutoff))
    x_current += neighbor_cutoff
  while y_current < y_max:
    y_bins.append((y_current, y_current+neighbor_cutoff))
    y_current += neighbor_cutoff
  while z_current < z_max:
    z_bins.append((z_current, z_current+neighbor_cutoff))
    z_current += neighbor_cutoff
  return x_bins, y_bins, z_bins

def put_atoms_in_cells(coords, x_bins, y_bins, z_bins):
  """Place each atom into cells. O(N) runtime.
  
  Parameters
  ----------
  coords: np.ndarray
    (N, 3) array where N is number of atoms
  x_bins: list
    List of (cell_start, cell_end) for x-coordinate
  y_bins: list
    List of (cell_start, cell_end) for y-coordinate
  z_bins: list
    List of (cell_start, cell_end) for z-coordinate
  """
  N = coords.shape[0]
  cell_to_atoms = {}
  atom_to_cell = {}
  for x_ind in range(len(x_bins)):
    for y_ind in range(len(y_bins)):
      for z_ind in range(len(z_bins)):
        cell_to_atoms[(x_ind, y_ind, z_ind)] = []
    
  for atom in range(N):
    x_coord, y_coord, z_coord = coords[atom]
    x_ind, y_ind, z_ind = None, None, None
    for ind, (x_cell_min, x_cell_max) in enumerate(x_bins):
      if x_coord >= x_cell_min and x_coord <= x_cell_max:
        x_ind = ind
        break
    if x_ind is None:
      raise ValueError("No x-cell found!")
    for ind, (y_cell_min, y_cell_max) in enumerate(y_bins):
      if y_coord >= y_cell_min and y_coord <= y_cell_max:
        y_ind = ind
        break
    if y_ind is None:
      raise ValueError("No y-cell found!")
    for ind, (z_cell_min, z_cell_max) in enumerate(z_bins):
      if z_coord >= z_cell_min and z_coord <= z_cell_max:
        z_ind = ind
        break
    if z_ind is None:
      raise ValueError("No z-cell found!")
    cell_to_atoms[(x_ind, y_ind, z_ind)].append(atom)
    atom_to_cell[atom] = (x_ind, y_ind, z_ind)
  return cell_to_atoms, atom_to_cell

def compute_neighbor_cell_map(N_x, N_y, N_z):
  """Compute neighbors of cells in grid.
  
  Parameters
  ----------
  N_x: int
    Number of grid cells in x-dimension.
  N_y: int
    Number of grid cells in y-dimension.
  N_z: int
    Number of grid cells in z-dimension.
  """
  neighbor_cell_map = {} 
  for x_ind in range(N_x):
    for y_ind in range(N_y):
      for z_ind in range(N_z):
        neighbors = []
        offsets = [-1, 0, +1]
        # Note neighbors contains self!
        for x_offset in offsets:
          for y_offset in offsets:
            for z_offset in offsets:
              neighbors.append(((x_ind+x_offset) % N_x,
                                (y_ind+y_offset) % N_y,
                                (z_ind+z_offset) % N_z))
        neighbor_cell_map[(x_ind, y_ind, z_ind)] = neighbors
  return neighbor_cell_map

def get_coords(mol):
  """
  Gets coordinates in Angstrom for RDKit mol.
@@ -214,11 +97,13 @@ class NeighborListAtomicCoordinates(Featurizer):

  Parameters
  ----------
  neighbor_cutoff: int
  neighbor_cutoff: float
    Threshold distance [Angstroms] for counting neighbors.
  periodic_box_size: 3 element array
    Dimensions of the periodic box in Angstroms, or None to not use periodic boundary conditions
  """ 

  def __init__(self, max_num_neighbors=None, neighbor_cutoff=4):
  def __init__(self, max_num_neighbors=None, neighbor_cutoff=4, periodic_box_size=None):
    if neighbor_cutoff <= 0:
      raise ValueError("neighbor_cutoff must be positive value.")
    if max_num_neighbors is not None:
@@ -226,6 +111,7 @@ class NeighborListAtomicCoordinates(Featurizer):
        raise ValueError("max_num_neighbors must be positive integer.")
    self.max_num_neighbors = max_num_neighbors
    self.neighbor_cutoff = neighbor_cutoff
    self.periodic_box_size = periodic_box_size
    # Type of data created by this featurizer
    self.dtype = object
    self.coordinates_featurizer = AtomicCoordinates()
@@ -243,9 +129,7 @@ class NeighborListAtomicCoordinates(Featurizer):
    # TODO(rbharath): Should this return a list?
    bohr_coords = self.coordinates_featurizer._featurize(mol)[0]
    coords = get_coords(mol)
    neighbor_list = compute_neighbor_list(
        coords, self.neighbor_cutoff, self.max_num_neighbors)
        
    neighbor_list = compute_neighbor_list(coords, self.neighbor_cutoff, self.max_num_neighbors, self.periodic_box_size)
    return (bohr_coords, neighbor_list)

class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
@@ -283,6 +167,6 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
        mol_coords, ob_mol, protein_coords, protein_mol)
    
    system_neighbor_list = compute_neighbor_list(
        system_coords, self.neighbor_cutoff, self.max_num_neighbors)
        system_coords, self.neighbor_cutoff, self.max_num_neighbors, None)

    return (system_coords, system_neighbor_list)
+19 −115
Original line number Diff line number Diff line
@@ -6,10 +6,7 @@ import numpy as np
import unittest
from rdkit import Chem
from deepchem.utils import conformers
from deepchem.featurizers.atomic_coordinates import get_cells
from deepchem.featurizers.atomic_coordinates import get_coords
from deepchem.featurizers.atomic_coordinates import put_atoms_in_cells
from deepchem.featurizers.atomic_coordinates import compute_neighbor_cell_map
from deepchem.featurizers.atomic_coordinates import AtomicCoordinates
from deepchem.featurizers.atomic_coordinates import NeighborListAtomicCoordinates
from deepchem.featurizers.atomic_coordinates import NeighborListComplexAtomicCoordinates
@@ -40,117 +37,6 @@ class TestAtomicCoordinates(unittest.TestCase):
    assert isinstance(coords, np.ndarray)
    assert coords.shape == (N, 3)

  def test_get_cells(self):
    """
    Test that coordinates are split into cell appropriately.
    """
    # The coordinates span the cube of side-length 2 set on (-1, 1)
    coords = np.array(
        [[1., 1., 1.],
         [-1., -1., -1.]])
    # Set cell size (neighbor_cutoff) at 1 angstrom.
    neighbor_cutoff = 1
    # We should get 2 bins in each dimension
    x_bins, y_bins, z_bins = get_cells(coords, neighbor_cutoff)

    # Check bins are lists
    assert isinstance(x_bins, list)
    assert isinstance(y_bins, list)
    assert isinstance(z_bins, list)

    assert len(x_bins) == 2
    assert x_bins ==[(-1.0, 0.0), (0.0, 1.0)] 
    assert len(y_bins) == 2
    assert y_bins ==[(-1.0, 0.0), (0.0, 1.0)] 
    assert len(z_bins) == 2
    assert z_bins ==[(-1.0, 0.0), (0.0, 1.0)] 


  def test_put_atoms_in_cells(self):
    """
    Test that atoms are placed into correct cells.
    """
    # As in previous example, coordinates span size-2 cube on (-1, 1)
    coords = np.array(
        [[1., 1., 1.],
         [-1., -1., -1.]])
    # Set cell size (neighbor_cutoff) at 1 angstrom.
    neighbor_cutoff = 1
    # We should get 2 bins in each dimension
    x_bins, y_bins, z_bins = get_cells(coords, neighbor_cutoff)

    cell_to_atoms, atom_to_cell = put_atoms_in_cells(
        coords, x_bins, y_bins, z_bins)

    # Both cell_to_atoms and atom_to_cell are dictionaries
    assert isinstance(cell_to_atoms, dict)
    assert isinstance(atom_to_cell, dict)

    # atom_to_cell should be of len 2 since 2 atoms
    assert len(atom_to_cell) == 2

    # cell_to_atoms should be of len 8 since 8 cells total.
    assert len(cell_to_atoms) == 8

    # We have two atoms. The first is in highest corner (1,1,1)
    # Second atom should be in lowest corner (0, 0, 0)
    assert atom_to_cell[0] == (1, 1, 1)
    assert atom_to_cell[1] == (0, 0, 0)

    # (1,1,1) should contain atom 0. (0, 0, 0) should contain atom 1.
    # Everything else should be an empty list
    for cell, atoms in cell_to_atoms.items():
      if cell == (1, 1, 1):
        assert atoms == [0]
      elif cell == (0, 0, 0):
        assert atoms == [1]
      else:
        assert atoms == []

  def test_compute_neighbor_cell_map(self):
    """
    Tests that computed neighbors for grid are meaningful.
    """
    # For a 1x1x1 grid, the neighbor cell map should return [(0,0,0)] * 27
    # since the periodic boundary conditions mean wrap-around happens in all
    # directions.
    neighbor_cell_map = compute_neighbor_cell_map(1, 1, 1)
    assert isinstance(neighbor_cell_map, dict)
    assert len(neighbor_cell_map) == 1
    assert neighbor_cell_map[(0,0,0)] == [(0,0,0)] * 27

    neighbor_cell_map = compute_neighbor_cell_map(5, 5, 5)
    assert isinstance(neighbor_cell_map, dict)
    assert len(neighbor_cell_map) == 125 
    assert sorted(neighbor_cell_map[(2,2, 2)]) == [
        (1, 1, 1),
        (1, 1, 2),
        (1, 1, 3),
        (1, 2, 1),
        (1, 2, 2),
        (1, 2, 3),
        (1, 3, 1),
        (1, 3, 2),
        (1, 3, 3),
        (2, 1, 1),
        (2, 1, 2),
        (2, 1, 3),
        (2, 2, 1),
        (2, 2, 2),
        (2, 2, 3),
        (2, 3, 1),
        (2, 3, 2),
        (2, 3, 3),
        (3, 1, 1),
        (3, 1, 2),
        (3, 1, 3),
        (3, 2, 1),
        (3, 2, 2),
        (3, 2, 3),
        (3, 3, 1),
        (3, 3, 2),
        (3, 3, 3)]

  def test_neighbor_list_shape(self):
    """
    Simple test that Neighbor Lists have right shape.
@@ -158,7 +44,6 @@ class TestAtomicCoordinates(unittest.TestCase):
    nblist_featurizer = NeighborListAtomicCoordinates()
    N = self.mol.GetNumAtoms()
    coords = get_coords(self.mol)
    x_bins, y_bins, z_bins = get_cells(coords, nblist_featurizer.neighbor_cutoff)

    nblist_featurizer = NeighborListAtomicCoordinates()
    nblist = nblist_featurizer._featurize(self.mol)[1]
@@ -231,6 +116,25 @@ class TestAtomicCoordinates(unittest.TestCase):
      else:
        assert nblist[i] == []

  def test_neighbor_list_periodic(self):
    """Test building a neighbor list with periodic boundary conditions."""
    cutoff = 4.0
    box_size = np.array([10.0, 8.0, 9.0])
    N = self.mol.GetNumAtoms()
    coords = get_coords(self.mol)
    featurizer = NeighborListAtomicCoordinates(neighbor_cutoff=cutoff, periodic_box_size=box_size)
    neighborlist = featurizer._featurize(self.mol)[1]
    expected_neighbors = [set() for i in range(N)]
    for i in range(N):
        for j in range(i):
            delta = coords[i]-coords[j]
            delta -= np.round(delta/box_size)*box_size
            if np.linalg.norm(delta) < cutoff:
                expected_neighbors[i].add(j)
                expected_neighbors[j].add(i)
    for i in range(N):
        assert(set(neighborlist[i]) == expected_neighbors[i])

  def test_complex_featurization_simple(self):
    """Test Neighbor List computation on protein-ligand complex."""
    dir_path = os.path.dirname(os.path.realpath(__file__))