Commit 4270c3da authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Adds NeighborListAtomicCoordinates and tests

parent 5dcf29e7
Loading
Loading
Loading
Loading
+39 −11
Original line number Diff line number Diff line
@@ -140,6 +140,20 @@ def compute_neighbor_cell_map(N_x, N_y, N_z):
        neighbor_cell_map[(x_ind, y_ind, z_ind)] = neighbors
  return neighbor_cell_map

def get_coords(mol):
  """
  Gets coordinates in Angstrom for RDKit mol.
  """
  N = mol.GetNumAtoms()
  coords = np.zeros((N,3))

  coords_raw = [mol.GetConformer(0).GetAtomPosition(i) for i in range(N)]
  for atom in range(N):
    coords[atom,0] = coords_raw[atom].x
    coords[atom,1] = coords_raw[atom].y
    coords[atom,2] = coords_raw[atom].z
  return coords

class NeighborListAtomicCoordinates(Featurizer):
  """
  Adjacency List of neighbors in 3-space
@@ -156,7 +170,9 @@ class NeighborListAtomicCoordinates(Featurizer):
  """ 

  def __init__(self, neighbor_cutoff=4):
    self.neighbor_cutoff = 4
    if neighbor_cutoff <= 0:
      raise ValueError("neighbor_cutoff must be positive value.")
    self.neighbor_cutoff = neighbor_cutoff

  def _featurize(self, mol):
    """
@@ -166,13 +182,7 @@ class NeighborListAtomicCoordinates(Featurizer):
    ----------
    """
    N = mol.GetNumAtoms()
    coords = np.zeros((N,3))

    coords_raw = [mol.GetConformer(0).GetAtomPosition(i) for i in range(N)]
    for atom in range(N):
      coords[atom,0] = coords_raw[atom].x
      coords[atom,1] = coords_raw[atom].y
      coords[atom,2] = coords_raw[atom].z
    coords = get_coords(mol)

    x_bins, y_bins, z_bins = get_cells(coords, self.neighbor_cutoff)

@@ -192,13 +202,31 @@ class NeighborListAtomicCoordinates(Featurizer):
    for atom in range(N):
      cell = atom_to_cell[atom]
      neighbor_cells = neighbor_cell_map[cell]
      neighbor_list[atom] = []
      # For smaller systems especially, the periodic boundary conditions can
      # result in neighboring cells being seen multiple times. Use a set() to
      # make sure duplicate neighbors are ignored. Convert back to list before
      # returning. 
      neighbor_list[atom] = set()
      ####################################################### DEBUG
      all_nbrs = set()
      print("self.neighbor_cutoff")
      print(self.neighbor_cutoff)
      ####################################################### DEBUG
      for neighbor_cell in neighbor_cells:
        atoms_in_cell = cell_to_atoms[neighbor_cell]
        for neighbor_atom in atoms_in_cell:
          if neighbor_atom == atom:
            continue
          if np.linalg.norm(coords[atom] - coords[atoms_in_cell]) < self.neighbor_cutoff:
            neighbor_list[atom].append(neighbor_atom)
          # TODO(rbharath): How does distance need to be modified here to
          # account for periodic boundary conditions?
          if np.linalg.norm(coords[atom] - coords[neighbor_atom]) < self.neighbor_cutoff:
            neighbor_list[atom].add(neighbor_atom)
          ########################################################### DEBUG
          all_nbrs.add(neighbor_atom)
          ########################################################### DEBUG
          
      ########################################################### DEBUG
      print("All neighbor-cell atoms for %d = %s" % (atom, str(all_nbrs)))
      neighbor_list[atom] = list(neighbor_list[atom])
        
    return neighbor_list
+72 −4
Original line number Diff line number Diff line
@@ -6,7 +6,9 @@ import unittest
from rdkit import Chem
from deepchem.utils import conformers
from deepchem.featurizers.atomic_coordinates import get_cells
from deepchem.featurizers.atomic_coordinates import get_coords
from deepchem.featurizers.atomic_coordinates import put_atoms_in_cells
from deepchem.featurizers.atomic_coordinates import compute_neighbor_cell_map
from deepchem.featurizers.atomic_coordinates import AtomicCoordinates
from deepchem.featurizers.atomic_coordinates import NeighborListAtomicCoordinates

@@ -103,27 +105,93 @@ class TestAtomicCoordinates(unittest.TestCase):
      else:
        assert atoms == []

  def test_compute_neighbor_cell_map(self):
    """
    Tests that computed neighbors for grid are meaningful.
    """
    # For a 1x1x1 grid, the neighbor cell map should return [(0,0,0)] * 27
    # since the periodic boundary conditions mean wrap-around happens in all
    # directions.
    neighbor_cell_map = compute_neighbor_cell_map(1, 1, 1)
    assert isinstance(neighbor_cell_map, dict)
    assert len(neighbor_cell_map) == 1
    assert neighbor_cell_map[(0,0,0)] == [(0,0,0)] * 27

    neighbor_cell_map = compute_neighbor_cell_map(5, 5, 5)
    assert isinstance(neighbor_cell_map, dict)
    assert len(neighbor_cell_map) == 125 
    assert sorted(neighbor_cell_map[(2,2, 2)]) == [
        (1, 1, 1),
        (1, 1, 2),
        (1, 1, 3),
        (1, 2, 1),
        (1, 2, 2),
        (1, 2, 3),
        (1, 3, 1),
        (1, 3, 2),
        (1, 3, 3),
        (2, 1, 1),
        (2, 1, 2),
        (2, 1, 3),
        (2, 2, 1),
        (2, 2, 2),
        (2, 2, 3),
        (2, 3, 1),
        (2, 3, 2),
        (2, 3, 3),
        (3, 1, 1),
        (3, 1, 2),
        (3, 1, 3),
        (3, 2, 1),
        (3, 2, 2),
        (3, 2, 3),
        (3, 3, 1),
        (3, 3, 2),
        (3, 3, 3)]

  def test_neighbor_list_shape(self):
    """
    Simple test that Neighbor Lists have right shape.
    """
    nblist_featurizer = NeighborListAtomicCoordinates()
    N = self.mol.GetNumAtoms()
    coords = get_coords(self.mol)
    x_bins, y_bins, z_bins = get_cells(coords, nblist_featurizer.neighbor_cutoff)

    nblist_featurizer = NeighborListAtomicCoordinates()
    nblist = nblist_featurizer._featurize(self.mol)
    assert isinstance(nblist, dict)
    assert len(nblist.keys()) == N
    print("nblist")
    print(nblist)
    for (atom, neighbors) in nblist.items():
      assert isinstance(atom, int)
      assert isinstance(neighbors, list)
      assert len(neighbors) <= N

    # Do a manual distance computation and make 
    for i in range(N):
      for j in range(N):
        dist = np.linalg.norm(coords[i] - coords[j])
        print("Distance(%d, %d) = %f" % (i, j, dist))
        if dist < nblist_featurizer.neighbor_cutoff and i != j:
          assert j in nblist[i]
        else:
          assert j not in nblist[i]

  def test_neighbor_list_extremes(self):
    """
    Test Neighbor Lists with large/small boxes.
    """
    pass
    N = self.mol.GetNumAtoms()

    # Test with cutoff 0 angstroms. There should be no neighbors in this case.
    nblist_featurizer = NeighborListAtomicCoordinates(neighbor_cutoff=.1)
    nblist = nblist_featurizer._featurize(self.mol)
    for atom in range(N):
      assert len(nblist[atom]) == 0

    # Test with cutoff 100 angstroms. Everything should be neighbors now.
    nblist_featurizer = NeighborListAtomicCoordinates(neighbor_cutoff=100)
    nblist = nblist_featurizer._featurize(self.mol)
    for atom in range(N):
      assert len(nblist[atom]) == N-1