Commit 7f6cf0e1 authored by ZHENQIN WU's avatar ZHENQIN WU
Browse files

classifier test overfit

parent bcde3826
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
@@ -392,6 +392,36 @@ class TestOverfit(test_util.TensorFlowTestCase):
    scores = model.evaluate(dataset, [classification_metric])
    assert scores[classification_metric.name] > .9

  def test_IRV_multitask_classification_overfit(self):
    """Test IRV classifier overfits tiny data."""
    n_tasks = 5
    n_samples = 10
    n_features = 128
    n_classes = 2
    
    # Generate dummy dataset
    np.random.seed(123)
    ids = np.arange(n_samples)
    X = np.random.randint(2, size=(n_samples, n_features))
    y = np.zeros((n_samples, n_tasks))
    w = np.ones((n_samples, n_tasks))
    dataset = dc.data.NumpyDataset(X, y, w, ids)
    IRV_transformer = dc.trans.IRVTransformer(5, n_tasks, dataset)
    dataset_trans = IRV_transformer.transform(dataset)
    classification_metric = dc.metrics.Metric(
        dc.metrics.accuracy_score, task_averager=np.mean)
    model = dc.models.TensorflowMultiTaskIRVClassifier(
        n_tasks, K=5, learning_rate=0.001, batch_size=n_samples)

    # Fit trained model
    model.fit(dataset_trans)
    model.save()

    # Eval model on train
    scores = model.evaluate(dataset_trans, [classification_metric])
    assert scores[classification_metric.name] > .9


  def test_sklearn_multitask_regression_overfit(self):
    """Test SKLearn singletask-to-multitask overfits tiny regression data."""
    n_tasks = 2
+2 −1
Original line number Diff line number Diff line
@@ -440,7 +440,8 @@ class TestTransformers(unittest.TestCase):
    sims = sorted(sims, reverse=True)
    IRV_transformer = dc.trans.IRVTransformer(10, n_tasks, dataset)
    test_dataset_trans = IRV_transformer.transform(test_dataset)
    dataset_trans = IRV_transformer.transform(dataset)
    assert test_dataset_trans.X.shape == (test_samples, 20*n_tasks)
    assert np.allclose(test_dataset_trans.X[0,:10], sims[:10])
    assert np.allclose(test_dataset_trans.X[0,10:20], [0]*10)
    assert not np.isclose(dataset_trans.X[0,0], 1.)