Commit 15eaed00 authored by mufeili's avatar mufeili
Browse files

Update

parent 0b7c0c64
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -150,8 +150,8 @@ class GraphData:
      src = np.concatenate([src, np.arange(self.num_nodes)])
      dst = np.concatenate([dst, np.arange(self.num_nodes)])

    g = dgl.graph((torch.from_numpy(src).long(),
                   torch.from_numpy(dst).long()),
    g = dgl.graph(
        (torch.from_numpy(src).long(), torch.from_numpy(dst).long()),
        num_nodes=self.num_nodes)
    g.ndata['x'] = torch.from_numpy(self.node_features).float()

+76 −69
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ try:
except:
  has_torch_and_dgl = False


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gcn_regression():
@@ -26,13 +27,18 @@ def test_gcn_regression():

  # initialize models
  n_tasks = len(tasks)
    model = GCNModel(mode='regression', n_tasks=n_tasks, number_atom_features=30, batch_size=10)
  model = GCNModel(
      mode='regression',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gcn_classification():
@@ -55,6 +61,7 @@ def test_gcn_classification():
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gcn_reload():
+134 −125
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy
from deepchem.models.torch_models.torch_model import TorchModel


class GCN(nn.Module):
  """Model for Graph Property Prediction Based on Graph Convolution Networks (GCN).

@@ -65,6 +66,7 @@ class GCN(nn.Module):
    * There are various minor differences in using dropout, skip connection and batch
      normalization.
    """

  def __init__(self,
               n_tasks: int,
               graph_conv_layers: list = None,
@@ -143,7 +145,8 @@ class GCN(nn.Module):
    if activation is not None:
      activation = [activation] * num_gnn_layers

        self.model = DGLGCNPredictor(in_feats=number_atom_features,
    self.model = DGLGCNPredictor(
        in_feats=number_atom_features,
        hidden_feats=graph_conv_layers,
        activation=activation,
        residual=[residual] * num_gnn_layers,
@@ -192,6 +195,7 @@ class GCN(nn.Module):
    else:
      return out


class GCNModel(TorchModel):
  """Model for Graph Property Prediction Based on Graph Convolution Networks (GCN).

@@ -243,6 +247,7 @@ class GCNModel(TorchModel):
    * There are various minor differences in using dropout, skip connection and batch
      normalization.
    """

  def __init__(self,
               n_tasks: int,
               graph_conv_layers: list = None,
@@ -296,7 +301,8 @@ class GCNModel(TorchModel):
        kwargs
            This can include any keyword argument of TorchModel.
        """
        model = GCN(graph_conv_layers=graph_conv_layers,
    model = GCN(
        graph_conv_layers=graph_conv_layers,
        activation=activation,
        residual=residual,
        batchnorm=batchnorm,
@@ -345,7 +351,10 @@ class GCNModel(TorchModel):
      raise ImportError('This class requires dgl.')

    inputs, labels, weights = batch
        dgl_graphs = [graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0]]
    dgl_graphs = [
        graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0]
    ]
    inputs = dgl.batch(dgl_graphs).to(self.device)
        _, labels, weights = super(GCNModel, self)._prepare_batch(([], labels, weights))
    _, labels, weights = super(GCNModel, self)._prepare_batch(([], labels,
                                                               weights))
    return inputs, labels, weights