Commit 885d0d5f authored by peastman's avatar peastman
Browse files

Improved test cases

parent dfdcc270
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ class TestEstimators(unittest.TestCase):
    """Test creating an Estimator from a MultiTaskClassifier."""
    n_samples = 10
    n_features = 3
    n_tasks = 1
    n_tasks = 2

    # Create a dataset and an input function for processing it.

@@ -35,7 +35,7 @@ class TestEstimators(unittest.TestCase):
    # Create an estimator from it.

    x_col = tf.feature_column.numeric_column('x', shape=(n_features,))
    weight_col = tf.feature_column.numeric_column('weights')
    weight_col = tf.feature_column.numeric_column('weights', shape=(n_tasks,))

    def accuracy(labels, predictions, weights):
      return tf.metrics.accuracy(labels, tf.round(predictions), weights)
@@ -46,7 +46,7 @@ class TestEstimators(unittest.TestCase):

    # Train the model.

    estimator.train(input_fn=lambda: input_fn(100), steps=100)
    estimator.train(input_fn=lambda: input_fn(100))

    # Evaluate the model.

@@ -58,7 +58,7 @@ class TestEstimators(unittest.TestCase):
    """Test creating an Estimator from a MultiTaskRegressor."""
    n_samples = 10
    n_features = 3
    n_tasks = 1
    n_tasks = 2

    # Create a dataset and an input function for processing it.

@@ -79,14 +79,14 @@ class TestEstimators(unittest.TestCase):
    # Create an estimator from it.

    x_col = tf.feature_column.numeric_column('x', shape=(n_features,))
    weight_col = tf.feature_column.numeric_column('weights')
    weight_col = tf.feature_column.numeric_column('weights', shape=(n_tasks,))
    metrics = {'error': tf.metrics.mean_absolute_error}
    estimator = model.make_estimator(
        feature_columns=[x_col], weight_column=weight_col, metrics=metrics)

    # Train the model.

    estimator.train(input_fn=lambda: input_fn(100), steps=100)
    estimator.train(input_fn=lambda: input_fn(100))

    # Evaluate the model.