Commit 4fb2ac80 authored by alat-rights's avatar alat-rights
Browse files

downstreaming

parent 500467de
Loading
Loading
Loading
Loading
+57 −0
Original line number Diff line number Diff line
from deepchem.feat import Featurizer
from typing import Dict, List
try:
  from transformers import BertTokenizerFast
except ModuleNotFoundError:
  raise ImportError(
      'Transformers must be installed for BertFeaturizer to be used!')
  pass


class BertFeaturizer(BertFeaturizerFast, Featurizer):
  """Bert Featurizer.

  Bert Featurizer.
  The Bert Featurizer is a wrapper class for HuggingFace's BertTokenizerFast.
  This class intends to allow users to use the BertTokenizer API while
  remaining inside the DeepChem ecosystem.

  Examples
  --------
  >>> from deepchem.feat import BertFeaturizer
  >>> featurizer = BertFeaturizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
  >>> inputs = featurizer.featurize('D L I P [MASK] L V T', return_tensors="pt")

  Notes
  -----
  This class inherits from BertTokenizerFast.
  This class may contain code and/or documentation taken from the
  RobertaFeaturizer pull request (#2581), which have been moved here due to
  code restructuring.
  """

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    return

  def _featurize(self, datapoint: str, **kwargs) -> List[List[int]]:
    """Calculate encoding using HuggingFace's RobertaTokenizerFast

    Parameters
    ----------
    datapoint: str
      Arbitrary string sequence to be tokenized.

    Returns
    -------
    encoding: List
      List containing three lists: the `input_ids`, 'token_type_ids', and `attention_mask`.
    """

    # the encoding is natively a dictionary with keys 'input_ids', 'token_type_ids', and 'attention_mask'
    # encoding = list(self(smiles_string, **kwargs).values())
    encoding = list(self(datapoint, **kwargs).values())
    return encoding

  def __call__(self, *args, **kwargs) -> Dict[str, List[int]]:
    return super().__call__(*args, **kwargs)
+37 −0
Original line number Diff line number Diff line
import unittest
from deepchem.feat import BertFeaturizer
from transformers import BertTokenizerFast


class TestBertFeaturizer(unittest.TestCase):
  """Tests for BertFeaturizer, based on tests for RobertaFeaturizer and
  usage examples in Rostlab prot_bert documentation hosted by HuggingFace."""

  def setUp(self):
    self.sequence = ['[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]']
    self.sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
    self.featurizer = BertFeaturizer.from_pretrained(
        "Rostlab/prot_bert", do_lower_case=False)

  def test_call(self):
    """Test BertFeaturizer.__call__(), which is based on BertTokenizerFast."""
    embedding = self.featurizer(
        self.sequence, return_tensors='pt')
    embedding_long = self.featurizer(
      self.sequence_long * 2, return_tensors='pt')
    for emb in [embedding, embedding_long]:
      assert 'input_ids' in emb.keys() and 'attention_mask' in emb.keys()
      assert len(embedding['input_ids']) == 2 and len(emb['attention_mask']) == 2

  def test_featurize(self):
    """Test that BertFeaturizer.featurize() correctly featurizes all sequences,
    correctly outputs input_ids and attention_mask.
    """
    feats = self.featurizer.featurize(self.sequence)
    long_feat = self.featurizer.featurize(
        self.sequence_long)

    assert (len(feats) == 2)
    assert (all([len(f) == 3 for f in feats]))
    assert (len(long_feat) == 1)
    assert (len(long_feat[0] == 3))