Commit 964fa216 authored by alat-rights's avatar alat-rights
Browse files

added new test addressing data loader

parent bf060749
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
import pytest


@pytest.mark.torch
def test_featurize():
  """Test that BertFeaturizer.featurize() correctly featurizes all sequences,
@@ -20,3 +19,20 @@ def test_featurize():
  assert (all([len(f) == 3 for f in feats]))
  assert (len(long_feat) == 1)
  assert (len(long_feat[0] == 2))

@pytest.mark.torch
def test_loading():
  """Test that the FASTA loader can load with this featurizer."""
  from transformers import BertModel, BertTokenizerFast
  from deepchem.feat.bert_tokenizer import BertFeaturizer
  from deepchem.data.data_loader import FASTALoader
  import re

  tokenizer = BertTokenizerFast.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
  model = BertModel.from_pretrained("Rostlab/prot_bert")
  featurizer = BertFeaturizer(tokenizer)

  loader = FASTALoader(featurizer = featurizer, legacy = False, auto_add_annotations = True)
  data = loader.create_dataset(input_files = "../../data/tests/uniprot_truncated.fasta")

  assert data.X.shape == (61, 3, 5)