Commit 29c11f55 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Vina loss now defined (minor issues remain, punting on those)

parent 0789b79e
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -173,11 +173,15 @@ class TestVinaModel(test_util.TensorFlowTestCase):
      assert cells_for_atoms.shape == (N, 1)

  def test_vina_construct_graph(self):
    """Test that vina model can generate meaningful conformations."""
    """Test that vina model graph can be constructed."""
    data_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(data_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(data_dir, "1jld_ligand.pdb")

    vina_model = VinaModel()

  def test_vina_generate_conformers(self):
    """Test that Vina Model can generate conformers"""
    max_protein_atoms = 3500 
    max_ligand_atoms = 100

@@ -192,5 +196,3 @@ class TestVinaModel(test_util.TensorFlowTestCase):
        np.array([atom.GetAtomicNum() for atom in ligand_mol.GetAtoms()]),
        max_ligand_atoms)
    vina_model = VinaModel()
+50 −38
Original line number Diff line number Diff line
@@ -298,9 +298,9 @@ def compute_neighbor_cells(cells, ndim, n_cells):
  return closest_inds


def cutoff(d):
def cutoff(d, x):
  """Truncates interactions that are too far away."""
  return tf.cond(d < 8, d, 0)
  return tf.select(d < 8, x, tf.zeros_like(x))

def gauss_1(d):
  """Computes first Gaussian interaction term.
@@ -314,25 +314,25 @@ def gauss_2(d):

  Note that d must be in Angstrom.
  """
  return tf.exp(-((d-3)/2)^2)

  return tf.exp(-((d-3)/2)**2)

def repulsion(d):
  """Computes repulsion interaction term."""
  return tf.cond(d < 0, d**2, 0)
  return tf.select(d < 0, d**2, tf.zeros_like(d))

def hydrophobic(d):
  """Compute hydrophobic interaction term."""
  return tf.cond(d < 0.5, 1,
                 tf.cond(d < 1.5, 1.5 - d,  0))
  return tf.select(d < 0.5, tf.ones_like(d),
                            tf.select(d < 1.5, 1.5 - d,  tf.zeros_like(d)))

def hbond(d):
  """Computes hydrogen bond term."""
  return tf.cond(d < -0.7, 1,
                 tf.cond(d < 0, (1.0/0.7)(0-d), 0))
  return tf.select(d < -0.7, tf.ones_like(d),
                   tf.select(d < 0, (1.0/0.7)*(0-d), tf.zeros_like(d)))

def g(c, w, Nrot):
def g(c, Nrot):
  """Nonlinear function mapping interactions to free energy."""
  w = tf.Variable(tf.random_normal([1,], stddev=.3))
  return c/(1 + w*Nrot)
  
def h(d):
@@ -462,35 +462,47 @@ class VinaModel(Model):
      nbr_list = compute_neighbor_list(coords, nbr_cutoff, N_protein+N_ligand, M,
                                       n_cells, ndim=ndim, k=k)
      all_interactions = []
      for atom in range(N_protein+N_ligand):
        # Shape (3,)
        atom_coords = tf.gather(coords, atom)
        # Shape (1,)
        atom_Z = tf.gather(Z, [atom])

        # Shape (M,)
        nbrs = tf.squeeze(tf.gather(nbr_list, [atom]))
        # Shape (M, 3)

      # Shape (N_protein+N_ligand,)
      all_atoms = tf.range(N_protein+N_ligand)
      # Shape (N_protein+N_ligand, 3)
      atom_coords = tf.gather(coords, all_atoms)
      # Shape (N_protein+N_ligand,)
      atom_Z = tf.gather(Z, all_atoms)
      # Shape (N_protein+N_ligand, M)
      nbrs = tf.squeeze(tf.gather(nbr_list, all_atoms))
      # Shape (N_protein+N_ligand, M, 3)
      nbr_coords = tf.gather(coords, nbrs)
        # Shape (M,)

      # Shape (N_protein+N_ligand, M)
      nbr_Z = tf.gather(Z, nbrs)
      # Shape (N_protein+N_ligand, M, 3)
      tiled_atom_coords = tf.tile(
          tf.reshape(atom_coords, (N_protein+N_ligand, 1, 3)), (1, M, 1))

        # Shape (M, 3)
        tiled_atom_coords = tf.tile(tf.reshape(atom_coords, (1, 3)), (M, 1))
        # Shape (M,)
        dists = tf.reduce_sum((tiled_atom_coords - nbr_coords)**2, axis=1)
      # Shape (N_protein+N_ligand, M)
      dists = tf.reduce_sum((tiled_atom_coords - nbr_coords)**2, axis=2)
    
      # TODO(rbharath): Need to subtract out Van-der-Waals radii from dists

      # Shape (N_protein+N_ligand, M)
      atom_interactions = h(dists)
        all_interactions.append(atom_interactions)
      all_interactions = tf.pack(all_interactions)
      energy = tf.reduce_sum(all_interactions)
      loss = tf.mul(0.5 * tf.square(energy - label_placeholder), weights)
      ############################################## DEBUG
      print("dists")
      print(dists)
      assert 0 == 1
      ############################################## DEBUG
      # Shape (N_protein+N_ligand, M)
      cutoff_interactions = cutoff(dists, atom_interactions)
  
      # TODO(rbharath): Use RDKit to compute number of rotatable bonds in ligand.
      Nrot = 1
  
      # TODO(rbharath): Autodock Vina only uses protein-ligand interactions in 
      # computing free-energy. This implementation currently uses all interaction
      # terms. Not sure if this makes a difference.

      # Shape (N_protein+N_ligand, M)
      free_energy = g(cutoff_interactions, Nrot)
      # Shape () -- scalar
      energy = tf.reduce_sum(atom_interactions)

      loss = 0.5 * (energy - label_placeholder)**2
        
    return (graph, (protein_coords_placeholder, protein_Z_placeholder,
                    ligand_coords_placeholder, ligand_Z_placeholder), label_placeholder)