Commit 45d1b493 authored by alat-rights's avatar alat-rights
Browse files

fix __call__()

parent 347ac6ac
......@@ -28,13 +28,16 @@ class BertFeaturizer(Featurizer):
RobertaFeaturizer pull request (#2581), which have been moved here due to
code restructuring.
"""
def __init__(self, tokenizer: BertTokenizerFast = BertTokenizerFast()):
self.tokenizer = tokenizer
return self.tokenizer
def __init__(self, tokenizer: BertTokenizerFast):
if not isinstance(tokenizer, BertTokenizerFast):
raise TypeError(f"""`tokenizer` must be a constructed `BertTokenizerFast`
object, not {type(tokenizer)}""")
else:
self.tokenizer = tokenizer
def _featurize(self, datapoint: str, **kwargs) -> List[List[int]]:
"""Calculate encoding using HuggingFace's RobertaTokenizerFast
"""
Calculate encoding using HuggingFace's RobertaTokenizerFast
Parameters
----------
......@@ -48,8 +51,8 @@ class BertFeaturizer(Featurizer):
"""
# the encoding is natively a dictionary with keys 'input_ids', 'token_type_ids', and 'attention_mask'
encoding = list(self(datapoint, **kwargs).values())
encoding = list(self.tokenizer(datapoint, **kwargs).values())
return encoding
def __call__(self, *args, **kwargs) -> Dict[str, List[int]]:
return super().__call__(*args, **kwargs)
return self.tokenizer.__call__(*args, **kwargs)
......@@ -5,12 +5,14 @@ import pytest
def test_call():
"""Test BertFeaturizer.__call__(), which is based on BertTokenizerFast."""
from deepchem.feat.bert_tokenizer import BertFeaturizer
from transformers import BertTokenizerFast
sequence = [
'[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]'
]
sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
featurizer = BertFeaturizer.from_pretrained(
tokenizer = BertTokenizerFast.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False)
featurizer = BertFeaturizer(tokenizer)
embedding = featurizer(sequence, return_tensors='pt')
embedding_long = featurizer(
sequence_long * 2, return_tensors='pt')
......@@ -19,18 +21,18 @@ def test_call():
assert len(embedding['input_ids']) == 2 and len(
emb['attention_mask']) == 2
def test_featurize():
"""Test that BertFeaturizer.featurize() correctly featurizes all sequences,
correctly outputs input_ids and attention_mask.
"""
correctly outputs input_ids and attention_mask."""
from deepchem.feat.bert_tokenizer import BertFeaturizer
from transformers import BertTokenizerFast
sequence = [
'[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]'
]
sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
featurizer = BertFeaturizer.from_pretrained(
tokenizer = BertTokenizerFast.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False)
featurizer = BertFeaturizer(tokenizer)
feats = featurizer.featurize(sequence)
long_feat = featurizer.featurize(sequence_long)
assert (len(feats) == 2)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment