Commit 0875273d authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Review comments

parent bb9858fa
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -195,7 +195,7 @@ class DataLoader(object):
    Parameters
    ----------
    input_files: list
      List of input files
      List of input filenames.
    data_dir: str
      (Optional) Directory to store featurized dataset.
    shard_size: int
+0 −4
Original line number Diff line number Diff line
@@ -10,8 +10,6 @@ __license__ = "MIT"

import os
import unittest
import tempfile
import shutil
import deepchem as dc


@@ -29,8 +27,6 @@ class TestFASTALoader(unittest.TestCase):
                              "../../data/tests/example.fasta")
    loader = dc.data.FASTALoader()
    sequences = loader.featurize(input_file)
    print("sequences.X.shape")
    print(sequences.X.shape)
    # example.fasta contains 3 sequences each of length 58.
    # The one-hot encoding turns base-pairs into vectors of length 4.
    # There is one "image channel")
+11 −0
Original line number Diff line number Diff line
@@ -117,6 +117,15 @@ def seq_one_hot_encode(sequences):
  ----------
  sequences: np.ndarray 
    Array of genetic sequences 

  Raises
  ------
  ValueError:
    If sequences are of different lengths.

  Returns
  -------
  np.ndarray: Shape (N_sequences, 4, sequence_length, 1).
  """
  sequence_length = len(sequences[0])
  # depends on Python version
@@ -130,6 +139,8 @@ def seq_one_hot_encode(sequences):
  # integers all at once. Had to do one at a time. Might be worth
  # figuring out what's going on under the hood.
  for sequence in sequences:
    if len(sequence) != sequence_length:
      raise ValueError("All sequences must be of same length")
    integer_seq = label_encoder.transform(
        np.array((sequence,)).view(integer_type))
    integer_array.append(integer_seq)
+35 −0
Original line number Diff line number Diff line
"""
Tests that sequence handling utilities work. 
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

__author__ = "Bharath Ramsundar"
__license__ = "MIT"

import numpy as np
import unittest
import deepchem as dc


class TestSeq(unittest.TestCase):
  """
  Tests sequence handling utilities.
  """

  def test_one_hot_simple(self):
    sequences = np.array(["ACGT", "GATA", "CGCG"])
    sequences = dc.utils.save.seq_one_hot_encode(sequences)
    assert sequences.shape == (3, 4, 4, 1)

  def test_one_hot_mismatch(self):
    # One sequence has length longer than others. This should throw a
    # value error.
    thrown = False
    try:
      sequences = np.array(["ACGTA", "GATA", "CGCG"])
      sequences = dc.utils.save.seq_one_hot_encode(sequences)
    except ValueError:
      thrown = True
    assert thrown