Commit f9bb6081 authored by Ubuntu's avatar Ubuntu
Browse files

Update

parent 97d7f88d
Loading
Loading
Loading
Loading
+36 −35
Original line number Diff line number Diff line
@@ -61,6 +61,7 @@ def test_gat_classification():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_reload():
+4 −2
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy
from deepchem.models.torch_models.torch_model import TorchModel


class GAT(nn.Module):
  """Model for Graph Property Prediction Based on Graph Attention Networks (GAT).

@@ -50,6 +51,7 @@ class GAT(nn.Module):
    This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
    (https://github.com/awslabs/dgl-lifesci) to be installed.
    """

  def __init__(self,
               n_tasks: int,
               graph_attention_layers: list = None,
@@ -168,8 +170,7 @@ class GAT(nn.Module):
        activations=activation,
        n_tasks=out_size,
        predictor_hidden_feats=predictor_hidden_feats,
        predictor_dropout=predictor_dropout
    )
        predictor_dropout=predictor_dropout)

  def forward(self, g):
    """Predict graph labels
@@ -247,6 +248,7 @@ class GATModel(TorchModel):
    This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
    (https://github.com/awslabs/dgl-lifesci) to be installed.
    """

  def __init__(self,
               n_tasks: int,
               graph_attention_layers: list = None,