Commit d99c86c6 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #799 from peastman/variables

Fixed bug in get_layer_variables()
parents 93e1e761 1c559cd1
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -601,6 +601,8 @@ class TensorGraph(Model):
    if not self.built:
      self.build()
    with self._get_tf("Graph").as_default():
      if layer.variable_scope == '':
        return []
      return tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope=layer.variable_scope)

+1 −0
Original line number Diff line number Diff line
@@ -74,6 +74,7 @@ class TestNbrList(test_util.TensorFlowTestCase):
    tg.set_loss(loss)
    tg.fit_generator(databag.iterbatches(epochs=1))
    assert len(tg.get_layer_variables(combo)) >= 2
    assert len(tg.get_layer_variables(out)) == 0

  def test_neighbor_list_shape(self):
    """Test that NeighborList works."""