Commit 74c238a7 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

Code clean-up

parent 63a1a829
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ from deepchem.feat.molecule_featurizers import MACCSKeysFingerprint
from deepchem.feat.molecule_featurizers import MordredDescriptors
from deepchem.feat.molecule_featurizers import Mol2VecFingerprint
from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer
from deepchem.feat.molecule_featurizers import molgan_featurizer
from deepchem.feat.molecule_featurizers import MolGanFeaturizer
from deepchem.feat.molecule_featurizers import OneHotFeaturizer
from deepchem.feat.molecule_featurizers import PubChemFingerprint
from deepchem.feat.molecule_featurizers import RawFeaturizer
+1 −0
Original line number Diff line number Diff line
@@ -14,3 +14,4 @@ from deepchem.feat.molecule_featurizers.rdkit_descriptors import RDKitDescriptor
from deepchem.feat.molecule_featurizers.smiles_to_image import SmilesToImage
from deepchem.feat.molecule_featurizers.smiles_to_seq import SmilesToSeq, create_char_to_idx
from deepchem.feat.molecule_featurizers.mol_graph_conv_featurizer import MolGraphConvFeaturizer
from deepchem.feat.molecule_featurizers.molgan_featurizer import MolGanFeaturizer
+50 −25
Original line number Diff line number Diff line
import logging
import numpy as np
import rdkit.Chem as Chem

from deepchem.utils.typing import RDKitBond, RDKitMol, List
from deepchem.feat.base_classes import MolecularFeaturizer


logger = logging.getLogger(__name__)


@@ -24,19 +23,17 @@ class GraphMatrix:
    graph: GraphMatrix
      A molecule graph with some features.
    """

    def __init__(self, adjacency_matrix: np.ndarray, node_features: np.ndarray):
    def __init__(self, adjacency_matrix: np.ndarray,
                 node_features: np.ndarray):
        self.adjacency_matrix = adjacency_matrix
        self.node_features = node_features


class MolGanFeaturizer(MolecularFeaturizer):
    """This class implements featurizer used with MolGAN de-novo molecular generation based on:
    `MolGAN: An implicit generative model for small molecular graphs`<https://arxiv.org/abs/1805.11973>`_.
    """Featurizer for MolGAN de-novo molecular generation [1]_.
    The default representation is in form of GraphMatrix object.
    It is wrapper for two matrices containing atom and bond type information.
    The class also provides reverse capabilities"""

    def __init__(
        self,
        max_atom_count: int = 9,
@@ -49,17 +46,29 @@ class MolGanFeaturizer(MolecularFeaturizer):
        ----------
        max_atom_count: int, default 9
            Maximum number of atoms used for creation of adjacency matrix.
            Molecules cannot have more atoms than this number; implicit hydrogens do not count.
            Molecules cannot have more atoms than this number
            Implicit hydrogens do not count.
        kekulize: bool, default True
            Should molecules be kekulized; solves number of issues with defeaturization when used.
            Should molecules be kekulized.
            Solves number of issues with defeaturization when used.
        bond_labels: List[RDKitBond]
            List containing types of bond used for generation of adjacency matrix
            List of types of bond used for generation of adjacency matrix
        atom_labels: List[int]
            List of atomic numbers used for generation of node features

        References
        ---------
        .. [1] Nicola De Cao et al. "MolGAN: An implicit generative model
        for small molecular graphs`<https://arxiv.org/abs/1805.11973>`"
        """
        self.max_atom_count = max_atom_count
        self.kekulize = kekulize

        try:
            from rdkit import Chem
        except ModuleNotFoundError:
            raise ImportError("This class requires RDKit to be installed.")

        # bond labels
        if bond_labels is None:
            self.bond_labels = [
@@ -97,15 +106,21 @@ class MolGanFeaturizer(MolecularFeaturizer):
        graph: GraphMatrix
          A molecule graph with some features.
        """

        try:
            from rdkit import Chem
        except ModuleNotFoundError:
            raise ImportError("This method requires RDKit to be installed.")

        if self.kekulize:
            Chem.Kekulize(mol)

        A = np.zeros(shape=(self.max_atom_count, self.max_atom_count), dtype=np.float32)
        A = np.zeros(shape=(self.max_atom_count, self.max_atom_count),
                     dtype=np.float32)
        bonds = mol.GetBonds()

        begin, end = [b.GetBeginAtomIdx() for b in bonds], [
            b.GetEndAtomIdx() for b in bonds
        ]
        begin, end = [b.GetBeginAtomIdx()
                      for b in bonds], [b.GetEndAtomIdx() for b in bonds]
        bond_type = [self.bond_encoder[b.GetBondType()] for b in bonds]

        A[begin, end] = bond_type
@@ -113,18 +128,22 @@ class MolGanFeaturizer(MolecularFeaturizer):

        degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
        X = np.array(
            [self.atom_encoder[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
            + [0] * (self.max_atom_count - mol.GetNumAtoms()),
            [
                self.atom_encoder[atom.GetAtomicNum()]
                for atom in mol.GetAtoms()
            ] + [0] * (self.max_atom_count - mol.GetNumAtoms()),
            dtype=np.int32,
        )
        graph = GraphMatrix(A, X)

        return graph if (degree > 0).all() else None

    def _defeaturize(
        self, graph_matrix: GraphMatrix, sanitize: bool = True, cleanup=True
    ) -> RDKitMol:
        """Recreate RDKitMol from GraphMatrix object. Same object needs to be used for featurization and defeaturization.
    def _defeaturize(self,
                     graph_matrix: GraphMatrix,
                     sanitize: bool = True,
                     cleanup: bool = True) -> RDKitMol:
        """Recreate RDKitMol from GraphMatrix object.
        Same object needs to be used for featurization and defeaturization.

        Parameters
        ----------
@@ -141,6 +160,11 @@ class MolGanFeaturizer(MolecularFeaturizer):
            RDKitMol object representing molecule.
        """

        try:
            from rdkit import Chem
        except ModuleNotFoundError:
            raise ImportError("This method requires RDKit to be installed.")

        node_labels = graph_matrix.node_features
        edge_labels = graph_matrix.adjacency_matrix

@@ -151,9 +175,8 @@ class MolGanFeaturizer(MolecularFeaturizer):

        for start, end in zip(*np.nonzero(edge_labels)):
            if start > end:
                mol.AddBond(
                    int(start), int(end), self.bond_decoder[edge_labels[start, end]]
                )
                mol.AddBond(int(start), int(end),
                            self.bond_decoder[edge_labels[start, end]])

        if sanitize:
            try:
@@ -174,7 +197,9 @@ class MolGanFeaturizer(MolecularFeaturizer):

        return mol

    def defeaturize(self, graphs, log_every_n=1000) -> np.ndarray:
    def defeaturize(self,
                    graphs: GraphMatrix,
                    log_every_n: int = 1000) -> np.ndarray:
        """Calculates molecules from correspoing GraphMatrix objects.
        Parameters
        ----------