Unverified Commit 7c42ebb7 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2642 from alat-rights/bert_to_merge

BertFeaturizer (Replaces #2608)
parents 79ceb479 39367e6b
......@@ -979,7 +979,7 @@ class FASTALoader(DataLoader):
# (X, y, w, ids)
yield X, None, None, ids
def _read_file(input_file: str, auto_add_annotations: bool = False):
def _read_file(input_file: str):
"""
Convert the FASTA file to a numpy array of FASTA-format strings.
"""
......@@ -1012,7 +1012,7 @@ class FASTALoader(DataLoader):
# TODO log attempts to add empty sequences every shard
return np.array([])
# Annotate start/stop of sequence
if auto_add_annotations:
if self.auto_add_annotations:
sequence = np.insert(sequence, 0, "[CLS]")
sequence = np.append(sequence, "[SEP]")
new_sequence = ''.join(sequence)
......
......@@ -73,6 +73,12 @@ try:
except ModuleNotFoundError:
pass
try:
from transformers import BertTokenizerFast
from deepchem.feat.bert_tokenizer import BertFeaturizer
except ModuleNotFoundError:
pass
try:
from transformers import RobertaTokenizerFast
from deepchem.feat.roberta_tokenizer import RobertaFeaturizer
......
from deepchem.feat import Featurizer
from typing import List
try:
from transformers import BertTokenizerFast
except ModuleNotFoundError:
raise ImportError(
'Transformers must be installed for BertFeaturizer to be used!')
pass
class BertFeaturizer(Featurizer):
"""Bert Featurizer.
Bert Featurizer.
The Bert Featurizer is a wrapper class for HuggingFace's BertTokenizerFast.
This class intends to allow users to use the BertTokenizer API while
remaining inside the DeepChem ecosystem.
Examples
--------
>>> from deepchem.feat import BertFeaturizer
>>> from transformers import BertTokenizerFast
>>> tokenizer = BertTokenizerFast.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
>>> featurizer = BertFeaturizer(tokenizer)
>>> feats = featurizer.featurize('D L I P [MASK] L V T')
Notes
-----
Examples are based on RostLab's ProtBert documentation.
"""
def __init__(self, tokenizer: BertTokenizerFast):
if not isinstance(tokenizer, BertTokenizerFast):
raise TypeError(f"""`tokenizer` must be a constructed `BertTokenizerFast`
object, not {type(tokenizer)}""")
else:
self.tokenizer = tokenizer
def _featurize(self, datapoint: str, **kwargs) -> List[List[int]]:
"""
Calculate encoding using HuggingFace's RobertaTokenizerFast
Parameters
----------
datapoint: str
Arbitrary string sequence to be tokenized.
Returns
-------
encoding: List
List containing three lists: the `input_ids`, 'token_type_ids', and `attention_mask`.
"""
# the encoding is natively a dictionary with keys 'input_ids', 'token_type_ids', and 'attention_mask'
encoding = list(self.tokenizer(datapoint, **kwargs).values())
return encoding
......@@ -20,7 +20,7 @@ class RobertaFeaturizer(RobertaTokenizerFast, Featurizer):
Examples
--------
>>> from deepchem.feat.molecule_featurizers import RobertaFeaturizer
>>> from deepchem.feat import RobertaFeaturizer
>>> smiles = ["Cn1c(=O)c2c(ncn2C)n(C)c1=O", "CC(=O)N1CN(C(C)=O)C(O)C1O"]
>>> featurizer = RobertaFeaturizer.from_pretrained("seyonec/SMILES_tokenized_PubChem_shard00_160k")
>>> featurizer.featurize(smiles, add_special_tokens=True, truncation=True)
......
This diff is collapsed.
import pytest
from os.path import join, realpath, dirname
@pytest.mark.torch
def test_featurize():
"""Test that BertFeaturizer.featurize() correctly featurizes all sequences,
correctly outputs input_ids and attention_mask."""
from deepchem.feat.bert_tokenizer import BertFeaturizer
from transformers import BertTokenizerFast
sequences = [
'[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]'
]
sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
tokenizer = BertTokenizerFast.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False)
featurizer = BertFeaturizer(tokenizer)
feats = featurizer(sequences)
long_feat = featurizer(sequence_long)
assert (len(feats) == 2)
assert (all([len(f) == 3 for f in feats]))
assert (len(long_feat) == 1)
assert (len(long_feat[0] == 2))
@pytest.mark.torch
def test_loading():
"""Test that the FASTA loader can load with this featurizer."""
from transformers import BertTokenizerFast
from deepchem.feat.bert_tokenizer import BertFeaturizer
from deepchem.data.data_loader import FASTALoader
tokenizer = BertTokenizerFast.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False)
featurizer = BertFeaturizer(tokenizer)
loader = FASTALoader(
featurizer=featurizer, legacy=False, auto_add_annotations=True)
file_loc = realpath(__file__)
directory = dirname(file_loc)
data = loader.create_dataset(
input_files=join(directory, "data/uniprot_truncated.fasta"))
assert data.X.shape == (61, 3, 5)
......@@ -396,6 +396,20 @@ References:
Other Featurizers
-----------------
BertFeaturizer
^^^^^^^^^^^^^^
.. autoclass:: deepchem.feat.BertFeaturizer
:members:
:inherited-members:
RobertaFeaturizer
^^^^^^^^^^^^^^^^^
.. autoclass:: deepchem.feat.RobertaFeaturizer
:members:
:inherited-members:
BindingPocketFeaturizer
^^^^^^^^^^^^^^^^^^^^^^^
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment