Commit 7cd554ce authored by Trent Hauck's avatar Trent Hauck
Browse files

Path 2

parent 238935e0
Loading
Loading
Loading
Loading
+17 −10
Original line number Diff line number Diff line
@@ -108,7 +108,8 @@ def load_csv_files(filenames, shard_size=None, verbose=True):
        yield df


def seq_one_hot_encode(sequences):

def seq_one_hot_encode(sequences, letters='ATCGN'):
  """One hot encodes list of genomic sequences.

  Sequences encoded have shape (N_sequences, 4, sequence_length, 1).
@@ -129,33 +130,39 @@ def seq_one_hot_encode(sequences):
  -------
  np.ndarray: Shape (N_sequences, 4, sequence_length, 1).
  """

  sequence_length = len(sequences[0])
  letters_length = len(letters)

  # depends on Python version
  integer_type = np.int32

  # The label encoder is given characters for ACGTN
  label_encoder = LabelEncoder().fit(np.array(('ACGTN',)).view(integer_type))
  # These are transformed in 0, 1, 2, 3, 4 in input sequence
  letters_array = np.array((letters,))
  label_encoder = LabelEncoder().fit(letters).view(integer_type))

  integer_array = []

  # TODO(rbharath): Unlike the DRAGONN implementation from which this
  # was ported, I couldn't transform the "ACGT..." strings into
  # integers all at once. Had to do one at a time. Might be worth
  # figuring out what's going on under the hood.

  for sequence in sequences:
    if len(sequence) != sequence_length:
      raise ValueError("All sequences must be of same length")
    integer_seq = label_encoder.transform(
        np.array((sequence,)).view(integer_type))
    integer_array.append(integer_seq)

  integer_array = np.concatenate(integer_array)
  integer_array = integer_array.reshape(len(sequences), sequence_length)
  one_hot_encoding = OneHotEncoder(
      sparse=False, n_values=5, dtype=integer_type).fit_transform(integer_array)

  return one_hot_encoding.reshape(len(sequences), sequence_length, 5,
                                  1).swapaxes(1, 2)[:, [0, 1, 2, 4], :, :]

  return one_hot_encoding.reshape(len(sequences), sequence_length, letters_length, 1).swapaxes(1, 2)

def encode_fasta_sequence(fname):
def encode_fasta_sequence(fname, letters='ATCGN'):
  """
  Loads fasta file and returns an array of one-hot sequences.

@@ -178,7 +185,7 @@ def encode_fasta_sequence(fname):
  if name is not None:
    sequences.append(''.join(seq_chars).upper())

  return seq_one_hot_encode(np.array(sequences))
  return seq_one_hot_encode(np.array(sequences), letters)

# This could just be ambiguous_dna_letters, but that would be much higher dim.
class IUPACUnambiguousDNAWithN(Alphabet):