Commit f1df7356 authored by peastman's avatar peastman
Browse files

Layers look up variables rather than storing them

parent fdff1b95
Loading
Loading
Loading
Loading
+38 −36
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ class Layer(object):
      in_layers = [in_layers]
    self.in_layers = in_layers
    self.op_type = "gpu"
    self.variables = []
    self.variable_scope = ''

  def _get_layer_number(self):
    class_name = self.__class__.__name__
@@ -35,10 +35,6 @@ class Layer(object):
    Layer.layer_number_dict[class_name] += 1
    return "%s" % Layer.layer_number_dict[class_name]

  def get_variables(self):
    """Get trainable variables in this layer."""
    return self.variables

  def none_tensors(self):
    out_tensor = self.out_tensor
    self.out_tensor = None
@@ -104,6 +100,17 @@ class Layer(object):
          tensors[i] = tf.reshape(tensors[i], shape)
    return tensors

  def _record_variable_scope(self, local_scope):
    """Record the scope name used for creating variables.

    This should be called from create_tensor().  It allows the list of variables
    belonging to this layer to be retrieved later."""
    parent_scope = tf.get_variable_scope().name
    if len(parent_scope) > 0:
      self.variable_scope = '%s/%s' % (parent_scope, local_scope)
    else:
      self.variable_scope = local_scope


class TensorWrapper(Layer):
  """Used to wrap a tensorflow tensor."""
@@ -114,7 +121,7 @@ class TensorWrapper(Layer):

  def create_tensor(self, in_layers=None, **kwargs):
    """Take no actions."""
    self.variables = []
    pass


def convert_to_layers(in_layers):
@@ -154,7 +161,7 @@ class Conv1D(Layer):
    t = tf.nn.bias_add(t, b)
    out_tensor = tf.nn.relu(t)
    if set_tensors:
      self.variables = [f, b]
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    return out_tensor

@@ -209,29 +216,29 @@ class Dense(Layer):
      biases_initializer = None
    else:
      biases_initializer = self.biases_initializer()
    if not self.time_series:
      out_tensor = tf.contrib.layers.fully_connected(
          parent,
          num_outputs=self.out_channels,
          activation_fn=self.activation_fn,
          biases_initializer=biases_initializer,
          weights_initializer=self.weights_initializer(),
          scope=self._get_scope_name(),
          reuse=self._reuse,
          trainable=True)
    else:
    for reuse in (self._reuse, False):
      dense_fn = lambda x: tf.contrib.layers.fully_connected(x,
                                                             num_outputs=self.out_channels,
                                                             activation_fn=self.activation_fn,
                                                             biases_initializer=biases_initializer,
                                                             weights_initializer=self.weights_initializer(),
                                                             scope=self._get_scope_name(),
                                                             reuse=self._reuse,
                                                             reuse=reuse,
                                                             trainable=True)
      try:
        if self.time_series:
          out_tensor = tf.map_fn(dense_fn, parent)
        else:
          out_tensor = dense_fn(parent)
        break
      except ValueError:
        if reuse:
          # This probably means the variable hasn't been created yet, so try again
          # with reuse set to false.
          continue
        raise
    if set_tensors:
      self.variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=self._get_scope_name())
      self._record_variable_scope(self._get_scope_name())
      self.out_tensor = out_tensor
    return out_tensor

@@ -733,8 +740,7 @@ class Conv2D(Layer):
        scope=self.scope_name)
    out_tensor = out_tensor
    if set_tensors:
      self.variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope_name)
      self._record_variable_scope(self.scope_name)
      self.out_tensor = out_tensor
    return out_tensor

@@ -888,7 +894,7 @@ class GraphConv(Layer):

    out_tensor = atom_features
    if set_tensors:
      self.variables = self.W_list + self.b_list
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    return out_tensor

@@ -1147,7 +1153,7 @@ class VinaFreeEnergy(Layer):

    out_tensor = free_energy
    if set_tensors:
      self.variables = [weight] + weighted_combo.variables
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    return out_tensor

@@ -1163,18 +1169,16 @@ class WeightedLinearCombo(Layer):
    inputs = self._get_input_tensors(in_layers, True)
    weights = []
    out_tensor = None
    variables = []
    for in_tensor in inputs:
      w = tf.Variable(tf.random_normal([
          1,
      ], stddev=self.std))
      variables.append(w)
      if out_tensor is None:
        out_tensor = w * in_tensor
      else:
        out_tensor += w * in_tensor
    if set_tensors:
      self.variables = variables
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    return out_tensor

@@ -1270,8 +1274,8 @@ class NeighborList(Layer):
    # List of length N_atoms each of shape (M_nbrs)
    padded_dists = [
        tf.reduce_sum((atom_coord - padded_nbr_coord)**2, axis=1)
        for (atom_coord, padded_nbr_coord
            ) in zip(atom_coords, padded_nbr_coords)
        for (atom_coord,
             padded_nbr_coord) in zip(atom_coords, padded_nbr_coords)
    ]

    padded_closest_nbrs = [
@@ -1282,8 +1286,8 @@ class NeighborList(Layer):
    # N_atoms elts of size (M_nbrs,) each
    padded_neighbor_list = [
        tf.gather(padded_atom_nbrs, padded_closest_nbr)
        for (padded_atom_nbrs, padded_closest_nbr
            ) in zip(padded_nbrs, padded_closest_nbrs)
        for (padded_atom_nbrs,
             padded_closest_nbr) in zip(padded_nbrs, padded_closest_nbrs)
    ]

    neighbor_list = tf.stack(padded_neighbor_list)
@@ -1574,13 +1578,11 @@ class AtomicConvolution(Layer):
    R = self.distance_matrix(D)
    sym = []
    rsf_zeros = tf.zeros((B, N, M))
    variables = []
    for param in self.radial_params:

      # We apply the radial pooling filter before atom type conv
      # to reduce computation
      param_variables, rsf = self.radial_symmetry_function(R, *param)
      variables += param_variables

      if not self.atom_types:
        cond = tf.not_equal(Nbrs_Z, 0.0)
@@ -1595,7 +1597,7 @@ class AtomicConvolution(Layer):
    m, v = tf.nn.moments(layer, axes=[0])
    out_tensor = tf.nn.batch_normalization(layer, m, v, None, None, 1e-3)
    if set_tensors:
      self.variables = variables
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    return out_tensor

+12 −4
Original line number Diff line number Diff line
@@ -166,12 +166,12 @@ class TensorGraph(Model):
          if self.global_step % checkpoint_interval == checkpoint_interval - 1:
            saver.save(sess, self.save_file, global_step=self.global_step)
            avg_loss = float(avg_loss) / n_batches
            print('Ending global_step %d: Average loss %g' %
                  (self.global_step, avg_loss))
            print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                              avg_loss))
            avg_loss, n_batches = 0.0, 0.0
        avg_loss = float(avg_loss) / n_batches
        print('Ending global_step %d: Average loss %g' %
              (self.global_step, avg_loss))
        print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                          avg_loss))
        saver.save(sess, self.save_file, global_step=self.global_step)
        self.last_checkpoint = saver.last_checkpoints[-1]
      ############################################################## TIMING
@@ -435,6 +435,14 @@ class TensorGraph(Model):
          metrics, per_task_metrics=per_task_metrics)
      return scores, per_task_scores

  def get_layer_variables(self, layer):
    """Get the list of trainable variables in a layer of the graph."""
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      return tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=layer.variable_scope)

  def _get_tf(self, obj):
    """
    TODO(LESWING) REALLY NEED TO DOCUMENT THIS
+1 −0
Original line number Diff line number Diff line
@@ -73,6 +73,7 @@ class TestNbrList(test_util.TensorFlowTestCase):
    tg = dc.models.TensorGraph(learning_rate=0.1, use_queue=False)
    tg.set_loss(loss)
    tg.fit_generator(databag.iterbatches(epochs=1))
    assert len(tg.get_layer_variables(combo)) >= 2

  def test_neighbor_list_shape(self):
    """Test that NeighborList works."""