Commit b558d6e4 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

unittest/docstrings corrections

parent 991c987e
Loading
Loading
Loading
Loading
+77 −72
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ class MolGanFeaturizer(MolecularFeaturizer):
  """Featurizer for MolGAN de-novo molecular generation [1]_.
  The default representation is in form of GraphMatrix object.
  It is wrapper for two matrices containing atom and bond type information.
    The class also provides reverse capabilities """
  The class also provides reverse capabilities."""

  def __init__(
      self,
@@ -97,6 +97,7 @@ class MolGanFeaturizer(MolecularFeaturizer):

  def _featurize(self, mol: RDKitMol) -> GraphMatrix:
    """Calculate adjacency matrix and nodes features for RDKitMol.
    It strips any chirality and charges

    Parameters
    ----------
@@ -143,7 +144,11 @@ class MolGanFeaturizer(MolecularFeaturizer):
                   sanitize: bool = True,
                   cleanup: bool = True) -> RDKitMol:
    """Recreate RDKitMol from GraphMatrix object.
        Same object needs to be used for featurization and defeaturization.
    Same featurizer need to be used for featurization and defeaturization.
    It only recreates bond and atom types, any kind of additional features
    like chirality or charge are not included.
    Therefore, any checks of type: original_smiles == defeaturized_smiles
    will fail on chiral or charged compounds.

    Parameters
    ----------
+75 −52
Original line number Diff line number Diff line
@@ -12,30 +12,42 @@ class TestMolganFeaturizer(unittest.TestCase):
      raise ImportError("This method requires RDKit to be installed.")

    smiles = [
        'C#C[C@@]1(C)[NH2+][CH+]N[C@@H]1C', '[NH-][CH+]Oc1nnon1',
        'Cn1ncc2c1C=CC2', 'O=C[C@@H]1[C@H]2[C@H]3[C@@H]1[C@H]1[C@@H]2[N@@H+]31',
        'C#Cc1[nH]cnc1C#N', 'N#C[C@]1(N)CO[C@H]1CO',
        'C[C@@H]1C(NO)C[C@@H]2O[C@@H]21', 'OC1CC=CC1', 'Cn1c(O)ccc1CO',
        '[NH-]C1OC[C@@]2(C=O)N[C@@H]12', 'incorrect smiles'
        'Cc1ccccc1CO', 'CC1CCC(C)C(N)C1', 'CCC(N)=O', 'Fc1cccc(F)c1', 'CC(C)F',
        'C1COC2NCCC2C1', 'C1=NCc2ccccc21'
    ]

    invalid_smiles = ['axa', 'xyz', 'inv']

    featurizer = MolGanFeaturizer()
    data = featurizer.featurize(smiles)
    valid_data = featurizer.featurize(smiles)
    invalid_data = featurizer.featurize(invalid_smiles)

    # test featurization
    valid_graph = list(filter(lambda x: isinstance(x, GraphMatrix), data))
    invalid_graph = list(filter(lambda x: not isinstance(x, GraphMatrix), data))
    assert len(data) == len(smiles)
    assert len(valid_graph) == len(smiles) - 1
    assert len(invalid_graph) == 1
    valid_graphs = list(
        filter(lambda x: isinstance(x, GraphMatrix), valid_data))
    invalid_graphs = list(
        filter(lambda x: not isinstance(x, GraphMatrix), invalid_data))
    assert len(valid_graphs) == len(smiles)
    assert len(invalid_graphs) == len(invalid_smiles)

    # test defeaturization
    mols = featurizer.defeaturize(data)
    valid_mols = list(filter(lambda x: isinstance(x, Chem.rdchem.Mol), mols))
    valid_mols = featurizer.defeaturize(valid_graphs)
    invalid_mols = featurizer.defeaturize(invalid_graphs)
    valid_mols = list(
        filter(lambda x: isinstance(x, Chem.rdchem.Mol), valid_mols))
    invalid_mols = list(
        filter(lambda x: not isinstance(x, Chem.rdchem.Mol), mols))
    assert len(valid_mols) == len(valid_mols)
    assert len(invalid_mols) == len(invalid_graph)
        filter(lambda x: not isinstance(x, Chem.rdchem.Mol), invalid_mols))
    assert len(valid_graphs) == len(valid_mols)
    assert len(invalid_graphs) == len(invalid_mols)

    mols = list(map(Chem.MolFromSmiles, smiles))
    redone_smiles = list(map(Chem.MolToSmiles, mols))
    # sanity check; see if something weird does not happen with rdkit
    assert redone_smiles == smiles

    # check if original smiles match defeaturized smiles
    defe_smiles = list(map(Chem.MolToSmiles, valid_mols))
    assert defe_smiles == smiles

  def test_featurizer_rdkit(self):

@@ -45,33 +57,44 @@ class TestMolganFeaturizer(unittest.TestCase):
      raise ImportError("This method requires RDKit to be installed.")

    smiles = [
          'C#C[C@@]1(C)[NH2+][CH+]N[C@@H]1C', '[NH-][CH+]Oc1nnon1',
          'Cn1ncc2c1C=CC2',
          'O=C[C@@H]1[C@H]2[C@H]3[C@@H]1[C@H]1[C@@H]2[N@@H+]31',
          'C#Cc1[nH]cnc1C#N', 'N#C[C@]1(N)CO[C@H]1CO',
          'C[C@@H]1C(NO)C[C@@H]2O[C@@H]21', 'OC1CC=CC1', 'Cn1c(O)ccc1CO',
          '[NH-]C1OC[C@@]2(C=O)N[C@@H]12', 'incorrect smiles'
        'Cc1ccccc1CO', 'CC1CCC(C)C(N)C1', 'CCC(N)=O', 'Fc1cccc(F)c1', 'CC(C)F',
        'C1COC2NCCC2C1', 'C1=NCc2ccccc21'
    ]
      mols = list(map(Chem.MolFromSmiles, smiles))

    invalid_smiles = ['axa', 'xyz', 'inv']

    valid_molecules = list(map(Chem.MolFromSmiles, smiles))
    invalid_molecules = list(map(Chem.MolFromSmiles, invalid_smiles))

    redone_smiles = list(map(Chem.MolToSmiles, valid_molecules))
    # sanity check; see if something weird does not happen with rdkit
    assert redone_smiles == smiles

    featurizer = MolGanFeaturizer()
      data = featurizer.featurize(mols)
    valid_data = featurizer.featurize(valid_molecules)
    invalid_data = featurizer.featurize(invalid_molecules)

    # test featurization

      valid_graph = list(filter(lambda x: isinstance(x, GraphMatrix), data))
      invalid_graph = list(
          filter(lambda x: not isinstance(x, GraphMatrix), data))
      assert len(data) == len(smiles)
      assert len(valid_graph) == len(smiles) - 1
      assert len(invalid_graph) == 1
    valid_graphs = list(
        filter(lambda x: isinstance(x, GraphMatrix), valid_data))
    invalid_graphs = list(
        filter(lambda x: not isinstance(x, GraphMatrix), invalid_data))
    assert len(valid_graphs) == len(valid_molecules)
    assert len(invalid_graphs) == len(invalid_molecules)

    # test defeaturization
      mols = featurizer.defeaturize(data)
      valid_mols = list(filter(lambda x: isinstance(x, Chem.rdchem.Mol), mols))
    valid_mols = featurizer.defeaturize(valid_graphs)
    invalid_mols = featurizer.defeaturize(invalid_graphs)
    valid_mols = list(
        filter(lambda x: isinstance(x, Chem.rdchem.Mol), valid_mols))
    invalid_mols = list(
          filter(lambda x: not isinstance(x, Chem.rdchem.Mol), mols))
      assert len(valid_mols) == len(valid_mols)
      assert len(invalid_mols) == len(invalid_graph)
        filter(lambda x: not isinstance(x, Chem.rdchem.Mol), invalid_mols))
    assert len(valid_mols) == len(valid_graphs)
    assert len(invalid_mols) == len(invalid_graphs)

    # check if original smiles match defeaturized smiles
    defe_smiles = list(map(Chem.MolToSmiles, valid_mols))
    assert defe_smiles == smiles


if __name__ == '__main__':