Commit be07198b authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Fix numerical problems with gradient by adding a very small delta

parent ef5bea10
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -3294,9 +3294,12 @@ class ANIFeat(Layer):
    tensor2 = tf.stack([coordinates] * max_atoms, axis=2)

    # Calculate pairwise distance
    # d = tf.sqrt(
        # tf.nn.relu(
            # tf.reduce_sum(tf.squared_difference(tensor1, tensor2), axis=3)))
    d = tf.sqrt(
        tf.nn.relu(
            tf.reduce_sum(tf.squared_difference(tensor1, tensor2), axis=3)))
            tf.reduce_sum(tf.squared_difference(tensor1, tensor2), axis=3)+1e-7)
    # d = tf.reduce_sum(tf.squared_difference(tensor1, tensor2), axis=3)
    # Masking for valid atom index
    d = d * flags
    return d
+19 −9
Original line number Diff line number Diff line
@@ -13,14 +13,24 @@ class TestANIRegression(unittest.TestCase):

  def setUp(self):

    max_atoms = 3

    X = np.array([[1, 5.0, 3.2, 1.1], [6, 1.0, 3.4, -1.1], [1, 2.3, 3.4, 2.2]])

    X = X.reshape((1, X.shape[0], X.shape[1]))
    max_atoms = 4

    X = np.array([
      [
        [1, 5.0, 3.2, 1.1],
        [6, 1.0, 3.4, -1.1],
        [1, 2.3, 3.4, 2.2],
        [0, 0,   0,   0],
      ],
      [
        [8, 2.0, -1.4, -1.1],
        [7, 6.3, 2.4, 3.2],
        [0, 0,   0,   0],
        [0, 0,   0,   0],
      ]
      ])

    y = np.array([2.0])
    y = y.reshape((1, 1))
    y = np.array([2.0, 1.1])

    layer_structures = [128, 128, 64]
    atom_number_cases = [1, 6, 7, 8]
@@ -32,7 +42,7 @@ class TestANIRegression(unittest.TestCase):
        "max_atoms": max_atoms,
        "layer_structures": layer_structures,
        "atom_number_cases": atom_number_cases,
        "batch_size": 1,
        "batch_size": 2,
        "learning_rate": 0.001,
        "use_queue": False,
        "mode": "regression",
+4 −4
Original line number Diff line number Diff line
@@ -60,10 +60,10 @@ def load_roiterberg_ANI(mode="atomization"):
  hdf5files = [
      'ani_gdb_s01.h5',
      'ani_gdb_s02.h5',
      'ani_gdb_s03.h5',
      'ani_gdb_s04.h5',
      'ani_gdb_s05.h5',
      'ani_gdb_s06.h5',
      # 'ani_gdb_s03.h5',
      # 'ani_gdb_s04.h5',
      # 'ani_gdb_s05.h5',
      # 'ani_gdb_s06.h5',
      # 'ani_gdb_s07.h5',
      # 'ani_gdb_s08.h5'
  ]