Commit 494cce1f authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #217 from rbharath/nblist

[WIP] Neighbor-List Featurizer
parents c186fd04 ad768009
Loading
Loading
Loading
Loading
+179 −4
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ from deepchem.featurizers import Featurizer

class AtomicCoordinates(Featurizer):
  """
  Nx3 matrix of Cartestian coordinates [Angstrom]
  Nx3 matrix of Cartesian coordinates [Angstrom]
  """
  name = ['atomic_coordinates']

@@ -45,3 +45,178 @@ class AtomicCoordinates(Featurizer):
    coords = [coords]
    return coords

def get_cells(coords, neighbor_cutoff):
  """Computes cells given molecular coordinates."""
  x_max, x_min = np.amax(coords[:, 0]), np.amin(coords[:, 0])
  y_max, y_min = np.amax(coords[:, 1]), np.amin(coords[:, 1])
  z_max, z_min = np.amax(coords[:, 2]), np.amin(coords[:, 2])

  # Compute cells for this molecule. O(constant)
  x_bins, y_bins, z_bins = [], [], []
  x_current, y_current, z_current = x_min, y_min, z_min
  while x_current < x_max:
    x_bins.append((x_current, x_current+neighbor_cutoff))
    x_current += neighbor_cutoff
  while y_current < y_max:
    y_bins.append((y_current, y_current+neighbor_cutoff))
    y_current += neighbor_cutoff
  while z_current < z_max:
    z_bins.append((z_current, z_current+neighbor_cutoff))
    z_current += neighbor_cutoff
  return x_bins, y_bins, z_bins

def put_atoms_in_cells(coords, x_bins, y_bins, z_bins):
  """Place each atom into cells. O(N) runtime.
  
  Parameters
  ----------
  coords: np.ndarray
    (N, 3) array where N is number of atoms
  x_bins: list
    List of (cell_start, cell_end) for x-coordinate
  y_bins: list
    List of (cell_start, cell_end) for y-coordinate
  z_bins: list
    List of (cell_start, cell_end) for z-coordinate
  """
  N = coords.shape[0]
  cell_to_atoms = {}
  atom_to_cell = {}
  for x_ind in range(len(x_bins)):
    for y_ind in range(len(y_bins)):
      for z_ind in range(len(z_bins)):
        cell_to_atoms[(x_ind, y_ind, z_ind)] = []
    
  for atom in range(N):
    x_coord, y_coord, z_coord = coords[atom]
    x_ind, y_ind, z_ind = None, None, None
    for ind, (x_cell_min, x_cell_max) in enumerate(x_bins):
      if x_coord >= x_cell_min and x_coord <= x_cell_max:
        x_ind = ind
        break
    if x_ind is None:
      raise ValueError("No x-cell found!")
    for ind, (y_cell_min, y_cell_max) in enumerate(y_bins):
      if y_coord >= y_cell_min and y_coord <= y_cell_max:
        y_ind = ind
        break
    if y_ind is None:
      raise ValueError("No y-cell found!")
    for ind, (z_cell_min, z_cell_max) in enumerate(z_bins):
      if z_coord >= z_cell_min and z_coord <= z_cell_max:
        z_ind = ind
        break
    if z_ind is None:
      raise ValueError("No z-cell found!")
    cell_to_atoms[(x_ind, y_ind, z_ind)].append(atom)
    atom_to_cell[atom] = (x_ind, y_ind, z_ind)
  return cell_to_atoms, atom_to_cell

def compute_neighbor_cell_map(N_x, N_y, N_z):
  """Compute neighbors of cells in grid.
  
  Parameters
  ----------
  N_x: int
    Number of grid cells in x-dimension.
  N_y: int
    Number of grid cells in y-dimension.
  N_z: int
    Number of grid cells in z-dimension.
  """
  neighbor_cell_map = {} 
  for x_ind in range(N_x):
    for y_ind in range(N_y):
      for z_ind in range(N_z):
        neighbors = []
        offsets = [-1, 0, +1]
        # Note neighbors contains self!
        for x_offset in offsets:
          for y_offset in offsets:
            for z_offset in offsets:
              neighbors.append(((x_ind+x_offset) % N_x,
                                (y_ind+y_offset) % N_y,
                                (z_ind+z_offset) % N_z))
        neighbor_cell_map[(x_ind, y_ind, z_ind)] = neighbors
  return neighbor_cell_map

def get_coords(mol):
  """
  Gets coordinates in Angstrom for RDKit mol.
  """
  N = mol.GetNumAtoms()
  coords = np.zeros((N,3))

  coords_raw = [mol.GetConformer(0).GetAtomPosition(i) for i in range(N)]
  for atom in range(N):
    coords[atom,0] = coords_raw[atom].x
    coords[atom,1] = coords_raw[atom].y
    coords[atom,2] = coords_raw[atom].z
  return coords

class NeighborListAtomicCoordinates(Featurizer):
  """
  Adjacency List of neighbors in 3-space

  Neighbors determined by user-defined distance cutoff [in Angstrom].

  https://en.wikipedia.org/wiki/Cell_list
  Ref: http://www.cs.cornell.edu/ron/references/1989/Calculations%20of%20a%20List%20of%20Neighbors%20in%20Molecular%20Dynamics%20Si.pdf

  Parameters
  ----------
  neighbor_cutoff: int
    Threshold distance [Angstroms] for counting neighbors.
  """ 

  def __init__(self, neighbor_cutoff=4):
    if neighbor_cutoff <= 0:
      raise ValueError("neighbor_cutoff must be positive value.")
    self.neighbor_cutoff = neighbor_cutoff

  def _featurize(self, mol):
    """
    Compute neighbor list.

    Parameters
    ----------
    """
    N = mol.GetNumAtoms()
    coords = get_coords(mol)

    x_bins, y_bins, z_bins = get_cells(coords, self.neighbor_cutoff)

    # Associate each atom with cell it belongs to. O(N)
    cell_to_atoms, atom_to_cell = put_atoms_in_cells(
        coords, x_bins, y_bins, z_bins)

    # Associate each cell with its neighbor cells. Assumes periodic boundary
    # conditions, so does wrapround. O(constant)
    N_x, N_y, N_z = len(x_bins), len(y_bins), len(z_bins)
    neighbor_cell_map = compute_neighbor_cell_map(N_x, N_y, N_z)

    # For each atom, loop through all atoms in its cell and neighboring cells.
    # Accept as neighbors only those within threshold. This computation should be
    # O(Nm), where m is the number of atoms within a set of neighboring-cells.
    neighbor_list = {}
    for atom in range(N):
      cell = atom_to_cell[atom]
      neighbor_cells = neighbor_cell_map[cell]
      # For smaller systems especially, the periodic boundary conditions can
      # result in neighboring cells being seen multiple times. Use a set() to
      # make sure duplicate neighbors are ignored. Convert back to list before
      # returning. 
      neighbor_list[atom] = set()
      for neighbor_cell in neighbor_cells:
        atoms_in_cell = cell_to_atoms[neighbor_cell]
        for neighbor_atom in atoms_in_cell:
          if neighbor_atom == atom:
            continue
          # TODO(rbharath): How does distance need to be modified here to
          # account for periodic boundary conditions?
          if np.linalg.norm(coords[atom] - coords[neighbor_atom]) < self.neighbor_cutoff:
            neighbor_list[atom].add(neighbor_atom)
          
      neighbor_list[atom] = list(neighbor_list[atom])
        
    return neighbor_list
+197 −0
Original line number Diff line number Diff line
"""
Test atomic coordinates and neighbor lists.
"""
import numpy as np
import unittest
from rdkit import Chem
from deepchem.utils import conformers
from deepchem.featurizers.atomic_coordinates import get_cells
from deepchem.featurizers.atomic_coordinates import get_coords
from deepchem.featurizers.atomic_coordinates import put_atoms_in_cells
from deepchem.featurizers.atomic_coordinates import compute_neighbor_cell_map
from deepchem.featurizers.atomic_coordinates import AtomicCoordinates
from deepchem.featurizers.atomic_coordinates import NeighborListAtomicCoordinates

class TestAtomicCoordinates(unittest.TestCase):
  """
  Test AtomicCoordinates.
  """
  def setUp(self):
    """
    Set up tests.
    """
    smiles = 'CC(=O)OC1=CC=CC=C1C(=O)O'
    mol = Chem.MolFromSmiles(smiles)
    engine = conformers.ConformerGenerator(max_conformers=1)
    self.mol = engine.generate_conformers(mol)
    assert self.mol.GetNumConformers() > 0

  def test_atomic_coordinates(self):
    """
    Simple test that atomic coordinates returns ndarray of right shape.
    """
    N = self.mol.GetNumAtoms()
    atomic_coords_featurizer = AtomicCoordinates()
    # TODO(rbharath, joegomes): Why does AtomicCoordinates return a list? Is
    # this expected behavior? Need to think about API.
    coords = atomic_coords_featurizer._featurize(self.mol)[0]
    assert isinstance(coords, np.ndarray)
    assert coords.shape == (N, 3)

  def test_get_cells(self):
    """
    Test that coordinates are split into cell appropriately.
    """
    # The coordinates span the cube of side-length 2 set on (-1, 1)
    coords = np.array(
        [[1., 1., 1.],
         [-1., -1., -1.]])
    # Set cell size (neighbor_cutoff) at 1 angstrom.
    neighbor_cutoff = 1
    # We should get 2 bins in each dimension
    x_bins, y_bins, z_bins = get_cells(coords, neighbor_cutoff)

    # Check bins are lists
    assert isinstance(x_bins, list)
    assert isinstance(y_bins, list)
    assert isinstance(z_bins, list)

    assert len(x_bins) == 2
    assert x_bins ==[(-1.0, 0.0), (0.0, 1.0)] 
    assert len(y_bins) == 2
    assert y_bins ==[(-1.0, 0.0), (0.0, 1.0)] 
    assert len(z_bins) == 2
    assert z_bins ==[(-1.0, 0.0), (0.0, 1.0)] 


  def test_put_atoms_in_cells(self):
    """
    Test that atoms are placed into correct cells.
    """
    # As in previous example, coordinates span size-2 cube on (-1, 1)
    coords = np.array(
        [[1., 1., 1.],
         [-1., -1., -1.]])
    # Set cell size (neighbor_cutoff) at 1 angstrom.
    neighbor_cutoff = 1
    # We should get 2 bins in each dimension
    x_bins, y_bins, z_bins = get_cells(coords, neighbor_cutoff)

    cell_to_atoms, atom_to_cell = put_atoms_in_cells(
        coords, x_bins, y_bins, z_bins)

    # Both cell_to_atoms and atom_to_cell are dictionaries
    assert isinstance(cell_to_atoms, dict)
    assert isinstance(atom_to_cell, dict)

    # atom_to_cell should be of len 2 since 2 atoms
    assert len(atom_to_cell) == 2

    # cell_to_atoms should be of len 8 since 8 cells total.
    assert len(cell_to_atoms) == 8

    # We have two atoms. The first is in highest corner (1,1,1)
    # Second atom should be in lowest corner (0, 0, 0)
    assert atom_to_cell[0] == (1, 1, 1)
    assert atom_to_cell[1] == (0, 0, 0)

    # (1,1,1) should contain atom 0. (0, 0, 0) should contain atom 1.
    # Everything else should be an empty list
    for cell, atoms in cell_to_atoms.items():
      if cell == (1, 1, 1):
        assert atoms == [0]
      elif cell == (0, 0, 0):
        assert atoms == [1]
      else:
        assert atoms == []

  def test_compute_neighbor_cell_map(self):
    """
    Tests that computed neighbors for grid are meaningful.
    """
    # For a 1x1x1 grid, the neighbor cell map should return [(0,0,0)] * 27
    # since the periodic boundary conditions mean wrap-around happens in all
    # directions.
    neighbor_cell_map = compute_neighbor_cell_map(1, 1, 1)
    assert isinstance(neighbor_cell_map, dict)
    assert len(neighbor_cell_map) == 1
    assert neighbor_cell_map[(0,0,0)] == [(0,0,0)] * 27

    neighbor_cell_map = compute_neighbor_cell_map(5, 5, 5)
    assert isinstance(neighbor_cell_map, dict)
    assert len(neighbor_cell_map) == 125 
    assert sorted(neighbor_cell_map[(2,2, 2)]) == [
        (1, 1, 1),
        (1, 1, 2),
        (1, 1, 3),
        (1, 2, 1),
        (1, 2, 2),
        (1, 2, 3),
        (1, 3, 1),
        (1, 3, 2),
        (1, 3, 3),
        (2, 1, 1),
        (2, 1, 2),
        (2, 1, 3),
        (2, 2, 1),
        (2, 2, 2),
        (2, 2, 3),
        (2, 3, 1),
        (2, 3, 2),
        (2, 3, 3),
        (3, 1, 1),
        (3, 1, 2),
        (3, 1, 3),
        (3, 2, 1),
        (3, 2, 2),
        (3, 2, 3),
        (3, 3, 1),
        (3, 3, 2),
        (3, 3, 3)]

  def test_neighbor_list_shape(self):
    """
    Simple test that Neighbor Lists have right shape.
    """
    nblist_featurizer = NeighborListAtomicCoordinates()
    N = self.mol.GetNumAtoms()
    coords = get_coords(self.mol)
    x_bins, y_bins, z_bins = get_cells(coords, nblist_featurizer.neighbor_cutoff)

    nblist_featurizer = NeighborListAtomicCoordinates()
    nblist = nblist_featurizer._featurize(self.mol)
    assert isinstance(nblist, dict)
    assert len(nblist.keys()) == N
    for (atom, neighbors) in nblist.items():
      assert isinstance(atom, int)
      assert isinstance(neighbors, list)
      assert len(neighbors) <= N

    # Do a manual distance computation and make 
    for i in range(N):
      for j in range(N):
        dist = np.linalg.norm(coords[i] - coords[j])
        print("Distance(%d, %d) = %f" % (i, j, dist))
        if dist < nblist_featurizer.neighbor_cutoff and i != j:
          assert j in nblist[i]
        else:
          assert j not in nblist[i]

  def test_neighbor_list_extremes(self):
    """
    Test Neighbor Lists with large/small boxes.
    """
    N = self.mol.GetNumAtoms()

    # Test with cutoff 0 angstroms. There should be no neighbors in this case.
    nblist_featurizer = NeighborListAtomicCoordinates(neighbor_cutoff=.1)
    nblist = nblist_featurizer._featurize(self.mol)
    for atom in range(N):
      assert len(nblist[atom]) == 0

    # Test with cutoff 100 angstroms. Everything should be neighbors now.
    nblist_featurizer = NeighborListAtomicCoordinates(neighbor_cutoff=100)
    nblist = nblist_featurizer._featurize(self.mol)
    for atom in range(N):
      assert len(nblist[atom]) == N-1