Commit 347ac6ac authored by alat-rights's avatar alat-rights
Browse files

roll back changes to RobertaFeaturizer

parent 8afa99e4
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -73,6 +73,12 @@ try:
except ModuleNotFoundError:
  pass

try:
  from transformers import BertTokenizerFast
  from deepchem.feat.bert_tokenizer import BertFeaturizer
except ModuleNotFoundError:
  pass

try:
  from transformers import RobertaTokenizerFast
  from deepchem.feat.roberta_tokenizer import RobertaFeaturizer
+5 −7
Original line number Diff line number Diff line
from deepchem.feat import Featurizer
from typing import Dict, List
from typing import Dict, List, Optional
try:
  from transformers import BertTokenizerFast
except ModuleNotFoundError:
@@ -8,7 +8,7 @@ except ModuleNotFoundError:
  pass


class BertFeaturizer(BertFeaturizerFast, Featurizer):
class BertFeaturizer(Featurizer):
  """Bert Featurizer.

  Bert Featurizer.
@@ -24,15 +24,14 @@ class BertFeaturizer(BertFeaturizerFast, Featurizer):

  Notes
  -----
  This class inherits from BertTokenizerFast.
  This class may contain code and/or documentation taken from the
  RobertaFeaturizer pull request (#2581), which have been moved here due to
  code restructuring.
  """

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return
  def __init__(self, tokenizer: BertTokenizerFast = BertTokenizerFast()):
    self.tokenizer = tokenizer
    return self.tokenizer

  def _featurize(self, datapoint: str, **kwargs) -> List[List[int]]:
    """Calculate encoding using HuggingFace's RobertaTokenizerFast
@@ -49,7 +48,6 @@ class BertFeaturizer(BertFeaturizerFast, Featurizer):
    """

    # the encoding is natively a dictionary with keys 'input_ids', 'token_type_ids', and 'attention_mask'
    # encoding = list(self(smiles_string, **kwargs).values())
    encoding = list(self(datapoint, **kwargs).values())
    return encoding

+1 −4
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ def test_call():
    assert len(embedding['input_ids']) == 2 and len(
        emb['attention_mask']) == 2


def test_featurize():
  """Test that BertFeaturizer.featurize() correctly featurizes all sequences,
  correctly outputs input_ids and attention_mask.
@@ -32,10 +33,6 @@ def test_featurize():
      "Rostlab/prot_bert", do_lower_case=False)
  feats = featurizer.featurize(sequence)
  long_feat = featurizer.featurize(sequence_long)
  """
  for f in feats:
    print(f)
  """
  assert (len(feats) == 2)
  assert (all([len(f) == 3 for f in feats]))
  assert (len(long_feat) == 1)
+14 −0
Original line number Diff line number Diff line
@@ -396,6 +396,20 @@ References:
Other Featurizers
-----------------

BertFeaturizer
^^^^^^^^^^^^^^

.. autoclass:: deepchem.feat.BertFeaturizer
    :members:
    :inherited-members:

RobertaFeaturizer
^^^^^^^^^^^^^^^^^

.. autoclass:: deepchem.feat.RobertaFeaturizer
    :members:
    :inherited-members:

BindingPocketFeaturizer
^^^^^^^^^^^^^^^^^^^^^^^