Commit 06744bcc authored by seyonechithrananda's avatar seyonechithrananda
Browse files

mpass complexes as datapoints

parent e844cdce
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -34,7 +34,6 @@ from deepchem.feat.molecule_featurizers import RawFeaturizer
from deepchem.feat.molecule_featurizers import RDKitDescriptors
from deepchem.feat.molecule_featurizers import SmilesToImage
from deepchem.feat.molecule_featurizers import SmilesToSeq, create_char_to_idx
from deepchem.feat.molecule_featurizers import RobertaFeaturizer
from deepchem.feat.molecule_featurizers import MATFeaturizer

# complex featurizers
@@ -74,5 +73,7 @@ try:
except ModuleNotFoundError:
  pass

from deepchem.feat.roberta_tokenizer import RobertaFeaturizer

# support classes
from deepchem.feat.molecule_featurizers import GraphMatrix
+8 −5
Original line number Diff line number Diff line
@@ -156,14 +156,14 @@ class ComplexFeaturizer(Featurizer):
  """

  def featurize(self,
                complexes: Iterable[Tuple[str, str]],
                datapoints: Iterable[Tuple[str, str]] = None,
                log_every_n: int = 100, **kwargs) -> np.ndarray:
    """
    Calculate features for mol/protein complexes.

    Parameters
    ----------
    complexes: Iterable[Tuple[str, str]]
    datapoints: Iterable[Tuple[str, str]]
      List of filenames (PDB, SDF, etc.) for ligand molecules and proteins.
      Each element should be a tuple of the form (ligand_filename,
      protein_filename).
@@ -174,10 +174,13 @@ class ComplexFeaturizer(Featurizer):
      Array of features
    """

    if not isinstance(complexes, Iterable):
      complexes = [cast(Tuple[str, str], complexes)]
    if 'complexes' in kwargs:
      datapoints = kwargs.get("complexes")
      raise DeprecationWarning('Complexes is being phased out as a parameter, please pass "datapoints" instead.')
    if not isinstance(datapoints, Iterable):
      datapoints = [cast(Tuple[str, str], datapoints)]
    features, failures, successes = [], [], []
    for idx, point in enumerate(complexes):
    for idx, point in enumerate(datapoints):
      if idx % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % idx)
      try: