Unverified Commit 72659d83 authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2252 from nd-02110114/fix-ci-rdkit

Fix test which doesn't pass in the master branch and add some options for rdkit descriptor featurizer
parents 6ccf705d 979f19c8
Loading
Loading
Loading
Loading
+24 −5
Original line number Diff line number Diff line
@@ -23,15 +23,29 @@ class RDKitDescriptors(MolecularFeaturizer):
  This class requires RDKit to be installed.
  """

  def __init__(self):
  def __init__(self, use_fragment=True, ipc_avg=True):
    """Initialize this featurizer.

    Parameters
    ----------
    use_fragment: bool, optional (default True)
      If True, the return value includes the fragment binary descriptors like 'fr_XXX'.
    ipc_avg: bool, optional (default True)
      If True, the IPC descriptor calculates with avg=True option.
      Please see this issue: https://github.com/rdkit/rdkit/issues/1527.
    """
    try:
      from rdkit.Chem import Descriptors
    except ModuleNotFoundError:
      raise ValueError("This class requires RDKit to be installed.")

    self.use_fragment = use_fragment
    self.ipc_avg = ipc_avg
    self.descriptors = []
    self.descList = []
    for descriptor, function in Descriptors.descList:
      if self.use_fragment is False and descriptor.startswith('fr_'):
        continue
      self.descriptors.append(descriptor)
      self.descList.append((descriptor, function))

@@ -47,9 +61,14 @@ class RDKitDescriptors(MolecularFeaturizer):
    Returns
    -------
    np.ndarray
      1D array of RDKit descriptors for `mol`. The length is 200.
      1D array of RDKit descriptors for `mol`.
      The length is `len(self.descriptors)`.
    """
    rval = []
    features = []
    for desc_name, function in self.descList:
      rval.append(function(mol))
    return np.asarray(rval)
      if desc_name == 'Ipc' and self.ipc_avg:
        feature = function(mol, avg=True)
      else:
        feature = function(mol)
      features.append(feature)
    return np.asarray(features)
+17 −11
Original line number Diff line number Diff line
@@ -25,10 +25,11 @@ class TestRDKitDescriptors(unittest.TestCase):
    """
    Test simple descriptors.
    """
    descriptors = self.featurizer([self.mol])
    assert descriptors.shape == (1, 200)
    featurizer = RDKitDescriptors()
    descriptors = featurizer([self.mol])
    assert descriptors.shape == (1, len(featurizer.descriptors))
    assert np.allclose(
        descriptors[0, self.featurizer.descriptors.index('ExactMolWt')],
        descriptors[0, featurizer.descriptors.index('ExactMolWt')],
        180,
        atol=0.1)

@@ -36,20 +37,25 @@ class TestRDKitDescriptors(unittest.TestCase):
    """
    Test invocation on raw smiles.
    """
    descriptors = self.featurizer('CC(=O)OC1=CC=CC=C1C(=O)O')
    assert descriptors.shape == (1, 200)
    featurizer = RDKitDescriptors()
    descriptors = featurizer('CC(=O)OC1=CC=CC=C1C(=O)O')
    assert descriptors.shape == (1, len(featurizer.descriptors))
    assert np.allclose(
        descriptors[0, self.featurizer.descriptors.index('ExactMolWt')],
        descriptors[0, featurizer.descriptors.index('ExactMolWt')],
        180,
        atol=0.1)

  def test_rdkit_descriptors_on_mol(self):
  def test_rdkit_descriptors_with_use_fragment(self):
    """
    Test invocation on RDKit mol.
    Test with use_fragment
    """
    descriptors = self.featurizer(self.mol)
    assert descriptors.shape == (1, 200)
    from rdkit.Chem import Descriptors
    featurizer = RDKitDescriptors(use_fragment=False)
    descriptors = featurizer(self.mol)
    assert descriptors.shape == (1, len(featurizer.descriptors))
    all_descriptors = Descriptors.descList
    assert len(featurizer.descriptors) < len(all_descriptors)
    assert np.allclose(
        descriptors[0, self.featurizer.descriptors.index('ExactMolWt')],
        descriptors[0, featurizer.descriptors.index('ExactMolWt')],
        180,
        atol=0.1)