Commit 3b96893e authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Bugfixes

parent b081bc50
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -321,7 +321,8 @@ class FASTALoader(DataLoader):
    def shard_generator():
      for input_file in input_files:
        X = encode_fasta_sequence(input_file)
        ids = np.ones(len(X))
        # (X, y, w, ids)
        yield X, None, None, None
        yield X, None, None, ids

    return DiskDataset.create_dataset(shard_generator(), data_dir)
+2 −2
Original line number Diff line number Diff line
>seq0
ACGTCCCACACGATGCATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGAT
ACGTCCCACACGATGCATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGAT
>seq1
GTCGATGCATGCTAGCTAGCTAGCTAGCTACGATCGATCGATCGTACGATCGATCGAT
>seq2
ACACATCATCATTACTATATATTATATATCGATCGATCGATCGATCGTACGTAGCTAGCTAGCA
ACACATCATCATTACTATATATTATATATCGATCGATCGATCGATCGTACGTAGCTAG
+7 −2
Original line number Diff line number Diff line
@@ -28,5 +28,10 @@ class TestDataLoader(unittest.TestCase):
    input_file = os.path.join(self.current_dir,
                              "../../data/tests/example.fasta")
    loader = dc.data.FASTALoader()
    loader.featurize(input_file)
    assert 0 == 1
    sequences = loader.featurize(input_file)
    print("sequences.X.shape")
    print(sequences.X.shape)
    # example.fasta contains 3 sequences each of length 58.
    # The one-hot encoding turns base-pairs into vectors of length 4.
    # There is one "image channel")
    assert sequences.X.shape == (3, 4, 58, 1)
+18 −17
Original line number Diff line number Diff line
@@ -109,8 +109,9 @@ def load_csv_files(filenames, shard_size=None, verbose=True):
def seq_one_hot_encode(sequences):
  """One hot encodes list of genomic sequences.

  Sequences encoded have shape (N_sequences, 1, 4, sequence_length).
  Sequences encoded have shape (N_sequences, 4, sequence_length, 1).
  Here 4 is for the 4 basepairs (ACGT) present in genomic sequences.
  These sequences will be processed as images with one color channel.

  Parameters
  ----------
@@ -121,25 +122,25 @@ def seq_one_hot_encode(sequences):
  # depends on Python version
  # TODO(rbharath): Can this be removed?
  integer_type = np.int8 if sys.version_info[0] == 2 else np.int32
  ######################################################## DEBUG
  print("sequences")
  print(sequences)
  print("sequences.view(integer_type)")
  print(sequences.view(integer_type))
  ######################################################## DEBUG
  integer_array = LabelEncoder().fit(
      np.array(('ACGTN',)).view(integer_type)).transform(
          sequences.view(integer_type)).reshape(
              len(sequences), sequence_length)
  ################################################## DEBUG
  print("integer_array")
  print(integer_array)
  ################################################## DEBUG
  # The label encoder is given characters for ACGTN
  label_encoder = LabelEncoder().fit(np.array(('ACGTN',)).view(integer_type))
  # These are transformed in 0, 1, 2, 3, 4 in input sequence
  integer_array = []
  # TODO(rbharath): Unlike the DRAGONN implementation from which this
  # was ported, I couldn't transform the "ACGT..." strings into
  # integers all at once. Had to do one at a time. Might be worth
  # figuring out what's going on under the hood.
  for sequence in sequences:
    integer_seq = label_encoder.transform(
        np.array((sequence,)).view(integer_type))
    integer_array.append(integer_seq)
  integer_array = np.concatenate(integer_array)
  integer_array = integer_array.reshape(len(sequences), sequence_length)
  one_hot_encoding = OneHotEncoder(
      sparse=False, n_values=5, dtype=integer_type).fit_transform(integer_array)

  return one_hot_encoding.reshape(len(sequences), 1, sequence_length,
                                  5).swapaxes(2, 3)[:, :, [0, 1, 2, 4], :]
  return one_hot_encoding.reshape(len(sequences), sequence_length, 5,
                                  1).swapaxes(1, 2)[:, [0, 1, 2, 4], :, :]


def encode_fasta_sequence(fname):