Commit c771c8c8 authored by alat-rights's avatar alat-rights
Browse files

modified unit test, modified __call__()

parent 04d4fc4e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -979,7 +979,7 @@ class FASTALoader(DataLoader):
        # (X, y, w, ids)
        yield X, None, None, ids

    def _read_file(input_file: str):
    def _read_file(input_file: str) -> np.ndarray:
      """
      Convert the FASTA file to a numpy array of FASTA-format strings.
      """
+1 −2
Original line number Diff line number Diff line
from deepchem.feat import Featurizer
import numpy as np
from icecream import ic
from typing import Dict, List
from typing import List
try:
  from transformers import BertTokenizerFast
except ModuleNotFoundError:
+2 −20
Original line number Diff line number Diff line
@@ -2,24 +2,6 @@ import pytest


@pytest.mark.torch
def test_call():
  """Test BertFeaturizer.__call__(), which is based on BertTokenizerFast."""
  from deepchem.feat.bert_tokenizer import BertFeaturizer
  from transformers import BertTokenizerFast
  sequence = [
      '[CLS] D L I P T S S K L V [SEP]', '[CLS] V K K A F F A L V T [SEP]'
  ]
  sequence_long = ['[CLS] D L I P T S S K L V V K K A F F A L V T [SEP]']
  tokenizer = BertTokenizerFast.from_pretrained(
      "Rostlab/prot_bert", do_lower_case=False)
  featurizer = BertFeaturizer(tokenizer)
  embedding = featurizer(sequence, return_tensors='pt')
  embedding_long = featurizer(sequence_long * 2, return_tensors='pt')
  for emb in [embedding, embedding_long]:
    assert 'input_ids' in emb.keys() and 'attention_mask' in emb.keys()
    assert len(embedding['input_ids']) == 2 and len(emb['attention_mask']) == 2


def test_featurize():
  """Test that BertFeaturizer.featurize() correctly featurizes all sequences,
  correctly outputs input_ids and attention_mask."""
@@ -32,8 +14,8 @@ def test_featurize():
  tokenizer = BertTokenizerFast.from_pretrained(
      "Rostlab/prot_bert", do_lower_case=False)
  featurizer = BertFeaturizer(tokenizer)
  feats = featurizer.featurize(sequence)
  long_feat = featurizer.featurize(sequence_long)
  feats = featurizer(sequence)
  long_feat = featurizer(sequence_long)
  assert (len(feats) == 2)
  assert (all([len(f) == 3 for f in feats]))
  assert (len(long_feat) == 1)