Commit ae9b5670 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent e8d4765b
Loading
Loading
Loading
Loading
+18 −14
Original line number Diff line number Diff line
@@ -44,9 +44,11 @@ class Featurizer(object):
    for i, point in enumerate(datapoints):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)
      if point is not None:
      try:
        features.append(self._featurize(point))
      else:
      except:
        logger.warning(
            "Failed to featurize datapoint %d. Appending empty array")
        features.append(np.array([]))

    features = np.asarray(features)
@@ -139,12 +141,12 @@ class MolecularFeaturizer(Featurizer):
  In general, subclasses of this class will require RDKit to be installed.
  """

  def featurize(self, mols, log_every_n=1000):
  def featurize(self, molecules, log_every_n=1000):
    """Calculate features for molecules.

    Parameters
    ----------
    mols : RDKit Mol / SMILES string /iterable
    molecules: RDKit Mol / SMILES string /iterable
        RDKit Mol, or SMILES string or iterable sequence of RDKit mols/SMILES
        strings.

@@ -159,22 +161,24 @@ class MolecularFeaturizer(Featurizer):
    except ModuleNotFoundError:
      raise ValueError("This class requires RDKit to be installed.")
    # Special case handling of single molecule
    if isinstance(mols, str) or isinstance(mols, Mol):
      mols = [mols]
    if isinstance(molecules, str) or isinstance(molecules, Mol):
      molecules = [molecules]
    else:
      # Convert iterables to list
      mols = list(mols)
      molecutes = list(molecules)
    features = []
    for i, mol in enumerate(mols):
    for i, mol in enumerate(molecules):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)
      if mol is not None:
      try:
        # Process only case of SMILES strings.
        if isinstance(mol, str):
          # mol must be a SMILES string so parse
          mol = Chem.MolFromSmiles(mol)
        features.append(self._featurize(mol))
      else:
      except:
        logger.warning(
            "Failed to featurize datapoint %d. Appending empty array")
        features.append(np.array([]))

    features = np.asarray(features)
@@ -191,16 +195,16 @@ class MolecularFeaturizer(Featurizer):
    """
    raise NotImplementedError('Featurizer is not defined.')

  def __call__(self, mols):
  def __call__(self, molecules):
    """
    Calculate features for molecules.

    Parameters
    ----------
    mols : iterable
        RDKit Mol objects.
    molecules: iterable
        An iterable yielding RDKit Mol objects or SMILES strings.
    """
    return self.featurize(mols)
    return self.featurize(molecules)


class UserDefinedFeaturizer(Featurizer):
+1 −1
Original line number Diff line number Diff line
@@ -115,7 +115,7 @@ def safe_index(l, e):


class GraphConvConstants(object):
  """Allowed Atom Types."""
  """This class defines a collection of constants which are useful for graph convolutions on molecules."""
  possible_atom_list = [
      'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu',
      'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn'