Commit 5a17008c authored by peastman's avatar peastman
Browse files

Added test case for restore argument

parent 84ef3959
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -73,3 +73,10 @@ class TestA3C(unittest.TestCase):
    new_a3c.restore()
    action_prob2, value2 = new_a3c.predict([[0]])
    assert value2 == value

    # Do the same thing, only using the "restore" argument to fit().

    new_a3c = dc.rl.A3C(env, TestPolicy(), model_dir=a3c._graph.model_dir)
    new_a3c.fit(0, restore=True)
    action_prob2, value2 = new_a3c.predict([[0]])
    assert value2 == value