Commit af9644c7 authored by mufeili's avatar mufeili
Browse files

Update

parent 6c7dfac3
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ def test_attentivefp_classification():
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=70)
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

+5 −3
Original line number Diff line number Diff line
@@ -31,10 +31,12 @@ def test_gat_regression():
      mode='regression',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10)
      batch_size=10,
      learning_rate=0.001
  )

  # overfit test
  model.fit(dataset, nb_epoch=150)
  model.fit(dataset, nb_epoch=300)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

@@ -81,7 +83,7 @@ def test_gat_reload():
      batch_size=10,
      learning_rate=0.001)

  model.fit(dataset, nb_epoch=60)
  model.fit(dataset, nb_epoch=80)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

+4 −4
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ def test_gcn_regression():
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
      learning_rate=0.02)
      learning_rate=0.003)

  # overfit test
  model.fit(dataset, nb_epoch=150)
@@ -55,10 +55,10 @@ def test_gcn_classification():
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
      learning_rate=0.001)
      learning_rate=0.0003)

  # overfit test
  model.fit(dataset, nb_epoch=50)
  model.fit(dataset, nb_epoch=70)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

@@ -92,7 +92,7 @@ def test_gcn_reload():
      number_atom_features=30,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
      learning_rate=0.0003)
  reloaded_model.restore()

  pred_mols = ["CCCC", "CCCCCO", "CCCCC"]