Unverified Commit 7c42ebb7 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2642 from alat-rights/bert_to_merge

BertFeaturizer (Replaces #2608)
parents 79ceb479 39367e6b
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -979,7 +979,7 @@ class FASTALoader(DataLoader):
        # (X, y, w, ids)
        yield X, None, None, ids

    def _read_file(input_file: str, auto_add_annotations: bool = False):
    def _read_file(input_file: str):
      """
      Convert the FASTA file to a numpy array of FASTA-format strings.
      """
@@ -1012,7 +1012,7 @@ class FASTALoader(DataLoader):
          # TODO log attempts to add empty sequences every shard
          return np.array([])
        # Annotate start/stop of sequence
        if auto_add_annotations:
        if self.auto_add_annotations:
          sequence = np.insert(sequence, 0, "[CLS]")
          sequence = np.append(sequence, "[SEP]")
        new_sequence = ''.join(sequence)
+6 −0
Original line number Diff line number Diff line
@@ -73,6 +73,12 @@ try:
except ModuleNotFoundError:
  pass

try:
  from transformers import BertTokenizerFast
  from deepchem.feat.bert_tokenizer import BertFeaturizer
except ModuleNotFoundError:
  pass

try:
  from transformers import RobertaTokenizerFast
  from deepchem.feat.roberta_tokenizer import RobertaFeaturizer
+56 −0
Original line number Diff line number Diff line
from deepchem.feat import Featurizer
from typing import List
try:
  from transformers import BertTokenizerFast
except ModuleNotFoundError:
  raise ImportError(
      'Transformers must be installed for BertFeaturizer to be used!')
  pass


class BertFeaturizer(Featurizer):
  """Bert Featurizer.

  Bert Featurizer.
  The Bert Featurizer is a wrapper class for HuggingFace's BertTokenizerFast.
  This class intends to allow users to use the BertTokenizer API while
  remaining inside the DeepChem ecosystem.

  Examples
  --------
  >>> from deepchem.feat import BertFeaturizer
  >>> from transformers import BertTokenizerFast
  >>> tokenizer = BertTokenizerFast.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
  >>> featurizer = BertFeaturizer(tokenizer)
  >>> feats = featurizer.featurize('D L I P [MASK] L V T')

  Notes
  -----
  Examples are based on RostLab's ProtBert documentation.
  """

  def __init__(self, tokenizer: BertTokenizerFast):
    if not isinstance(tokenizer, BertTokenizerFast):
      raise TypeError(f"""`tokenizer` must be a constructed `BertTokenizerFast`
                       object, not {type(tokenizer)}""")
    else:
      self.tokenizer = tokenizer

  def _featurize(self, datapoint: str, **kwargs) -> List[List[int]]:
    """
    Calculate encoding using HuggingFace's RobertaTokenizerFast

    Parameters
    ----------
    datapoint: str
      Arbitrary string sequence to be tokenized.

    Returns
    -------
    encoding: List
      List containing three lists: the `input_ids`, 'token_type_ids', and `attention_mask`.
    """

    # the encoding is natively a dictionary with keys 'input_ids', 'token_type_ids', and 'attention_mask'
    encoding = list(self.tokenizer(datapoint, **kwargs).values())
    return encoding
+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ class RobertaFeaturizer(RobertaTokenizerFast, Featurizer):

  Examples
  --------
  >>> from deepchem.feat.molecule_featurizers import RobertaFeaturizer
  >>> from deepchem.feat import RobertaFeaturizer
  >>> smiles = ["Cn1c(=O)c2c(ncn2C)n(C)c1=O", "CC(=O)N1CN(C(C)=O)C(O)C1O"]
  >>> featurizer = RobertaFeaturizer.from_pretrained("seyonec/SMILES_tokenized_PubChem_shard00_160k")
  >>> featurizer.featurize(smiles, add_special_tokens=True, truncation=True)
+354 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading