Commit aa9e28c9 authored by mufeili's avatar mufeili
Browse files

Update

parent bdc73970
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ def test_gcn_regression():
      learning_rate=0.003)

  # overfit test
  model.fit(dataset, nb_epoch=200)
  model.fit(dataset, nb_epoch=300)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

+2 −2
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ def test_mpnn_regression():
  model = MPNNModel(mode='regression', n_tasks=n_tasks, batch_size=10)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

@@ -52,7 +52,7 @@ def test_mpnn_classification():
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

+2 −2
Original line number Diff line number Diff line
@@ -306,6 +306,6 @@ class MPNNModel(TorchModel):
        graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0]
    ]
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(MPNNModel, self)._prepare_batch(
        ([], labels, weights))
    _, labels, weights = super(MPNNModel, self)._prepare_batch(([], labels,
                                                                weights))
    return inputs, labels, weights