Commit ec5364b6 authored by peastman's avatar peastman
Browse files

Fixed test case

parent 6f20ecfe
Loading
Loading
Loading
Loading
+8 −5
Original line number Diff line number Diff line
@@ -540,11 +540,14 @@ class TestTensorGraph(unittest.TestCase):
    """Test computing a saliency map."""
    n_tasks = 3
    n_features = 5
    model = dc.models.MultitaskRegressor(
        n_tasks,
        n_features, [20],
        activation_fns=tf.tanh,
        weight_init_stddevs=1.0)
    features = Feature(shape=(None, n_features))
    dense = Dense(
        out_channels=n_tasks, in_layers=[features], activation_fn=tf.tanh)
    label = Label(shape=(None, n_tasks))
    loss = ReduceSquareDifference(in_layers=[dense, label])
    model = dc.models.TensorGraph()
    model.add_output(dense)
    model.set_loss(loss)
    x = np.random.random(n_features)
    s = model.compute_saliency(x)
    assert s.shape[0] == n_tasks