Commit dad2e84c authored by miaecle's avatar miaecle
Browse files

fix test failure

parent 0a718810
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -116,6 +116,7 @@ class ANIRegression(TensorGraph):
               n_tasks,
               max_atoms,
               exp_loss=False,
               activation_fn='ani',
               layer_structures=[128, 64],
               atom_number_cases=[1, 6, 7, 8, 16],
               dropout_prob=0.,
@@ -134,6 +135,7 @@ class ANIRegression(TensorGraph):
    self.n_tasks = n_tasks
    self.max_atoms = max_atoms
    self.exp_loss = exp_loss
    self.activation_fn = activation_fn
    self.layer_structures = layer_structures
    self.atom_number_cases = atom_number_cases
    self.dropout_prob = dropout_prob
@@ -320,7 +322,7 @@ class ANIRegression(TensorGraph):
          self.max_atoms,
          n_hidden,
          self.atom_number_cases,
          activation='ani',
          activation=self.activation_fn,
          in_layers=[previous_layer, self.atom_numbers])
      dropout = Dropout(self.dropout_prob, in_layers=[Hidden])
      Hiddens.append(dropout)
+2 −1
Original line number Diff line number Diff line
@@ -43,7 +43,8 @@ class TestANIRegression(unittest.TestCase):
        "learning_rate": 0.001,
        "use_queue": False,
        "mode": "regression",
        "model_dir": self.model_dir
        "model_dir": self.model_dir,
        "activation_fn": "relu"
    }

    model = ANIRegression(**self.kwargs)