Commit 285bc538 authored by alat-rights's avatar alat-rights
Browse files

Fixed stylistic problems suggested by Bharath

parent 22933d68
Loading
Loading
Loading
Loading
+13 −21
Original line number Diff line number Diff line
@@ -21,11 +21,11 @@ ZINC_CHARSET = [
class OneHotFeaturizer(MolecularFeaturizer):
  """Encodes SMILES or any arbitrary string as a one-hot array.

  This featurizer encodes a string or any arbitrary string as a one-hot array.
  This featurizer encodes either a SMILES string or any arbitrary string as a one-hot array.

  Note
  ----
  This class requires RDKit to be installed to work with RDKit molecules.
  This class needs RDKit to be installed in order to work with RDKit molecules.
  """

  def __init__(self, charset: List[str] = ZINC_CHARSET, max_length: int = 100):
@@ -50,40 +50,32 @@ class OneHotFeaturizer(MolecularFeaturizer):

    Parameters
    ----------
    datapoints: A list of either strings or RDKit molecules.
    datapoints: list
      A list of either strings or RDKit molecules.
    log_every_n: int, optional (default 1000)
      How many elements are featurized every time a featurization is logged.
    """
    datapoints = list(datapoints)
    if (len(datapoints) < 1):
      print(
          "No datapoints are present in the parameter Iterable, so we return an empty array."
      )
      return np.array([])

    # Featurize str data
    if (type(datapoints[0]) == str):
      # Calls featurize() in grandparent class, which takes Iterable[Any].
    # Featurize data using featurize() in grandparent class
    return Featurizer.featurize(self, datapoints, log_every_n)
    # Featurize mol data
    else:
      # Calls featurize() in parent class, which takes molecules.
      return MolecularFeaturizer.featurize(self, datapoints, log_every_n)

  def _featurize(self, datapoint: Any):
    # Featurize str data
    if (type(datapoint) == str):
      return self._featurizeString(datapoint)
      return self._featurize_string(datapoint)
    # Featurize mol data
    else:
      return self._featurizeMol(datapoint)
      return self._featurize_mol(datapoint)

  def _featurizeString(self, string: str) -> np.ndarray:
  def _featurize_string(self, string: str) -> np.ndarray:
    """Compute one-hot featurization of string.

    Parameters
    ----------
    str: An arbitrary string to be featurized.
    string: str
      An arbitrary string to be featurized.

    Returns
    -------
@@ -105,7 +97,7 @@ class OneHotFeaturizer(MolecularFeaturizer):
        for val in string
    ])

  def _featurizeMol(self, mol: RDKitMol) -> np.ndarray:
  def _featurize_mol(self, mol: RDKitMol) -> np.ndarray:
    """Compute one-hot featurization of this molecule.

    Parameters
@@ -125,7 +117,7 @@ class OneHotFeaturizer(MolecularFeaturizer):
    except ModuleNotFoundError:
      raise ImportError("This class requires RDKit to be installed.")
    smiles = Chem.MolToSmiles(mol)  # Convert mol to SMILES string.
    return self._featurizeString(smiles)  # Use string featurization.
    return self._featurize_string(smiles)  # Use string featurization.

  def pad_smile(self, smiles: str) -> str:
    """Pad smile string to `self.pad_length`