Commit 6f6a6741 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixing reshape handling

parent bd29cbcc
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ def test_compute_model_performance_multitask_classifier():
  y = np.stack([y1, y2], axis=1)
  dataset = NumpyDataset(X, y)

  features = layers.Input(shape=(n_data_points // 2, n_features))
  features = layers.Input(shape=(n_features))
  dense = layers.Dense(n_tasks * n_classes)(features)
  logits = layers.Reshape((n_tasks, n_classes))(dense)
  output = layers.Softmax()(logits)