Commit 4b262d78 authored by Carlos Hernandez's avatar Carlos Hernandez
Browse files

add test

parent d03c6339
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -217,6 +217,12 @@ class TestAPI(unittest.TestCase):

    model = dc.models.TensorGraphMultiTaskClassifier(len(tasks), n_features)

    # Test Parameter getting and setting
    param, value = 'weight_decay_penalty_type', 'l2'
    assert model.get_params()[param] is None
    model.set_params(**{param: value})
    assert model.get_params()[param] == value

    # Fit trained model
    model.fit(train_dataset)
    model.save()