Unverified Commit 3d257a0c authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2109 from nd-02110114/gat-pyg-2

Implement sample GAT model for working PyG with DeepChem
parents 96de1d14 b7b56fa3
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
"""
Making it easy to import in classes.
"""
# flake8: noqa

# base classes for featurizers
from deepchem.feat.base_classes import Featurizer
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.feat.base_classes import MaterialStructureFeaturizer
from deepchem.feat.base_classes import MaterialCompositionFeaturizer
from deepchem.feat.base_classes import ComplexFeaturizer
from deepchem.feat.base_classes import UserDefinedFeaturizer

from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.graph_features import WeaveFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
@@ -22,6 +26,11 @@ from deepchem.feat.atomic_coordinates import AtomicCoordinates
from deepchem.feat.atomic_coordinates import NeighborListComplexAtomicCoordinates
from deepchem.feat.adjacency_fingerprints import AdjacencyFingerprint
from deepchem.feat.smiles_featurizers import SmilesToSeq, SmilesToImage

# molecule featurizers
from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer

# material featurizers
from deepchem.feat.material_featurizers import ElementPropertyFingerprint
from deepchem.feat.material_featurizers import SineCoulombMatrix
from deepchem.feat.material_featurizers import CGCNNFeaturizer
+55 −0
Original line number Diff line number Diff line
"""
Feature calculations.
"""
import inspect
import logging
import numpy as np
import multiprocessing
@@ -75,6 +76,60 @@ class Featurizer(object):
    """
    raise NotImplementedError('Featurizer is not defined.')

  def __repr__(self) -> str:
    """Convert self to repr representation.

    Returns
    -------
    str
      The string represents the class.

    Examples
    --------
    >>> import deepchem as dc
    >>> dc.feat.CircularFingerprint(size=1024, radius=4)
    CircularFingerprint[radius=4, size=1024, chiral=False, bonds=True, features=False, sparse=False, smiles=False]
    >>> dc.feat.CGCNNFeaturizer()
    CGCNNFeaturizer[radius=8.0, max_neighbors=8, step=0.2]
    """
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
    args_info = ''
    for arg_name in args_names:
      args_info += arg_name + '=' + str(self.__dict__[arg_name]) + ', '
    return self.__class__.__name__ + '[' + args_info[:-2] + ']'

  def __str__(self) -> str:
    """Convert self to str representation.

    Returns
    -------
    str
      The string represents the class.

    Examples
    --------
    >>> import deepchem as dc
    >>> str(dc.feat.CircularFingerprint(size=1024, radius=4))
    'CircularFingerprint_radius_4_size_1024'
    >>> str(dc.feat.CGCNNFeaturizer())
    'CGCNNFeaturizer'
    """
    args_spec = inspect.getfullargspec(self.__init__)  # type: ignore
    args_names = [arg for arg in args_spec.args if arg != 'self']
    args_num = len(args_names)
    args_default_values = [None for _ in range(args_num)]
    if args_spec.defaults is not None:
      defaults = list(args_spec.defaults)
      args_default_values[-len(defaults):] = defaults

    override_args_info = ''
    for arg_name, default in zip(args_names, args_default_values):
      arg_value = self.__dict__[arg_name]
      if default != arg_value:
        override_args_info += '_' + arg_name + '_' + str(arg_value)
    return self.__class__.__name__ + override_args_info


class ComplexFeaturizer(object):
  """"
+2 −0
Original line number Diff line number Diff line
# flake8: noqa
from deepchem.feat.molecule_featurizers.mol_graph_conv_featurizer import MolGraphConvFeaturizer
+196 −0
Original line number Diff line number Diff line
from typing import List, Sequence, Tuple
import numpy as np

from deepchem.utils.typing import RDKitAtom, RDKitBond, RDKitMol
from deepchem.feat.graph_data import GraphData
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.utils.molecule_feature_utils import get_atom_type_one_hot, \
  construct_hydrogen_bonding_info, get_atom_hydrogen_bonding_one_hot, \
  get_atom_is_in_aromatic_one_hot, get_atom_hybridization_one_hot, \
  get_atom_total_num_Hs_one_hot, get_atom_chirality_one_hot, get_atom_formal_charge, \
  get_atom_partial_charge, get_atom_ring_size_one_hot, get_atom_total_degree_one_hot, \
  get_bond_type_one_hot, get_bond_is_in_same_ring_one_hot, get_bond_is_conjugated_one_hot, \
  get_bond_stereo_one_hot


def _construct_atom_feature(atom: RDKitAtom,
                            h_bond_infos: List[Tuple[int, str]],
                            sssr: List[Sequence]) -> List[float]:
  """Construct an atom feature from a RDKit atom object.

  Parameters
  ----------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  h_bond_infos: List[Tuple[int, str]]
    A list of tuple `(atom_index, hydrogen_bonding_type)`.
    Basically, it is expected that this value is the return value of
    `construct_hydrogen_bonding_info`. The `hydrogen_bonding_type`
    value is "Acceptor" or "Donor".
  sssr: List[Sequence]
    The return value of `Chem.GetSymmSSSR(mol)`.
    The value is a sequence of rings.

  Returns
  -------
  List[float]
    A one-hot vector of the atom feature.
  """
  atom_type = get_atom_type_one_hot(atom)
  chirality = get_atom_chirality_one_hot(atom)
  formal_charge = get_atom_formal_charge(atom)
  partial_charge = get_atom_partial_charge(atom)
  ring_size = get_atom_ring_size_one_hot(atom, sssr)
  hybridization = get_atom_hybridization_one_hot(atom)
  acceptor_donor = get_atom_hydrogen_bonding_one_hot(atom, h_bond_infos)
  aromatic = get_atom_is_in_aromatic_one_hot(atom)
  degree = get_atom_total_degree_one_hot(atom)
  total_num = get_atom_total_num_Hs_one_hot(atom)
  return atom_type + chirality + formal_charge + partial_charge + \
    ring_size + hybridization + acceptor_donor + aromatic + degree + total_num


def _construct_bond_feature(bond: RDKitBond) -> List[float]:
  """Construct a bond feature from a RDKit bond object.

  Parameters
  ---------
  bond: rdkit.Chem.rdchem.Bond
    RDKit bond object

  Returns
  -------
  List[float]
    A one-hot vector of the bond feature.
  """
  bond_type = get_bond_type_one_hot(bond)
  same_ring = get_bond_is_in_same_ring_one_hot(bond)
  conjugated = get_bond_is_conjugated_one_hot(bond)
  stereo = get_bond_stereo_one_hot(bond)
  return bond_type + same_ring + conjugated + stereo


class MolGraphConvFeaturizer(MolecularFeaturizer):
  """This class is a featurizer of general graph convolution networks for molecules.

  The default node(atom) and edge(bond) representations are based on
  `WeaveNet paper <https://arxiv.org/abs/1603.00856>`_. If you want to use your own representations,
  you could use this class as a guide to define your original Featurizer. In many cases, it's enough
  to modify return values of `construct_atom_feature` or `construct_bond_feature`.

  The default node representation are constructed by concatenating the following values,
  and the feature length is 39.

  - Atom type: A one-hot vector of this atom, "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "other atoms".
  - Chirality: A one-hot vector of the chirality, "R" or "S".
  - Formal charge: Integer electronic charge.
  - Partial charge: Calculated partial charge.
  - Ring sizes: A one-hot vector of the size (3-8) of rings that include this atom.
  - Hybridization: A one-hot vector of "sp", "sp2", "sp3".
  - Hydrogen bonding: A one-hot vector of whether this atom is a hydrogen bond donor or acceptor.
  - Aromatic: A one-hot vector of whether the atom belongs to an aromatic ring.
  - Degree: A one-hot vector of the degree (0-5) of this atom.
  - Number of Hydrogens: A one-hot vector of the number of hydrogens (0-4) that this atom connected.

  The default edge representation are constructed by concatenating the following values,
  and the feature length is 11.

  - Bond type: A one-hot vector of the bond type, "single", "double", "triple", or "aromatic".
  - Same ring: A one-hot vector of whether the atoms in the pair are in the same ring.
  - Conjugated: A one-hot vector of whether this bond is conjugated or not.
  - Stereo: A one-hot vector of the stereo configuration of a bond.

  If you want to know more details about features, please check the paper [1]_ and
  utilities in deepchem.utils.molecule_feature_utils.py.

  Examples
  --------
  >>> smiles = ["C1CCC1", "C1=CC=CN=C1"]
  >>> featurizer = MolGraphConvFeaturizer()
  >>> out = featurizer.featurize(smiles)
  >>> type(out[0])
  <class 'deepchem.feat.graph_data.GraphData'>
  >>> out[0].num_node_features
  39
  >>> out[0].num_edge_features
  11

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints."
     Journal of computer-aided molecular design 30.8 (2016):595-608.

  Notes
  -----
  This class requires RDKit to be installed.
  """

  def __init__(self, add_self_edges: bool = False):
    """
    Parameters
    ----------
    add_self_edges: bool, default False
      Whether to add self-connected edges or not. If you want to use DGL,
      you sometimes need to add explict self-connected edges.
    """
    self.add_self_edges = add_self_edges

  def _featurize(self, mol: RDKitMol) -> GraphData:
    """Calculate molecule graph features from RDKit mol object.

    Parameters
    ----------
    mol: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphData
      A molecule graph with some features.
    """
    try:
      from rdkit import Chem
      from rdkit.Chem import AllChem
    except ModuleNotFoundError:
      raise ValueError("This method requires RDKit to be installed.")

    # construct atom and bond features
    try:
      mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
    except:
      # If partial charges were not computed
      AllChem.ComputeGasteigerCharges(mol)

    h_bond_infos = construct_hydrogen_bonding_info(mol)
    sssr = Chem.GetSymmSSSR(mol)

    # construct atom (node) feature
    atom_features = np.array(
        [
            _construct_atom_feature(atom, h_bond_infos, sssr)
            for atom in mol.GetAtoms()
        ],
        dtype=np.float,
    )

    # construct edge (bond) information
    src, dest, bond_features = [], [], []
    for bond in mol.GetBonds():
      # add edge list considering a directed graph
      start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
      src += [start, end]
      dest += [end, start]
      bond_features += 2 * [_construct_bond_feature(bond)]

    if self.add_self_edges:
      num_atoms = mol.GetNumAtoms()
      src += [i for i in range(num_atoms)]
      dest += [i for i in range(num_atoms)]
      # add dummy edge features
      bond_fea_length = len(bond_features[0])
      bond_features += num_atoms * [[0 for _ in range(bond_fea_length)]]

    return GraphData(
        node_features=atom_features,
        edge_index=np.array([src, dest], dtype=np.int),
        edge_features=np.array(bond_features, dtype=np.float))
+4 −5
Original line number Diff line number Diff line
import unittest
import pytest
import numpy as np
from deepchem.feat.graph_data import GraphData, BatchGraphData

@@ -38,7 +37,7 @@ class TestGraph(unittest.TestCase):
    assert isinstance(dgl_graph, DGLGraph)

  def test_invalid_graph_data(self):
    with pytest.raises(ValueError):
    with self.assertRaises(ValueError):
      invalid_node_features_type = list(np.random.random_sample((5, 32)))
      edge_index = np.array([
          [0, 1, 2, 2, 3, 4],
@@ -49,7 +48,7 @@ class TestGraph(unittest.TestCase):
          edge_index=edge_index,
      )

    with pytest.raises(ValueError):
    with self.assertRaises(ValueError):
      node_features = np.random.random_sample((5, 32))
      invalid_edge_index_shape = np.array([
          [0, 1, 2, 2, 3, 4],
@@ -60,7 +59,7 @@ class TestGraph(unittest.TestCase):
          edge_index=invalid_edge_index_shape,
      )

    with pytest.raises(ValueError):
    with self.assertRaises(ValueError):
      node_features = np.random.random_sample((5, 5))
      invalid_edge_index_shape = np.array([
          [0, 1, 2, 2, 3, 4],
@@ -72,7 +71,7 @@ class TestGraph(unittest.TestCase):
          edge_index=invalid_edge_index_shape,
      )

    with pytest.raises(TypeError):
    with self.assertRaises(TypeError):
      node_features = np.random.random_sample((5, 32))
      _ = GraphData(node_features=node_features)

Loading