Unverified Commit a8ead748 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2508 from VIGNESHinZONE/pagtn

Wrapper function for Pagtn Model
parents 6ab2a69f c4ef2415
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -230,10 +230,10 @@ class PagtnMolGraphFeaturizer(MolecularFeaturizer):

  The featurization is based on `PAGTN model <https://arxiv.org/abs/1905.12712>`_. It is
  slightly more computationally intensive than default Graph Convolution Featuriser, but it
  builds a Molecular Graph connecting all tom pairs accounting for interactions of atom with
  builds a Molecular Graph connecting all atom pairs accounting for interactions of an atom with
  every other atom in the Molecule. According to the paper, interactions between two pairs
  of an atom are dependent on the relative distance between them and calculating the shortest
  path between them.
  of atom are dependent on the relative distance between them and and hence, the function needs
  to calculate the shortest path between them.

  The default node representation is constructed by concatenating the following values,
  and the feature length is 94.
@@ -247,9 +247,9 @@ class PagtnMolGraphFeaturizer(MolecularFeaturizer):
    include ``0 - 5``.
  - Aromaticity: Boolean representing if an atom is aromatic.

  The default edge representation are constructed by concatenating the following values,
  The default edge representation is constructed by concatenating the following values,
  and the feature length is 42. It builds a complete graph where each node is connected to
  every other node. The edge representations are calculated the shortest path between two nodes
  every other node. The edge representations are calculated based on the shortest path between two nodes
  (choose any one if multiple exist). Each bond encountered in the shortest path is used to
  calculate edge features.

+1 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ try:
  from deepchem.models.torch_models import GAT, GATModel
  from deepchem.models.torch_models import GCN, GCNModel
  from deepchem.models.torch_models import LCNN, LCNNModel
  from deepchem.models.torch_models import Pagtn, PagtnModel
except ModuleNotFoundError:
  pass

+1 −2
Original line number Diff line number Diff line
@@ -112,8 +112,7 @@ class SquaredHingeLoss(Loss):
      output, labels = _make_pytorch_shapes_consistent(output, labels)
      return torch.mean(
          torch.pow(
              torch.maximum(1 - torch.multiply(labels, output),
                            torch.tensor(0)), 2),
              torch.max(1 - torch.mul(labels, output), torch.tensor(0.0)), 2),
          dim=-1)

    return loss
+106 −0
Original line number Diff line number Diff line
import unittest
import tempfile

import numpy as np

import deepchem as dc
from deepchem.feat import PagtnMolGraphFeaturizer
from deepchem.models import PagtnModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
  import dgl
  import dgllife
  import torch
  has_torch_and_dgl = True
except:
  has_torch_and_dgl = False


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_pagtn_regression():
  # load datasets
  featurizer = PagtnMolGraphFeaturizer(max_length=5)
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = PagtnModel(mode='regression', n_tasks=n_tasks, batch_size=16)

  # overfit test
  model.fit(dataset, nb_epoch=150)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.65

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_delaney

  tasks, all_dataset, transformers = load_delaney(featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = PagtnModel(mode='regression', n_tasks=n_tasks, batch_size=16)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_pagtn_classification():
  # load datasets
  featurizer = PagtnMolGraphFeaturizer(max_length=5)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = PagtnModel(mode='classification', n_tasks=n_tasks, batch_size=16)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  # test on a small MoleculeNet dataset
  from deepchem.molnet import load_bace_classification

  tasks, all_dataset, transformers = load_bace_classification(
      featurizer=featurizer)
  train_set, _, _ = all_dataset
  model = PagtnModel(mode='classification', n_tasks=len(tasks), batch_size=16)
  model.fit(train_set, nb_epoch=1)


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_pagtn_reload():
  # load datasets
  featurizer = PagtnMolGraphFeaturizer(max_length=5)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model_dir = tempfile.mkdtemp()
  model = PagtnModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=16)

  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  reloaded_model = PagtnModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=16)
  reloaded_model.restore()

  pred_mols = ["CCCC", "CCCCCO", "CCCCC"]
  X_pred = featurizer(pred_mols)
  random_dataset = dc.data.NumpyDataset(X_pred)
  original_pred = model.predict(random_dataset)
  reload_pred = reloaded_model.predict(random_dataset)
  assert np.all(original_pred == reload_pred)
+2 −1
Original line number Diff line number Diff line
@@ -6,3 +6,4 @@ from deepchem.models.torch_models.gat import GAT, GATModel
from deepchem.models.torch_models.gcn import GCN, GCNModel
from deepchem.models.torch_models.mpnn import MPNN, MPNNModel
from deepchem.models.torch_models.lcnn import LCNN, LCNNModel
from deepchem.models.torch_models.pagtn import Pagtn, PagtnModel
 No newline at end of file
Loading