Unverified Commit b5f7242d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2371 from MiloszGrabski/molgan-featurizer

Added MolGAN featurizer
parents 1d760359 65d29d65
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ from deepchem.feat.molecule_featurizers import MACCSKeysFingerprint
from deepchem.feat.molecule_featurizers import MordredDescriptors
from deepchem.feat.molecule_featurizers import Mol2VecFingerprint
from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer
from deepchem.feat.molecule_featurizers import MolGanFeaturizer
from deepchem.feat.molecule_featurizers import OneHotFeaturizer
from deepchem.feat.molecule_featurizers import PubChemFingerprint
from deepchem.feat.molecule_featurizers import RawFeaturizer
@@ -36,7 +37,8 @@ from deepchem.feat.molecule_featurizers import SmilesToSeq, create_char_to_idx
from deepchem.feat.complex_featurizers import RdkitGridFeaturizer
from deepchem.feat.complex_featurizers import NeighborListAtomicCoordinates
from deepchem.feat.complex_featurizers import NeighborListComplexAtomicCoordinates
from deepchem.feat.complex_featurizers import ComplexNeighborListFragmentAtomicCoordinates
from deepchem.feat.complex_featurizers import (
    ComplexNeighborListFragmentAtomicCoordinates,)
from deepchem.feat.complex_featurizers import ContactCircularFingerprint
from deepchem.feat.complex_featurizers import ContactCircularVoxelizer
from deepchem.feat.complex_featurizers import SplifFingerprint
@@ -65,3 +67,6 @@ try:
  from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
except ModuleNotFoundError:
  pass

# support classes
from deepchem.feat.molecule_featurizers import GraphMatrix
 No newline at end of file
+2 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from deepchem.feat.molecule_featurizers.bp_symmetry_function_input import BPSymm
from deepchem.feat.molecule_featurizers.circular_fingerprint import CircularFingerprint
from deepchem.feat.molecule_featurizers.coulomb_matrices import CoulombMatrix
from deepchem.feat.molecule_featurizers.coulomb_matrices import CoulombMatrixEig
from deepchem.feat.molecule_featurizers.molgan_featurizer import GraphMatrix
from deepchem.feat.molecule_featurizers.maccs_keys_fingerprint import MACCSKeysFingerprint
from deepchem.feat.molecule_featurizers.mordred_descriptors import MordredDescriptors
from deepchem.feat.molecule_featurizers.mol2vec_fingerprint import Mol2VecFingerprint
@@ -14,3 +15,4 @@ from deepchem.feat.molecule_featurizers.rdkit_descriptors import RDKitDescriptor
from deepchem.feat.molecule_featurizers.smiles_to_image import SmilesToImage
from deepchem.feat.molecule_featurizers.smiles_to_seq import SmilesToSeq, create_char_to_idx
from deepchem.feat.molecule_featurizers.mol_graph_conv_featurizer import MolGraphConvFeaturizer
from deepchem.feat.molecule_featurizers.molgan_featurizer import MolGanFeaturizer
+254 −0
Original line number Diff line number Diff line
import logging
import numpy as np
from deepchem.utils.typing import RDKitBond, RDKitMol, List
from deepchem.feat.base_classes import MolecularFeaturizer

logger = logging.getLogger(__name__)


class GraphMatrix:
  """
  This is class used to store data for MolGAN neural networks.

  Parameters
  ----------
  node_features: np.ndarray
    Node feature matrix with shape [num_nodes, num_node_features]
  edge_features: np.ndarray,
    Edge feature matrix with shape [num_nodes, num_nodes]

  Returns
  -------
  graph: GraphMatrix
    A molecule graph with some features.
  """

  def __init__(self, adjacency_matrix: np.ndarray, node_features: np.ndarray):
    self.adjacency_matrix = adjacency_matrix
    self.node_features = node_features


class MolGanFeaturizer(MolecularFeaturizer):
  """
  Featurizer for MolGAN de-novo molecular generation [1]_.
  The default representation is in form of GraphMatrix object.
  It is wrapper for two matrices containing atom and bond type information.
  The class also provides reverse capabilities."""

  def __init__(
      self,
      max_atom_count: int = 9,
      kekulize: bool = True,
      bond_labels: List[RDKitBond] = None,
      atom_labels: List[int] = None,
  ):
    """
    Parameters
    ----------
    max_atom_count: int, default 9
      Maximum number of atoms used for creation of adjacency matrix.
      Molecules cannot have more atoms than this number
      Implicit hydrogens do not count.
    kekulize: bool, default True
      Should molecules be kekulized.
      Solves number of issues with defeaturization when used.
    bond_labels: List[RDKitBond]
      List of types of bond used for generation of adjacency matrix
    atom_labels: List[int]
      List of atomic numbers used for generation of node features

    References
    ---------
    .. [1] Nicola De Cao et al. "MolGAN: An implicit generative model
    for small molecular graphs`<https://arxiv.org/abs/1805.11973>`"
    """

    self.max_atom_count = max_atom_count
    self.kekulize = kekulize

    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This class requires RDKit to be installed.")

    # bond labels
    if bond_labels is None:
      self.bond_labels = [
          Chem.rdchem.BondType.ZERO,
          Chem.rdchem.BondType.SINGLE,
          Chem.rdchem.BondType.DOUBLE,
          Chem.rdchem.BondType.TRIPLE,
          Chem.rdchem.BondType.AROMATIC,
      ]
    else:
      self.bond_labels = bond_labels

    # atom labels
    if atom_labels is None:
      self.atom_labels = [0, 6, 7, 8, 9]  # C,N,O,F
    else:
      self.atom_labels = atom_labels

    # create bond encoders and decoders
    self.bond_encoder = {l: i for i, l in enumerate(self.bond_labels)}
    self.bond_decoder = {i: l for i, l in enumerate(self.bond_labels)}
    # create atom encoders and decoders
    self.atom_encoder = {l: i for i, l in enumerate(self.atom_labels)}
    self.atom_decoder = {i: l for i, l in enumerate(self.atom_labels)}

  def _featurize(self, mol: RDKitMol) -> GraphMatrix:
    """
    Calculate adjacency matrix and nodes features for RDKitMol.
    It strips any chirality and charges

    Parameters
    ----------
    mol: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphMatrix
      A molecule graph with some features.
    """

    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This method requires RDKit to be installed.")

    if self.kekulize:
      Chem.Kekulize(mol)

    A = np.zeros(
        shape=(self.max_atom_count, self.max_atom_count), dtype=np.float32)
    bonds = mol.GetBonds()

    begin, end = [b.GetBeginAtomIdx() for b in bonds], [
        b.GetEndAtomIdx() for b in bonds
    ]
    bond_type = [self.bond_encoder[b.GetBondType()] for b in bonds]

    A[begin, end] = bond_type
    A[end, begin] = bond_type

    degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
    X = np.array(
        [self.atom_encoder[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
        [0] * (self.max_atom_count - mol.GetNumAtoms()),
        dtype=np.int32,
    )
    graph = GraphMatrix(A, X)

    return graph if (degree > 0).all() else None

  def _defeaturize(self,
                   graph_matrix: GraphMatrix,
                   sanitize: bool = True,
                   cleanup: bool = True) -> RDKitMol:
    """
    Recreate RDKitMol from GraphMatrix object.
    Same featurizer need to be used for featurization and defeaturization.
    It only recreates bond and atom types, any kind of additional features
    like chirality or charge are not included.
    Therefore, any checks of type: original_smiles == defeaturized_smiles
    will fail on chiral or charged compounds.

    Parameters
    ----------
    graph_matrix: GraphMatrix
      GraphMatrix object.
    sanitize: bool, default True
      Should RDKit sanitization be included in the process.
    cleanup: bool, default True
      Splits salts and removes compounds with "*" atom types

    Returns
    -------
    mol: RDKitMol object
      RDKitMol object representing molecule.
    """

    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This method requires RDKit to be installed.")

    if not isinstance(graph_matrix, GraphMatrix):
      return None

    node_labels = graph_matrix.node_features
    edge_labels = graph_matrix.adjacency_matrix

    mol = Chem.RWMol()

    for node_label in node_labels:
      mol.AddAtom(Chem.Atom(self.atom_decoder[node_label]))

    for start, end in zip(*np.nonzero(edge_labels)):
      if start > end:
        mol.AddBond(
            int(start), int(end), self.bond_decoder[edge_labels[start, end]])

    if sanitize:
      try:
        Chem.SanitizeMol(mol)
      except Exception:
        mol = None

    if cleanup:
      try:
        smiles = Chem.MolToSmiles(mol)
        smiles = max(smiles.split("."), key=len)
        if "*" not in smiles:
          mol = Chem.MolFromSmiles(smiles)
        else:
          mol = None
      except Exception:
        mol = None

    return mol

  def defeaturize(self, graphs: GraphMatrix,
                  log_every_n: int = 1000) -> np.ndarray:
    """
    Calculates molecules from corresponding GraphMatrix objects.

    Parameters
    ----------
    graphs: GraphMatrix / iterable
      GraphMatrix object or corresponding iterable
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.
      
    Returns
    -------
    features: np.ndarray
      A numpy array containing RDKitMol objext.
    """

    # Special case handling of single molecule
    if isinstance(graphs, GraphMatrix):
      graphs = [graphs]
    else:
      # Convert iterables to list
      graphs = list(graphs)

    molecules = []
    for i, gr in enumerate(graphs):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)

      try:
        molecules.append(self._defeaturize(gr))
      except Exception as e:
        logger.warning(
            "Failed to defeaturize datapoint %d, %s. Appending empty array",
            i,
            gr,
        )
        logger.warning("Exception message: {}".format(e))
        molecules.append(np.array([]))

    molecules = np.asarray(molecules)
    return molecules
+26 −0
Original line number Diff line number Diff line
import unittest
import numpy as np
from deepchem.feat.molecule_featurizers import GraphMatrix


class TestGraphMatrix(unittest.TestCase):

  def test_graph_matrix(self):

    max_atom_count = 5
    atom_array = [7, 7, 7, 8, 8, 8, 9, 6]

    A = np.zeros(shape=(max_atom_count, max_atom_count), dtype=np.float32)
    X = np.array(atom_array, dtype=np.int32)

    graph_matrix = GraphMatrix(adjacency_matrix=A, node_features=X)
    assert isinstance(graph_matrix.adjacency_matrix, np.ndarray)
    assert isinstance(graph_matrix.node_features, np.ndarray)
    assert graph_matrix.adjacency_matrix.dtype == np.float32
    assert graph_matrix.node_features.dtype == np.int32
    assert graph_matrix.adjacency_matrix.shape == A.shape
    assert graph_matrix.node_features.shape == X.shape


if __name__ == '__main__':
  unittest.main()
+101 −0
Original line number Diff line number Diff line
import unittest
from deepchem.feat.molecule_featurizers import MolGanFeaturizer
from deepchem.feat.molecule_featurizers import GraphMatrix


class TestMolganFeaturizer(unittest.TestCase):

  def test_featurizer_smiles(self):
    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This method requires RDKit to be installed.")

    smiles = [
        'Cc1ccccc1CO', 'CC1CCC(C)C(N)C1', 'CCC(N)=O', 'Fc1cccc(F)c1', 'CC(C)F',
        'C1COC2NCCC2C1', 'C1=NCc2ccccc21'
    ]

    invalid_smiles = ['axa', 'xyz', 'inv']

    featurizer = MolGanFeaturizer()
    valid_data = featurizer.featurize(smiles)
    invalid_data = featurizer.featurize(invalid_smiles)

    # test featurization
    valid_graphs = list(
        filter(lambda x: isinstance(x, GraphMatrix), valid_data))
    invalid_graphs = list(
        filter(lambda x: not isinstance(x, GraphMatrix), invalid_data))
    assert len(valid_graphs) == len(smiles)
    assert len(invalid_graphs) == len(invalid_smiles)

    # test defeaturization
    valid_mols = featurizer.defeaturize(valid_graphs)
    invalid_mols = featurizer.defeaturize(invalid_graphs)
    valid_mols = list(
        filter(lambda x: isinstance(x, Chem.rdchem.Mol), valid_mols))
    invalid_mols = list(
        filter(lambda x: not isinstance(x, Chem.rdchem.Mol), invalid_mols))
    assert len(valid_graphs) == len(valid_mols)
    assert len(invalid_graphs) == len(invalid_mols)

    mols = list(map(Chem.MolFromSmiles, smiles))
    redone_smiles = list(map(Chem.MolToSmiles, mols))
    # sanity check; see if something weird does not happen with rdkit
    assert redone_smiles == smiles

    # check if original smiles match defeaturized smiles
    defe_smiles = list(map(Chem.MolToSmiles, valid_mols))
    assert defe_smiles == smiles

  def test_featurizer_rdkit(self):

    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This method requires RDKit to be installed.")

    smiles = [
        'Cc1ccccc1CO', 'CC1CCC(C)C(N)C1', 'CCC(N)=O', 'Fc1cccc(F)c1', 'CC(C)F',
        'C1COC2NCCC2C1', 'C1=NCc2ccccc21'
    ]

    invalid_smiles = ['axa', 'xyz', 'inv']

    valid_molecules = list(map(Chem.MolFromSmiles, smiles))
    invalid_molecules = list(map(Chem.MolFromSmiles, invalid_smiles))

    redone_smiles = list(map(Chem.MolToSmiles, valid_molecules))
    # sanity check; see if something weird does not happen with rdkit
    assert redone_smiles == smiles

    featurizer = MolGanFeaturizer()
    valid_data = featurizer.featurize(valid_molecules)
    invalid_data = featurizer.featurize(invalid_molecules)

    # test featurization
    valid_graphs = list(
        filter(lambda x: isinstance(x, GraphMatrix), valid_data))
    invalid_graphs = list(
        filter(lambda x: not isinstance(x, GraphMatrix), invalid_data))
    assert len(valid_graphs) == len(valid_molecules)
    assert len(invalid_graphs) == len(invalid_molecules)

    # test defeaturization
    valid_mols = featurizer.defeaturize(valid_graphs)
    invalid_mols = featurizer.defeaturize(invalid_graphs)
    valid_mols = list(
        filter(lambda x: isinstance(x, Chem.rdchem.Mol), valid_mols))
    invalid_mols = list(
        filter(lambda x: not isinstance(x, Chem.rdchem.Mol), invalid_mols))
    assert len(valid_mols) == len(valid_graphs)
    assert len(invalid_mols) == len(invalid_graphs)

    # check if original smiles match defeaturized smiles
    defe_smiles = list(map(Chem.MolToSmiles, valid_mols))
    assert defe_smiles == smiles


if __name__ == '__main__':
  unittest.main()
Loading