Commit 991c987e authored by Milosz Grabski's avatar Milosz Grabski
Browse files

unit tests

parent 41575193
Loading
Loading
Loading
Loading
+143 −144
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import numpy as np
from deepchem.utils.typing import RDKitBond, RDKitMol, List
from deepchem.feat.base_classes import MolecularFeaturizer


logger = logging.getLogger(__name__)


@@ -24,8 +23,7 @@ class GraphMatrix:
      A molecule graph with some features.
    """

    def __init__(self, adjacency_matrix: np.ndarray,
                 node_features: np.ndarray):
  def __init__(self, adjacency_matrix: np.ndarray, node_features: np.ndarray):
    self.adjacency_matrix = adjacency_matrix
    self.node_features = node_features

@@ -35,6 +33,7 @@ class MolGanFeaturizer(MolecularFeaturizer):
    The default representation is in form of GraphMatrix object.
    It is wrapper for two matrices containing atom and bond type information.
    The class also provides reverse capabilities """

  def __init__(
      self,
      max_atom_count: int = 9,
@@ -42,7 +41,6 @@ class MolGanFeaturizer(MolecularFeaturizer):
      bond_labels: List[RDKitBond] = None,
      atom_labels: List[int] = None,
  ):

    """
        Parameters
        ----------
@@ -118,12 +116,13 @@ class MolGanFeaturizer(MolecularFeaturizer):
    if self.kekulize:
      Chem.Kekulize(mol)

        A = np.zeros(shape=(self.max_atom_count, self.max_atom_count),
                     dtype=np.float32)
    A = np.zeros(
        shape=(self.max_atom_count, self.max_atom_count), dtype=np.float32)
    bonds = mol.GetBonds()

        begin, end = [b.GetBeginAtomIdx()
                      for b in bonds], [b.GetEndAtomIdx() for b in bonds]
    begin, end = [b.GetBeginAtomIdx() for b in bonds], [
        b.GetEndAtomIdx() for b in bonds
    ]
    bond_type = [self.bond_encoder[b.GetBondType()] for b in bonds]

    A[begin, end] = bond_type
@@ -131,10 +130,8 @@ class MolGanFeaturizer(MolecularFeaturizer):

    degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
    X = np.array(
            [
                self.atom_encoder[atom.GetAtomicNum()]
                for atom in mol.GetAtoms()
            ] + [0] * (self.max_atom_count - mol.GetNumAtoms()),
        [self.atom_encoder[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
        [0] * (self.max_atom_count - mol.GetNumAtoms()),
        dtype=np.int32,
    )
    graph = GraphMatrix(A, X)
@@ -168,6 +165,9 @@ class MolGanFeaturizer(MolecularFeaturizer):
    except ModuleNotFoundError:
      raise ImportError("This method requires RDKit to be installed.")

    if not isinstance(graph_matrix, GraphMatrix):
      return None

    node_labels = graph_matrix.node_features
    edge_labels = graph_matrix.adjacency_matrix

@@ -178,8 +178,8 @@ class MolGanFeaturizer(MolecularFeaturizer):

    for start, end in zip(*np.nonzero(edge_labels)):
      if start > end:
                mol.AddBond(int(start), int(end),
                            self.bond_decoder[edge_labels[start, end]])
        mol.AddBond(
            int(start), int(end), self.bond_decoder[edge_labels[start, end]])

    if sanitize:
      try:
@@ -200,8 +200,7 @@ class MolGanFeaturizer(MolecularFeaturizer):

    return mol

    def defeaturize(self,
                    graphs: GraphMatrix,
  def defeaturize(self, graphs: GraphMatrix,
                  log_every_n: int = 1000) -> np.ndarray:
    """Calculates molecules from corresponding GraphMatrix objects.

+55 −2
Original line number Diff line number Diff line
@@ -6,6 +6,11 @@ from deepchem.feat.molecule_featurizers import GraphMatrix
class TestMolganFeaturizer(unittest.TestCase):

  def test_featurizer_smiles(self):
    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This method requires RDKit to be installed.")

    smiles = [
        'C#C[C@@]1(C)[NH2+][CH+]N[C@@H]1C', '[NH-][CH+]Oc1nnon1',
        'Cn1ncc2c1C=CC2', 'O=C[C@@H]1[C@H]2[C@H]3[C@@H]1[C@H]1[C@@H]2[N@@H+]31',
@@ -16,9 +21,57 @@ class TestMolganFeaturizer(unittest.TestCase):

    featurizer = MolGanFeaturizer()
    data = featurizer.featurize(smiles)
    incorrect = list(filter(lambda x: not isinstance(x, GraphMatrix), data))

    # test featurization
    valid_graph = list(filter(lambda x: isinstance(x, GraphMatrix), data))
    invalid_graph = list(filter(lambda x: not isinstance(x, GraphMatrix), data))
    assert len(data) == len(smiles)
    assert len(incorrect) == 1
    assert len(valid_graph) == len(smiles) - 1
    assert len(invalid_graph) == 1

    # test defeaturization
    mols = featurizer.defeaturize(data)
    valid_mols = list(filter(lambda x: isinstance(x, Chem.rdchem.Mol), mols))
    invalid_mols = list(
        filter(lambda x: not isinstance(x, Chem.rdchem.Mol), mols))
    assert len(valid_mols) == len(valid_mols)
    assert len(invalid_mols) == len(invalid_graph)

    def test_featurizer_rdkit(self):

      try:
        from rdkit import Chem
      except ModuleNotFoundError:
        raise ImportError("This method requires RDKit to be installed.")

      smiles = [
          'C#C[C@@]1(C)[NH2+][CH+]N[C@@H]1C', '[NH-][CH+]Oc1nnon1',
          'Cn1ncc2c1C=CC2',
          'O=C[C@@H]1[C@H]2[C@H]3[C@@H]1[C@H]1[C@@H]2[N@@H+]31',
          'C#Cc1[nH]cnc1C#N', 'N#C[C@]1(N)CO[C@H]1CO',
          'C[C@@H]1C(NO)C[C@@H]2O[C@@H]21', 'OC1CC=CC1', 'Cn1c(O)ccc1CO',
          '[NH-]C1OC[C@@]2(C=O)N[C@@H]12', 'incorrect smiles'
      ]
      mols = list(map(Chem.MolFromSmiles, smiles))
      featurizer = MolGanFeaturizer()
      data = featurizer.featurize(mols)

      # test featurization

      valid_graph = list(filter(lambda x: isinstance(x, GraphMatrix), data))
      invalid_graph = list(
          filter(lambda x: not isinstance(x, GraphMatrix), data))
      assert len(data) == len(smiles)
      assert len(valid_graph) == len(smiles) - 1
      assert len(invalid_graph) == 1

      # test defeaturization
      mols = featurizer.defeaturize(data)
      valid_mols = list(filter(lambda x: isinstance(x, Chem.rdchem.Mol), mols))
      invalid_mols = list(
          filter(lambda x: not isinstance(x, Chem.rdchem.Mol), mols))
      assert len(valid_mols) == len(valid_mols)
      assert len(invalid_mols) == len(invalid_graph)


if __name__ == '__main__':