Commit 97d7f88d authored by mufeili's avatar mufeili Committed by Ubuntu
Browse files

Update

parent 5dbf3b0a
Loading
Loading
Loading
Loading
+56 −51
Original line number Diff line number Diff line
@@ -9,15 +9,16 @@ from deepchem.models import GATModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
  import torch  # noqa
  import torch_geometric  # noqa
  has_pytorch_and_pyg = True
  import dgl
  import dgllife
  import torch
  has_torch_and_dgl = True
except:
  has_pytorch_and_pyg = False
  has_torch_and_dgl = False


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_regression():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
@@ -26,17 +27,20 @@ def test_gat_regression():

  # initialize models
  n_tasks = len(tasks)
  model = GATModel(mode='regression', n_tasks=n_tasks, batch_size=10)
  model = GATModel(
      mode='regression',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10)

  # overfit test
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=300)
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.75
  assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
@@ -48,18 +52,17 @@ def test_gat_classification():
  model = GATModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=150)
  model.fit(dataset, nb_epoch=50)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.70

  assert scores['mean-roc_auc_score'] >= 0.85

@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_reload():
    # load datasets
    featurizer = MolGraphConvFeaturizer()
@@ -72,17 +75,19 @@ def test_gat_reload():
    model = GATModel(
        mode='classification',
        n_tasks=n_tasks,
        number_atom_features=30,
        model_dir=model_dir,
        batch_size=10,
        learning_rate=0.001)

  model.fit(dataset, nb_epoch=150)
    model.fit(dataset, nb_epoch=50)
    scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.70
    assert scores['mean-roc_auc_score'] >= 0.85

    reloaded_model = GATModel(
        mode='classification',
        n_tasks=n_tasks,
        number_atom_features=30,
        model_dir=model_dir,
        batch_size=10,
        learning_rate=0.001)
+324 −209

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1
Original line number Diff line number Diff line
@@ -302,6 +302,7 @@ class GCNModel(TorchModel):
            This can include any keyword argument of TorchModel.
        """
    model = GCN(
        n_tasks=n_tasks,
        graph_conv_layers=graph_conv_layers,
        activation=activation,
        residual=residual,
@@ -309,7 +310,6 @@ class GCNModel(TorchModel):
        dropout=dropout,
        predictor_hidden_feats=predictor_hidden_feats,
        predictor_dropout=predictor_dropout,
        n_tasks=n_tasks,
        mode=mode,
        number_atom_features=number_atom_features,
        n_classes=n_classes,