Commit 824faa2c authored by peastman's avatar peastman
Browse files

Continuing to convert layers

parent e9102d12
Loading
Loading
Loading
Loading
+50 −18
Original line number Diff line number Diff line
@@ -87,8 +87,9 @@ class Layer(object):
      return self.clone(in_layers)
    raise ValueError('%s does not implement shared()' % self.__class__.__name__)

  def __call__(self, *in_layers, training=False):
    return self.create_tensor(in_layers=in_layers, set_tensors=False, training=training)
  def __call__(self, *in_layers, training=False, **kwargs):
    return self.create_tensor(
        in_layers=in_layers, set_tensors=False, training=training, **kwargs)

  @property
  def shape(self):
@@ -576,8 +577,7 @@ class Dense(SharedVariableScope):
        kernel_initializer=self.weights_initializer(),
        bias_initializer=biases_initializer,
        _scope=self._get_scope_name(),
      _reuse=reuse
    )
        _reuse=reuse)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
@@ -964,6 +964,14 @@ class GRU(Layer):
  This layer expects its input to be of shape (batch_size, sequence_length, ...).
  It consists of a set of independent sequences (one for each element in the batch),
  that are each propagated independently through the GRU.

  When this layer is called in eager execution mode, it behaves slightly differently.
  It returns two tensors: the output of the recurrent layer, and the final state
  of the recurrent cell.  In addition, you can specify the initial_state option
  to tell it what to use as the initial state of the recurrent cell.  If that
  option is omitted, it defaults to all zeros for the initial state:

  outputs, final_state = gru_layer(input, initial_state=state)
  """

  def __init__(self, n_hidden, batch_size, **kwargs):
@@ -979,6 +987,10 @@ class GRU(Layer):
    self.n_hidden = n_hidden
    self.batch_size = batch_size
    super(GRU, self).__init__(**kwargs)
    if tfe.in_eager_mode():
      self._cell = tf.contrib.rnn.GRUCell(n_hidden)
      self.variables = self._cell.variables
      self._zero_state = self._cell.zero_state(batch_size, tf.float32)
    try:
      parent_shape = self.in_layers[0].shape
      self._shape = (batch_size, parent_shape[1], n_hidden)
@@ -990,10 +1002,16 @@ class GRU(Layer):
    if len(inputs) != 1:
      raise ValueError("Must have one parent")
    parent_tensor = inputs[0]
    if tfe.in_eager_mode():
      gru_cell = self._cell
      zero_state = self._zero_state
    else:
      gru_cell = tf.contrib.rnn.GRUCell(self.n_hidden)
      zero_state = gru_cell.zero_state(self.batch_size, tf.float32)
    if set_tensors:
      initial_state = tf.placeholder(tf.float32, zero_state.get_shape())
    elif 'initial_state' in kwargs:
      initial_state = kwargs['initial_state']
    else:
      initial_state = zero_state
    out_tensor, final_state = tf.nn.dynamic_rnn(
@@ -1007,6 +1025,9 @@ class GRU(Layer):
      self.out_tensors = [
          self.out_tensor, initial_state, final_state, zero_state
      ]
    if tfe.in_eager_mode():
      return (out_tensor, final_state)
    else:
      return out_tensor

  def none_tensors(self):
@@ -1100,16 +1121,22 @@ class TimeSeriesDense(Layer):
  def __init__(self, out_channels, **kwargs):
    self.out_channels = out_channels
    super(TimeSeriesDense, self).__init__(**kwargs)
    if tfe.in_eager_mode():
      self._layer = self._build_layer()

  def _build_layer(self):
    return tf.layers.Dense(self.out_channels, activation=tf.nn.sigmoid)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    if len(inputs) != 1:
      raise ValueError("Must have one parent")
    parent_tensor = inputs[0]
    dense_fn = lambda x: tf.contrib.layers.fully_connected(
      x, num_outputs=self.out_channels,
      activation_fn=tf.nn.sigmoid)
    out_tensor = tf.map_fn(dense_fn, parent_tensor)
    if tfe.in_eager_mode():
      layer = self._layer
    else:
      layer = self._build_layer()
    out_tensor = tf.map_fn(layer, parent_tensor)
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
@@ -1386,8 +1413,13 @@ class Variable(Layer):
    self.dtype = dtype
    self._shape = tuple(initial_value.shape)
    super(Variable, self).__init__(**kwargs)
    if tfe.in_eager_mode():
      self.variables = [tfe.Variable(self.initial_value, dtype=self.dtype)]

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if tfe.in_eager_mode():
      out_tensor = self.variables[0]
    else:
      out_tensor = tf.Variable(self.initial_value, dtype=self.dtype)
    if set_tensors:
      self._record_variable_scope(self.name)
+210 −9
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import tensorflow.contrib.eager as tfe
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util


class TestLayersEager(test_util.TensorFlowTestCase):
  """
  Test that layers function in eager mode.
@@ -19,7 +20,8 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        filters = 3
        kernel_size = 2
        batch_size = 10
        input = np.random.rand(batch_size, width, in_channels).astype(np.float32)
        input = np.random.rand(batch_size, width, in_channels).astype(
            np.float32)
        layer = layers.Conv1D(filters, kernel_size)
        result = layer(input)
        self.assertEqual(result.shape[0], batch_size)
@@ -155,6 +157,148 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        result = layers.Gather()(input, indices)
        assert np.array_equal(result, [input[1], input[3]])

  def test_gru(self):
    """Test invoking GRU in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        n_hidden = 7
        in_channels = 4
        n_steps = 6
        input = np.random.rand(batch_size, n_steps, in_channels).astype(
            np.float32)
        layer = layers.GRU(n_hidden, batch_size)
        result, state = layer(input)
        assert result.shape == (batch_size, n_steps, n_hidden)

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.GRU(n_hidden, batch_size)
        result2, state2 = layer2(input)
        assert not np.allclose(result, result2)

        # But evaluating the first layer again should produce the same result as before.

        result3, state3 = layer(input)
        assert np.allclose(result, result3)

        # But if we specify a different starting state, that should produce a
        # different result.

        result4, state4 = layer(input, initial_state=state3)
        assert not np.allclose(result, result4)

  def test_time_series_dense(self):
    """Test invoking TimeSeriesDense in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        in_dim = 2
        out_dim = 3
        n_steps = 6
        batch_size = 10
        input = np.random.rand(batch_size, n_steps, in_dim).astype(np.float32)
        layer = layers.TimeSeriesDense(out_dim)
        result = layer(input)
        assert result.shape == (batch_size, n_steps, out_dim)

        # Creating a second layer should produce different results, since it has
        # different random weights.

        layer2 = layers.TimeSeriesDense(out_dim)
        result2 = layer2(input)
        assert not np.allclose(result, result2)

        # But evaluating the first layer again should produce the same result as before.

        result3 = layer(input)
        assert np.allclose(result, result3)

  def test_l1_loss(self):
    """Test invoking L1Loss in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.random.rand(5, 10).astype(np.float32)
        input2 = np.random.rand(5, 10).astype(np.float32)
        result = layers.L1Loss()(input1, input2)
        expected = np.mean(np.abs(input1 - input2), axis=1)
        assert np.allclose(result, expected)

  def test_l2_loss(self):
    """Test invoking L2Loss in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.random.rand(5, 10).astype(np.float32)
        input2 = np.random.rand(5, 10).astype(np.float32)
        result = layers.L2Loss()(input1, input2)
        expected = np.mean((input1 - input2)**2, axis=1)
        assert np.allclose(result, expected)

  def test_softmax(self):
    """Test invoking SoftMax in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input = np.random.rand(5, 10).astype(np.float32)
        result = layers.SoftMax()(input)
        expected = tf.nn.softmax(input)
        assert np.allclose(result, expected)

  def test_sigmoid(self):
    """Test invoking Sigmoid in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input = np.random.rand(5, 10).astype(np.float32)
        result = layers.Sigmoid()(input)
        expected = tf.nn.sigmoid(input)
        assert np.allclose(result, expected)

  def test_relu(self):
    """Test invoking ReLU in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input = np.random.normal(size=(5, 10)).astype(np.float32)
        result = layers.ReLU()(input)
        expected = tf.nn.relu(input)
        assert np.allclose(result, expected)

  def test_concat(self):
    """Test invoking Concat in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.random.rand(5, 10).astype(np.float32)
        input2 = np.random.rand(5, 4).astype(np.float32)
        result = layers.Concat()(input1, input2)
        assert result.shape == (5, 14)
        assert np.array_equal(input1, result[:, :10])
        assert np.array_equal(input2, result[:, 10:])

  def test_stack(self):
    """Test invoking Stack in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        input1 = np.random.rand(5, 4).astype(np.float32)
        input2 = np.random.rand(5, 4).astype(np.float32)
        result = layers.Stack()(input1, input2)
        assert result.shape == (5, 2, 4)
        assert np.array_equal(input1, result[:, 0, :])
        assert np.array_equal(input2, result[:, 1, :])

  def test_constant(self):
    """Test invoking Constant in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        value = np.random.rand(5, 4).astype(np.float32)
        result = layers.Constant(value)()
        assert np.array_equal(result, value)

  def test_variable(self):
    """Test invoking Variable in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        value = np.random.rand(5, 4).astype(np.float32)
        result = layers.Variable(value)()
        assert np.array_equal(result.numpy(), value)

  def test_add(self):
    """Test invoking Add in eager mode."""
    with context.eager_mode():
@@ -189,3 +333,60 @@ class TestLayersEager(test_util.TensorFlowTestCase):
      with tfe.IsolateTest():
        result = layers.Exp()(2.5)
        assert np.allclose(result, np.exp(2.5))

  def test_interatomic_l2_distances(self):
    """Test invoking InteratomicL2Distances in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        atoms = 5
        neighbors = 2
        coords = np.random.rand(atoms, 3)
        neighbor_list = np.random.randint(0, atoms, size=(atoms, neighbors))
        layer = layers.InteratomicL2Distances(atoms, neighbors, 3)
        result = layer(coords, neighbor_list)
        assert result.shape == (atoms, neighbors)
        for atom in range(atoms):
          for neighbor in range(neighbors):
            delta = coords[atom] - coords[neighbor_list[atom, neighbor]]
            dist2 = np.dot(delta, delta)
            assert np.allclose(dist2, result[atom, neighbor])

  def test_sparse_softmax_cross_entropy(self):
    """Test invoking SparseSoftMaxCrossEntropy in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        n_features = 5
        logits = np.random.rand(batch_size, n_features).astype(np.float32)
        labels = np.random.rand(batch_size).astype(np.int32)
        result = layers.SparseSoftMaxCrossEntropy()(labels, logits)
        expected = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        assert np.allclose(result, expected)

  def test_softmax_cross_entropy(self):
    """Test invoking SoftMaxCrossEntropy in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        n_features = 5
        logits = np.random.rand(batch_size, n_features).astype(np.float32)
        labels = np.random.rand(batch_size, n_features).astype(np.float32)
        result = layers.SoftMaxCrossEntropy()(labels, logits)
        expected = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=labels, logits=logits)
        assert np.allclose(result, expected)

  def test_sigmoid_cross_entropy(self):
    """Test invoking SigmoidCrossEntropy in eager mode."""
    with context.eager_mode():
      with tfe.IsolateTest():
        batch_size = 10
        n_features = 5
        logits = np.random.rand(batch_size, n_features).astype(np.float32)
        labels = np.random.randint(0, 2, (batch_size, n_features)).astype(
            np.float32)
        result = layers.SigmoidCrossEntropy()(labels, logits)
        expected = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels, logits=logits)
        assert np.allclose(result, expected)