Commit 9b9ae903 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Forgot to add

parent 32cb3dac
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -97,7 +97,7 @@ def compute_neighbor_list(coords, nbr_cutoff, N, M, n_cells, ndim=3, k=5):
                   in zip(split_nbr_inds, split_closest_nbr_locs)]

  # Shape (N, M)
  neighbor_list = tf.pack(neighbor_list)
  neighbor_list = tf.stack(neighbor_list)

  return neighbor_list   

@@ -138,7 +138,7 @@ def get_cells_for_atoms(coords, cells, N, n_cells, ndim=3):
                  for norm in coords_norm]

  # TODO(rbharath): tf.stack for tf 1.0
  return tf.pack(closest_inds)
  return tf.stack(closest_inds)
    

def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N,
@@ -158,7 +158,7 @@ def compute_closest_neighbors(coords, cells, atoms_in_cells, neighbor_cells, N,
  """
  n_cells = len(atoms_in_cells)
  # Tensor of shape (n_cells, k, ndim)
  #atoms_in_cells = tf.pack(atoms_in_cells)
  #atoms_in_cells = tf.stack(atoms_in_cells)

  cells_for_atoms = get_cells_for_atoms(coords, cells, N, n_cells, ndim)
  all_closest = []
@@ -191,7 +191,7 @@ def get_cells(start, stop, nbr_cutoff, ndim=3):
  cells: tf.Tensor
    (box_size**ndim, ndim) shape.
  """
  return tf.reshape(tf.transpose(tf.pack(tf.meshgrid(
  return tf.reshape(tf.transpose(tf.stack(tf.meshgrid(
      *[tf.range(start, stop, nbr_cutoff) for _ in range(ndim)]))), (-1, ndim))
     
def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
@@ -237,9 +237,9 @@ def put_atoms_in_cells(coords, cells, N, n_cells, ndim, k=5):
  # Get indices of k atoms closest to each cell point
  closest_inds = [tf.nn.top_k(norm, k=k)[1] for norm in coords_norm]
  # n_cells tensors of shape (k, ndim)
  closest_atoms = tf.pack([tf.gather(coords, inds) for inds in closest_inds])
  closest_atoms = tf.stack([tf.gather(coords, inds) for inds in closest_inds])
  # Tensor of shape (n_cells, k)
  closest_inds = tf.pack(closest_inds)
  closest_inds = tf.stack(closest_inds)

  return closest_inds, closest_atoms

@@ -293,7 +293,7 @@ def compute_neighbor_cells(cells, ndim, n_cells):
  # Lists of length n_cells
  # Get indices of k atoms closest to each cell point
  # n_cells tensors of shape (26,)
  closest_inds = tf.pack([tf.nn.top_k(norm, k=k)[1] for norm in coords_norm])
  closest_inds = tf.stack([tf.nn.top_k(norm, k=k)[1] for norm in coords_norm])

  return closest_inds