Unverified Commit bca869d8 authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2184 from nd-02110114/fix-cgcnn

Improve featurization speed for MolGraphConvFeaturizer
parents 59e02291 4c2e0101
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -267,6 +267,7 @@ class MolecularFeaturizer(Featurizer):
    for i, mol in enumerate(molecules):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)

      try:
        if isinstance(mol, str):
          # mol must be a RDKit Mol object, so parse a SMILES
@@ -274,10 +275,15 @@ class MolecularFeaturizer(Featurizer):
          # SMILES is unique, so set a canonical order of atoms
          new_order = rdmolfiles.CanonicalRankAtoms(mol)
          mol = rdmolops.RenumberAtoms(mol, new_order)

        features.append(self._featurize(mol))
      except:
      except Exception as e:
        if isinstance(mol, Chem.rdchem.Mol):
          mol = Chem.MolToSmiles(mol)
        logger.warning(
            "Failed to featurize datapoint %d. Appending empty array", i)
            "Failed to featurize datapoint %d, %s. Appending empty array", i,
            mol)
        logger.warning("Exception message: {}".format(e))
        features.append(np.array([]))

    features = np.asarray(features)
+86 −64
Original line number Diff line number Diff line
from typing import List, Sequence, Tuple
from typing import List, Tuple
import numpy as np

from deepchem.utils.typing import RDKitAtom, RDKitBond, RDKitMol
from deepchem.feat.graph_data import GraphData
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.utils.molecule_feature_utils import get_atom_type_one_hot, \
  construct_hydrogen_bonding_info, get_atom_hydrogen_bonding_one_hot, \
  get_atom_is_in_aromatic_one_hot, get_atom_hybridization_one_hot, \
  get_atom_total_num_Hs_one_hot, get_atom_chirality_one_hot, get_atom_formal_charge, \
  get_atom_partial_charge, get_atom_ring_size_one_hot, get_atom_total_degree_one_hot, \
  get_bond_type_one_hot, get_bond_is_in_same_ring_one_hot, get_bond_is_conjugated_one_hot, \
  get_bond_stereo_one_hot


def _construct_atom_feature(atom: RDKitAtom,
                            h_bond_infos: List[Tuple[int, str]],
                            sssr: List[Sequence]) -> List[float]:
from deepchem.utils.molecule_feature_utils import get_atom_type_one_hot
from deepchem.utils.molecule_feature_utils import construct_hydrogen_bonding_info
from deepchem.utils.molecule_feature_utils import get_atom_hydrogen_bonding_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_hybridization_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_total_num_Hs_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_is_in_aromatic_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_chirality_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge
from deepchem.utils.molecule_feature_utils import get_atom_partial_charge
from deepchem.utils.molecule_feature_utils import get_atom_total_degree_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_type_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_in_same_ring_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_conjugated_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_stereo_one_hot


def _construct_atom_feature(
    atom: RDKitAtom, h_bond_infos: List[Tuple[int, str]], use_chirality: bool,
    use_partial_charge: bool) -> np.ndarray:
  """Construct an atom feature from a RDKit atom object.

  Parameters
@@ -27,30 +34,39 @@ def _construct_atom_feature(atom: RDKitAtom,
    Basically, it is expected that this value is the return value of
    `construct_hydrogen_bonding_info`. The `hydrogen_bonding_type`
    value is "Acceptor" or "Donor".
  sssr: List[Sequence]
    The return value of `Chem.GetSymmSSSR(mol)`.
    The value is a sequence of rings.
  use_chirality: bool
    Whether to use chirality information or not.
  use_partial_charge: bool
    Whether to use partial charge data or not.

  Returns
  -------
  List[float]
  np.ndarray
    A one-hot vector of the atom feature.
  """
  atom_type = get_atom_type_one_hot(atom)
  chirality = get_atom_chirality_one_hot(atom)
  formal_charge = get_atom_formal_charge(atom)
  partial_charge = get_atom_partial_charge(atom)
  ring_size = get_atom_ring_size_one_hot(atom, sssr)
  hybridization = get_atom_hybridization_one_hot(atom)
  acceptor_donor = get_atom_hydrogen_bonding_one_hot(atom, h_bond_infos)
  aromatic = get_atom_is_in_aromatic_one_hot(atom)
  degree = get_atom_total_degree_one_hot(atom)
  total_num = get_atom_total_num_Hs_one_hot(atom)
  return atom_type + chirality + formal_charge + partial_charge + \
    ring_size + hybridization + acceptor_donor + aromatic + degree + total_num
  total_num_Hs = get_atom_total_num_Hs_one_hot(atom)
  atom_feat = np.concatenate([
      atom_type, formal_charge, hybridization, acceptor_donor, aromatic, degree,
      total_num_Hs
  ])

  if use_chirality:
    chirality = get_atom_chirality_one_hot(atom)
    atom_feat = np.concatenate([atom_feat, chirality])

  if use_partial_charge:
    partial_charge = get_atom_partial_charge(atom)
    atom_feat = np.concatenate([atom_feat, partial_charge])
  return atom_feat


def _construct_bond_feature(bond: RDKitBond) -> List[float]:
def _construct_bond_feature(bond: RDKitBond) -> np.ndarray:
  """Construct a bond feature from a RDKit bond object.

  Parameters
@@ -60,14 +76,14 @@ def _construct_bond_feature(bond: RDKitBond) -> List[float]:

  Returns
  -------
  List[float]
  np.ndarray
    A one-hot vector of the bond feature.
  """
  bond_type = get_bond_type_one_hot(bond)
  same_ring = get_bond_is_in_same_ring_one_hot(bond)
  conjugated = get_bond_is_conjugated_one_hot(bond)
  stereo = get_bond_stereo_one_hot(bond)
  return bond_type + same_ring + conjugated + stereo
  return np.concatenate([bond_type, same_ring, conjugated, stereo])


class MolGraphConvFeaturizer(MolecularFeaturizer):
@@ -79,18 +95,17 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
  to modify return values of `construct_atom_feature` or `construct_bond_feature`.

  The default node representation are constructed by concatenating the following values,
  and the feature length is 39.
  and the feature length is 30.

  - Atom type: A one-hot vector of this atom, "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "other atoms".
  - Chirality: A one-hot vector of the chirality, "R" or "S".
  - Formal charge: Integer electronic charge.
  - Partial charge: Calculated partial charge.
  - Ring sizes: A one-hot vector of the size (3-8) of rings that include this atom.
  - Hybridization: A one-hot vector of "sp", "sp2", "sp3".
  - Hydrogen bonding: A one-hot vector of whether this atom is a hydrogen bond donor or acceptor.
  - Aromatic: A one-hot vector of whether the atom belongs to an aromatic ring.
  - Degree: A one-hot vector of the degree (0-5) of this atom.
  - Number of Hydrogens: A one-hot vector of the number of hydrogens (0-4) that this atom connected.
  - Chirality: A one-hot vector of the chirality, "R" or "S". (Optional)
  - Partial charge: Calculated partial charge. (Optional)

  The default edge representation are constructed by concatenating the following values,
  and the feature length is 11.
@@ -106,12 +121,12 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
  Examples
  --------
  >>> smiles = ["C1CCC1", "C1=CC=CN=C1"]
  >>> featurizer = MolGraphConvFeaturizer()
  >>> featurizer = MolGraphConvFeaturizer(use_edges=True)
  >>> out = featurizer.featurize(smiles)
  >>> type(out[0])
  <class 'deepchem.feat.graph_data.GraphData'>
  >>> out[0].num_node_features
  39
  30
  >>> out[0].num_edge_features
  11

@@ -125,21 +140,32 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
  This class requires RDKit to be installed.
  """

  def __init__(self, add_self_edges: bool = False):
  def __init__(self,
               use_edges: bool = False,
               use_chirality: bool = False,
               use_partial_charge: bool = False):
    """
    Parameters
    ----------
    add_self_edges: bool, default False
      Whether to add self-connected edges or not. If you want to use DGL,
      you sometimes need to add explicit self-connected edges.
    use_edges: bool, default False
      Whether to use edge features or not.
    use_chirality: bool, default False
      Whether to use chirality information or not.
      If True, featurization becomes slow.
    use_partial_charge: bool, default False
      Whether to use partial charge data or not.
      If True, this featurizer computes gasteiger charges.
      Therefore, there is a possibility to fail to featurize for some molecules
      and featurization becomes slow.
    """
    try:
      from rdkit import Chem  # noqa
      from rdkit.Chem import AllChem  # noqa
    except ModuleNotFoundError:
      raise ValueError("This method requires RDKit to be installed.")

    self.add_self_edges = add_self_edges
    self.use_edges = use_edges
    self.use_partial_charge = use_partial_charge
    self.use_chirality = use_chirality

  def _featurize(self, mol: RDKitMol) -> GraphData:
    """Calculate molecule graph features from RDKit mol object.
@@ -154,46 +180,42 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
    graph: GraphData
      A molecule graph with some features.
    """
    from rdkit import Chem
    from rdkit.Chem import AllChem

    # construct atom and bond features
    if self.use_partial_charge:
      try:
        mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
      except:
        # If partial charges were not computed
        from rdkit.Chem import AllChem
        AllChem.ComputeGasteigerCharges(mol)

    h_bond_infos = construct_hydrogen_bonding_info(mol)
    sssr = Chem.GetSymmSSSR(mol)

    # construct atom (node) feature
    atom_features = np.array(
    h_bond_infos = construct_hydrogen_bonding_info(mol)
    atom_features = np.asarray(
        [
            _construct_atom_feature(atom, h_bond_infos, sssr)
            _construct_atom_feature(atom, h_bond_infos, self.use_chirality,
                                    self.use_partial_charge)
            for atom in mol.GetAtoms()
        ],
        dtype=np.float,
    )

    # construct edge (bond) information
    src, dest, bond_features = [], [], []
    # construct edge (bond) index
    src, dest = [], []
    for bond in mol.GetBonds():
      # add edge list considering a directed graph
      start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
      src += [start, end]
      dest += [end, start]
      bond_features += 2 * [_construct_bond_feature(bond)]

    if self.add_self_edges:
      num_atoms = mol.GetNumAtoms()
      src += [i for i in range(num_atoms)]
      dest += [i for i in range(num_atoms)]
      # add dummy edge features
      bond_fea_length = len(bond_features[0])
      bond_features += num_atoms * [[0 for _ in range(bond_fea_length)]]
    # construct edge (bond) feature
    bond_features = None  # deafult None
    if self.use_edges:
      bond_features = []
      for bond in mol.GetBonds():
        bond_features += 2 * [_construct_bond_feature(bond)]
      bond_features = np.asarray(bond_features, dtype=np.float)

    return GraphData(
        node_features=atom_features,
        edge_index=np.array([src, dest], dtype=np.int),
        edge_features=np.array(bond_features, dtype=np.float))
        edge_index=np.asarray([src, dest], dtype=np.int),
        edge_features=bond_features)
+40 −10
Original line number Diff line number Diff line
@@ -13,30 +13,60 @@ class TestMolGraphConvFeaturizer(unittest.TestCase):

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 39
    assert graph_feat[0].num_node_features == 30
    assert graph_feat[0].num_edges == 12
    assert graph_feat[0].num_edge_features == 11

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 39
    assert graph_feat[1].num_node_features == 30
    assert graph_feat[1].num_edges == 44
    assert graph_feat[1].num_edge_features == 11

  def test_featurizer_with_self_loop(self):
  def test_featurizer_with_use_edge(self):
    smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
    featurizer = MolGraphConvFeaturizer(add_self_edges=True)
    featurizer = MolGraphConvFeaturizer(use_edges=True)
    graph_feat = featurizer.featurize(smiles)
    assert len(graph_feat) == 2

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 39
    assert graph_feat[0].num_edges == 12 + 6
    assert graph_feat[0].num_node_features == 30
    assert graph_feat[0].num_edges == 12
    assert graph_feat[0].num_edge_features == 11

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 39
    assert graph_feat[1].num_edges == 44 + 22
    assert graph_feat[1].num_node_features == 30
    assert graph_feat[1].num_edges == 44
    assert graph_feat[1].num_edge_features == 11

  def test_featurizer_with_use_chirality(self):
    smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
    featurizer = MolGraphConvFeaturizer(use_chirality=True)
    graph_feat = featurizer.featurize(smiles)
    assert len(graph_feat) == 2

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 32
    assert graph_feat[0].num_edges == 12

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 32
    assert graph_feat[1].num_edges == 44

  def test_featurizer_with_use_partial_charge(self):
    smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
    featurizer = MolGraphConvFeaturizer(use_partial_charge=True)
    graph_feat = featurizer.featurize(smiles)
    assert len(graph_feat) == 2

    # assert "C1=CC=CN=C1"
    assert graph_feat[0].num_nodes == 6
    assert graph_feat[0].num_node_features == 31
    assert graph_feat[0].num_edges == 12

    # assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
    assert graph_feat[1].num_nodes == 22
    assert graph_feat[1].num_node_features == 31
    assert graph_feat[1].num_edges == 44
+6 −6
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ class GAT(nn.Module):

  def __init__(
      self,
      in_node_dim: int = 39,
      in_node_dim: int = 30,
      hidden_node_dim: int = 32,
      heads: int = 1,
      dropout: float = 0.0,
@@ -64,8 +64,8 @@ class GAT(nn.Module):
    """
    Parameters
    ----------
    in_node_dim: int, default 39
      The length of the initial node feature vectors. The 39 is
    in_node_dim: int, default 30
      The length of the initial node feature vectors. The 30 is
      based on `MolGraphConvFeaturizer`.
    hidden_node_dim: int, default 32
      The length of the hidden node feature vectors.
@@ -178,7 +178,7 @@ class GATModel(TorchModel):
  """

  def __init__(self,
               in_node_dim: int = 39,
               in_node_dim: int = 30,
               hidden_node_dim: int = 32,
               heads: int = 1,
               dropout: float = 0.0,
@@ -193,8 +193,8 @@ class GATModel(TorchModel):

    Parameters
    ----------
    in_node_dim: int, default 39
      The length of the initial node feature vectors. The 39 is
    in_node_dim: int, default 30
      The length of the initial node feature vectors. The 30 is
      based on `MolGraphConvFeaturizer`.
    hidden_node_dim: int, default 32
      The length of the hidden node feature vectors.
+0 −1
Original line number Diff line number Diff line
@@ -71,7 +71,6 @@ from deepchem.utils.molecule_feature_utils import get_atom_total_num_Hs_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_chirality_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge
from deepchem.utils.molecule_feature_utils import get_atom_partial_charge
from deepchem.utils.molecule_feature_utils import get_atom_ring_size_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_total_degree_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_type_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_in_same_ring_one_hot
Loading