Commit b26de50a authored by nd-02110114's avatar nd-02110114
Browse files

add gat models

parent 3892a54a
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -166,7 +166,12 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):

    # Weave style
    # compute partial charges
    try:
      mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
      pass
    except:
      AllChem.ComputeGasteigerCharges(mol)

    dist_matrix = Chem.GetDistanceMatrix(mol)
    chiral_center = Chem.FindMolChiralCenters(mol)
    sssr = Chem.GetSymmSSSR(mol)
+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ from deepchem.models.chemnet_models import Smiles2Vec, ChemCeption
try:
  from deepchem.models.torch_models import TorchModel
  from deepchem.models.torch_models import CGCNN, CGCNNModel
  from deepchem.models.torch_models import GAT, GATModel
except ModuleNotFoundError:
  pass

+37 −0
Original line number Diff line number Diff line
import unittest

from deepchem.feat import MolGraphConvFeaturizer
from deepchem.models import GATModel, losses
from deepchem.models.tests.test_graph_models import get_dataset

try:
  import torch  # noqa
  import torch_geometric  # noqa
  has_pytorch_and_pyg = True
except:
  has_pytorch_and_pyg = False


@unittest.skipIf(not has_pytorch_and_pyg, 'PyTorch and PyTorch Geometric are not installed')
def test_gat_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset('regression', featurizer=featurizer)
  n_tasks = len(tasks)

  # initialize models
  model = GATModel(
      in_node_dim=25,
      hidden_node_dim=64,
      heads=1,
      num_conv=3,
      predicator_hidden_feats=32,
      n_tasks=n_tasks,
      loss=losses.L2Loss(),
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.1
+1 −0
Original line number Diff line number Diff line
# flake8:noqa
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models.torch_models.cgcnn import CGCNN, CGCNNModel
from deepchem.models.torch_models.gat import GAT, GATModel
+9 −6
Original line number Diff line number Diff line
"""
This is a sample implementation for working DGL with DeepChem!
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -161,6 +164,10 @@ class CGCNN(nn.Module):
    n_tasks: int, default 1
      Number of the output size, default to 1.
    """
    try:
      import dgl
    except:
      raise ValueError("This class requires DGL to be installed.")
    super(CGCNN, self).__init__()
    self.embedding = nn.Linear(in_node_dim, hidden_node_dim)
    self.conv_layers = nn.ModuleList([
@@ -169,6 +176,7 @@ class CGCNN(nn.Module):
            edge_dim=in_edge_dim,
            batch_norm=True) for _ in range(num_conv)
    ])
    self.pooling = dgl.mean_nodes
    self.fc = nn.Linear(hidden_node_dim, predicator_hidden_feats)
    self.out = nn.Linear(predicator_hidden_feats, n_tasks)

@@ -186,11 +194,6 @@ class CGCNN(nn.Module):
    out: torch.Tensor
      The output value, the shape is `(batch_size, n_tasks)`.
    """
    try:
      import dgl
    except:
      raise ValueError("This class requires DGL to be installed.")

    graph = dgl_graph
    # embedding node features
    graph.ndata['x'] = self.embedding(graph.ndata['x'])
@@ -200,7 +203,7 @@ class CGCNN(nn.Module):
      graph = conv(graph)

    # pooling
    graph_feat = dgl.mean_nodes(graph, 'x')
    graph_feat = self.pooling(graph, 'x')
    graph_feat = self.fc(graph_feat)
    out = self.out(graph_feat)
    return out
Loading