Commit 0566d2c9 authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Change to central finite difference

parent 8623479b
Loading
Loading
Loading
Loading
+17 −11
Original line number Diff line number Diff line
@@ -35,8 +35,6 @@ class TestANIRegression(unittest.TestCase):
      use_queue=False,
      mode="regression")

    print(X.shape, y.shape)

    train_dataset = dc.data.NumpyDataset(X, y, n_tasks=1)

    model.fit(train_dataset, nb_epoch=2, checkpoint_interval=100)
@@ -49,18 +47,26 @@ class TestANIRegression(unittest.TestCase):

    new_atomic_nums = np.array([1,1,6])

    grad_approx = scipy.optimize.approx_fprime(
      new_x,
      model.pred_one,
      1e-4,
      new_atomic_nums)
    delta = 1e-2

    grad_exact = model.grad_one(new_x, new_atomic_nums)
    # use central difference since forward difference has a pretty high
    # approximation error

    grad_approx = []

    print(grad_approx)
    print(grad_exact)
    for idx in range(new_x.shape[0]):
      d_new_x_plus = np.array(new_x)
      d_new_x_plus[idx] += delta
      d_new_x_minus = np.array(new_x)
      d_new_x_minus[idx] -= delta      
      dydx = (model.pred_one(d_new_x_plus, new_atomic_nums)-model.pred_one(d_new_x_minus, new_atomic_nums))/(2*delta)
      grad_approx.append(dydx[0])

    grad_approx = np.array(grad_approx)

    grad_exact = model.grad_one(new_x, new_atomic_nums)

    np.testing.assert_array_almost_equal(grad_approx, grad_exact)
    np.testing.assert_array_almost_equal(grad_approx, grad_exact, decimal=3)

if __name__ == '__main__':
  unittest.main()
 No newline at end of file