Commit 45d1b493 authored by alat-rights's avatar alat-rights
Browse files

fix __call__()

parent 347ac6ac
Loading
Loading
Loading
Loading
+10 −7
Original line number Diff line number Diff line
@@ -28,13 +28,16 @@ class BertFeaturizer(Featurizer):
  RobertaFeaturizer pull request (#2581), which have been moved here due to
  code restructuring.
  """

  def __init__(self, tokenizer: BertTokenizerFast = BertTokenizerFast()):
  def __init__(self, tokenizer: BertTokenizerFast):
    if not isinstance(tokenizer, BertTokenizerFast):
      raise TypeError(f"""`tokenizer` must be a constructed `BertTokenizerFast`
                       object, not {type(tokenizer)}""")
    else:
      self.tokenizer = tokenizer
    return self.tokenizer

  def _featurize(self, datapoint: str, **kwargs) -> List[List[int]]:
    """Calculate encoding using HuggingFace's RobertaTokenizerFast
    """
    Calculate encoding using HuggingFace's RobertaTokenizerFast

    Parameters
    ----------
@@ -48,8 +51,8 @@ class BertFeaturizer(Featurizer):
    """

    # the encoding is natively a dictionary with keys 'input_ids', 'token_type_ids', and 'attention_mask'
    encoding = list(self(datapoint, **kwargs).values())
    encoding = list(self.tokenizer(datapoint, **kwargs).values())
    return encoding

  def __call__(self, *args, **kwargs) -> Dict[str, List[int]]:
    return super().__call__(*args, **kwargs)
    return self.tokenizer.__call__(*args, **kwargs)
+7 −5
Original line number Diff line number Diff line
@@ -5,12 +5,14 @@ import pytest
def test_call():
  """Test BertFeaturizer.__call__(), which is based on BertTokenizerFast."""
  from deepchem.feat.bert_tokenizer import BertFeaturizer
  from transformers import BertTokenizerFast
  sequence = [
      '[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]'
  ]
  sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
  featurizer = BertFeaturizer.from_pretrained(
  tokenizer = BertTokenizerFast.from_pretrained(
      "Rostlab/prot_bert", do_lower_case=False)
  featurizer = BertFeaturizer(tokenizer)
  embedding = featurizer(sequence, return_tensors='pt')
  embedding_long = featurizer(
      sequence_long * 2, return_tensors='pt')
@@ -19,18 +21,18 @@ def test_call():
    assert len(embedding['input_ids']) == 2 and len(
        emb['attention_mask']) == 2


def test_featurize():
  """Test that BertFeaturizer.featurize() correctly featurizes all sequences,
  correctly outputs input_ids and attention_mask.
  """
  correctly outputs input_ids and attention_mask."""
  from deepchem.feat.bert_tokenizer import BertFeaturizer
  from transformers import BertTokenizerFast
  sequence = [
      '[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]'
  ]
  sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
  featurizer = BertFeaturizer.from_pretrained(
  tokenizer = BertTokenizerFast.from_pretrained(
      "Rostlab/prot_bert", do_lower_case=False)
  featurizer = BertFeaturizer(tokenizer)
  feats = featurizer.featurize(sequence)
  long_feat = featurizer.featurize(sequence_long)
  assert (len(feats) == 2)