Commit 2612ceb8 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #567 from peastman/reshape

Layers try to automatically reshape their inputs
parents 782feb6d c1bff4fc
Loading
Loading
Loading
Loading
+2 −5
Original line number Diff line number Diff line
@@ -98,9 +98,7 @@ class TensorGraphMultiTaskClassifier(TensorGraph):
    self.add_output(output)
    labels = Label(shape=(None, n_tasks, n_classes))
    weights = Weights(shape=(None, n_tasks))
    loss = Reshape(
        shape=(-1, n_tasks),
        in_layers=[SoftMaxCrossEntropy(in_layers=[labels, output])])
    loss = SoftMaxCrossEntropy(in_layers=[labels, output])
    weighted_loss = WeightedError(in_layers=[loss, weights])
    if weight_decay_penalty != 0.0:
      weighted_loss = WeightDecay(
@@ -211,8 +209,7 @@ class TensorGraphMultiTaskRegressor(TensorGraph):
    self.add_output(output)
    labels = Label(shape=(None, n_tasks, 1))
    weights = Weights(shape=(None, n_tasks))
    loss = Reshape(
        shape=(-1, n_tasks), in_layers=[L2Loss(in_layers=[labels, output])])
    loss = L2Loss(in_layers=[labels, output])
    weighted_loss = WeightedError(in_layers=[loss, weights])
    if weight_decay_penalty != 0.0:
      weighted_loss = WeightDecay(
+155 −194

File changed.

Preview size limit exceeded, changes collapsed.

+10 −0
Original line number Diff line number Diff line
@@ -511,3 +511,13 @@ class TestLayers(test_util.TensorFlowTestCase):
      sess.run(tf.global_variables_initializer())
      out_tensor = out_tensor.eval()
      assert out_tensor.shape == (batch_size, n_features)

  def test_reshape_inputs(self):
    """Test that layers can automatically reshape inconsistent inputs."""
    value1 = np.random.uniform(size=(2, 3)).astype(np.float32)
    value2 = np.random.uniform(size=(1, 6, 1)).astype(np.float32)
    with self.test_session() as sess:
      out_tensor = Add()(tf.constant(value1), tf.constant(value2))
      result = out_tensor.eval()
      assert result.shape == (1, 6, 1)
      assert np.array_equal(value1.reshape((1, 6, 1)) + value2, result)