Unverified Commit e613643b authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2458 from alat-rights/onehot

[WIP] Generalize OneHotFeaturizer to Support Arbitrary Strings
parents 361207f8 198233c2
Loading
Loading
Loading
Loading
+89 −27
Original line number Diff line number Diff line
@@ -5,7 +5,8 @@ import numpy as np

from deepchem.utils.typing import RDKitMol
from deepchem.utils.molecule_feature_utils import one_hot_encode
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.feat.base_classes import Featurizer
from typing import Any, Iterable

logger = logging.getLogger(__name__)

@@ -16,14 +17,19 @@ ZINC_CHARSET = [
]


class OneHotFeaturizer(MolecularFeaturizer):
  """Encodes SMILES as a one-hot array.
class OneHotFeaturizer(Featurizer):
  """Encodes any arbitrary string or molecule as a one-hot array.

  This featurizer encodes SMILES string as a one-hot array.
  This featurizer encodes the characters within any given string as a one-hot
  array. It also works with RDKit molecules: it can convert RDKit molecules to
  SMILES strings and then one-hot encode the characters in said strings.

  Note
  ----
  This class requires RDKit to be installed.
  This class needs RDKit to be installed in order to accept RDKit molecules as
  inputs.

  It does not need RDKit to be installed to work with arbitrary strings.
  """

  def __init__(self, charset: List[str] = ZINC_CHARSET, max_length: int = 100):
@@ -42,57 +48,113 @@ class OneHotFeaturizer(MolecularFeaturizer):
    self.charset = charset
    self.max_length = max_length

  def _featurize(self, mol: RDKitMol) -> np.ndarray:
    """Compute one-hot featurization of this molecule.
  def featurize(self, datapoints: Iterable[Any],
                log_every_n: int = 1000) -> np.ndarray:
    """Featurize strings or mols.

    Parameters
    ----------
    mol: rdkit.Chem.rdchem.Mol
      RDKit Mol object
    datapoints: list
      A list of either strings or RDKit molecules.
    log_every_n: int, optional (default 1000)
      How many elements are featurized every time a featurization is logged.
    """
    datapoints = list(datapoints)
    if (len(datapoints) < 1):
      return np.array([])
    # Featurize data using featurize() in grandparent class
    return Featurizer.featurize(self, datapoints, log_every_n)

  def _featurize(self, datapoint: Any):
    # Featurize str data
    if (type(datapoint) == str):
      return self._featurize_string(datapoint)
    # Featurize mol data
    else:
      return self._featurize_mol(datapoint)

  def _featurize_string(self, string: str) -> np.ndarray:
    """Compute one-hot featurization of string.

    Parameters
    ----------
    string: str
      An arbitrary string to be featurized.

    Returns
    -------
    np.ndarray
      An one hot vector encoded from SMILES.
      An one hot vector encoded from arbitrary input string.
      The shape is `(max_length, len(charset) + 1)`.
      The index of unknown character is `len(charset)`.
    """
    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This class requires RDKit to be installed.")

    smiles = Chem.MolToSmiles(mol)
    # validation
    if len(smiles) > self.max_length:
    if (len(string) > self.max_length):
      logger.info(
          "The length of {} is longer than `max_length`. So we return an empty array."
      )
      return np.array([])

    smiles = self.pad_smile(smiles)
    string = self.pad_string(string)  # Padding
    return np.array([
        one_hot_encode(val, self.charset, include_unknown_set=True)
        for val in smiles
        for val in string
    ])

  def _featurize_mol(self, mol: RDKitMol) -> np.ndarray:
    """Compute one-hot featurization of this molecule.

    Parameters
    ----------
    mol: rdKit.Chem.rdchem.Mol
      RDKit Mol object

    Returns
    -------
    np.ndarray
      An one hot vector encoded from SMILES.
      The shape is '(max_length, len(charset) + 1)'
      The index of unknown character is 'len(charset)'.
    """
    try:
      from rdkit import Chem
    except ModuleNotFoundError:
      raise ImportError("This class requires RDKit to be installed.")
    smiles = Chem.MolToSmiles(mol)  # Convert mol to SMILES string.
    return self._featurize_string(smiles)  # Use string featurization.

  def pad_smile(self, smiles: str) -> str:
    """Pad SMILES string to `self.pad_length`

    Parameters
    ----------
    smiles: str
      The smiles string to be padded.
      The SMILES string to be padded.

    Returns
    -------
    str
      SMILES string space padded to self.pad_length
    """
    return smiles.ljust(self.max_length)
    return self.pad_string(smiles)

  def pad_string(self, string: str) -> str:
    """Pad string to `self.pad_length`

    Parameters
    ----------
    string: str
      The string to be padded.

    Returns
    -------
    str
      String space padded to self.pad_length
    """
    return string.ljust(self.max_length)

  def untransform(self, one_hot_vectors: np.ndarray) -> str:
    """Convert from one hot representation back to SMILES
    """Convert from one hot representation back to original string

    Parameters
    ----------
@@ -102,13 +164,13 @@ class OneHotFeaturizer(MolecularFeaturizer):
    Returns
    -------
    str
      SMILES string for an one hot encoded array.
      Original string for an one hot encoded array.
    """
    smiles = ""
    string = ""
    for one_hot in one_hot_vectors:
      try:
        idx = np.argmax(one_hot)
        smiles += self.charset[idx]
        string += self.charset[idx]
      except IndexError:
        smiles += ""
    return smiles
        string += ""
    return string
+55 −7
Original line number Diff line number Diff line
@@ -11,9 +11,24 @@ class TestOneHotFeaturizert(unittest.TestCase):
  Test OneHotFeaturizer.
  """

  def test_onehot_featurizer(self):
  def test_onehot_featurizer_arbitrary(self):
    """
    Test simple one hot encoding.
    Test simple one hot encoding for arbitrary string.
    """
    string = "abcdefghijklmnopqrstuvwxyzwebhasw"
    charset = "abcdefghijklmnopqrstuvwxyz"
    length = len(charset) + 1
    defaultMaxLength = 100
    featurizer = OneHotFeaturizer(charset)
    feature = featurizer([string])  # Implicit call to featurize()
    assert feature.shape == (1, defaultMaxLength, length)
    # untransform
    undo_string = featurizer.untransform(feature[0])
    assert string == undo_string

  def test_onehot_featurizer_SMILES(self):
    """
    Test simple one hot encoding for SMILES strings.
    """
    from rdkit import Chem
    length = len(ZINC_CHARSET) + 1
@@ -21,13 +36,27 @@ class TestOneHotFeaturizert(unittest.TestCase):
    mol = Chem.MolFromSmiles(smiles)
    featurizer = OneHotFeaturizer()
    feature = featurizer([mol])
    assert feature.shape == (1, 100, length)

    defaultMaxLength = 100
    assert feature.shape == (1, defaultMaxLength, length)
    # untranform
    undo_smiles = featurizer.untransform(feature[0])
    assert smiles == undo_smiles

  def test_onehot_featurizer_with_max_length(self):
  def test_onehot_featurizer_arbitrary_with_max_length(self):
    """
    Test one hot encoding with max_length.
    """
    string = "abcdefghijklmnopqrstuvwxyzvewqmc"
    charset = "abcdefghijklmnopqrstuvwxyz"
    length = len(charset) + 1
    featurizer = OneHotFeaturizer(charset, max_length=120)
    feature = featurizer([string])
    assert feature.shape == (1, 120, length)
    # untranform
    undo_string = featurizer.untransform(feature[0])
    assert string == undo_string

  def test_onehot_featurizer_SMILES_with_max_length(self):
    """
    Test one hot encoding with max_length.
    """
@@ -38,12 +67,11 @@ class TestOneHotFeaturizert(unittest.TestCase):
    featurizer = OneHotFeaturizer(max_length=120)
    feature = featurizer([mol])
    assert feature.shape == (1, 120, length)

    # untranform
    undo_smiles = featurizer.untransform(feature[0])
    assert smiles == undo_smiles

  def test_correct_transformation(self):
  def test_correct_transformation_SMILES(self):
    """
    Test correct one hot encoding.
    """
@@ -62,3 +90,23 @@ class TestOneHotFeaturizert(unittest.TestCase):
    # untranform
    undo_smiles = featurizer.untransform(feature[0])
    assert smiles == undo_smiles

  def test_correct_transformation_arbitrary(self):
    """
    Test correct one hot encoding.
    """
    charset = "1234567890"
    string = "12345"
    featurizer = OneHotFeaturizer(charset=charset, max_length=100)
    feature = featurizer([string])
    assert np.allclose(feature[0][0],
                       np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
    assert np.allclose(feature[0][1],
                       np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
    assert np.allclose(feature[0][2],
                       np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
    assert np.allclose(feature[0][3],
                       np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]))
    assert np.allclose(feature[0][4],
                       np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]))
    assert "This test case has not yet been written."