Commit e93e03b3 authored by alat-rights's avatar alat-rights
Browse files

yapf/flake

parent abd0b4e8
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
from deepchem.feat import Featurizer
from typing import Dict, List, Optional
from typing import Dict, List
try:
  from transformers import BertTokenizerFast
except ModuleNotFoundError:
@@ -28,6 +28,7 @@ class BertFeaturizer(Featurizer):
  RobertaFeaturizer pull request (#2581), which have been moved here due to
  code restructuring.
  """

  def __init__(self, tokenizer: BertTokenizerFast):
    if not isinstance(tokenizer, BertTokenizerFast):
      raise TypeError(f"""`tokenizer` must be a constructed `BertTokenizerFast`
+3 −4
Original line number Diff line number Diff line
@@ -14,12 +14,11 @@ def test_call():
      "Rostlab/prot_bert", do_lower_case=False)
  featurizer = BertFeaturizer(tokenizer)
  embedding = featurizer(sequence, return_tensors='pt')
  embedding_long = featurizer(
      sequence_long * 2, return_tensors='pt')
  embedding_long = featurizer(sequence_long * 2, return_tensors='pt')
  for emb in [embedding, embedding_long]:
    assert 'input_ids' in emb.keys() and 'attention_mask' in emb.keys()
    assert len(embedding['input_ids']) == 2 and len(
        emb['attention_mask']) == 2
    assert len(embedding['input_ids']) == 2 and len(emb['attention_mask']) == 2


def test_featurize():
  """Test that BertFeaturizer.featurize() correctly featurizes all sequences,