Commit c36de6cd authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent e7919ece
Loading
Loading
Loading
Loading
+27 −17
Original line number Diff line number Diff line
@@ -102,30 +102,40 @@ def in_silico_mutagenesis(model, X):
  model: TensorGraph
    Currently only SequenceDNN will work, but other models may be added.
  X: ndarray
      Shape (N_sequences, 1, N_letters, sequence_length) 
    Shape (N_sequences, N_letters, sequence_length, 1) 

  Returns
  -------
    (num_task, N_sequences, 1, N_letters, sequence_length) ISM score array.
  (num_task, N_sequences, N_letters, sequence_length, 1) ISM score array.
  """
  #Shape (N_sequences, N_letters, sequence_length, 1, num_tasks)
  mutagenesis_scores = np.empty(X.shape + (model.num_tasks,), dtype=np.float32)
  # Shape (N_sequences, num_tasks)
  wild_type_predictions = model.predict(NumpyDataset(X))
  # Shape (N_sequences, num_tasks, 1, 1, 1)
  wild_type_predictions = wild_type_predictions[:, np.newaxis, np.newaxis,
                                                np.newaxis]
  for sequence_index, (
      sequence,
      wild_type_prediction) in enumerate(zip(X, wild_type_predictions)):

    # Mutates every position of the sequence to every letter
    # Shape (N_letters * sequence_length, N_letters, sequence_length, 1)
    # Breakdown:
    #  Shape of sequence[np.newaxis] (1, N_letters, sequence_length, 1)
    mutated_sequences = np.repeat(
        sequence[np.newaxis], np.prod(sequence.shape), axis=0)

    # remove wild-type
    # len(arange) = N_letters * sequence_length
    arange = np.arange(len(mutated_sequences))
    horizontal_cycle = np.tile(
        np.arange(sequence.shape[-1]), sequence.shape[-2])
    mutated_sequences[arange, :, :, horizontal_cycle] = 0
    # len(horizontal cycle) = N_letters * sequence_length
    horizontal_cycle = np.tile(np.arange(sequence.shape[1]), sequence.shape[0])
    mutated_sequences[arange, :, horizontal_cycle, :] = 0

    # add mutant
    vertical_repeat = np.repeat(
        np.arange(sequence.shape[-2]), sequence.shape[-1])
    mutated_sequences[arange, :, vertical_repeat, horizontal_cycle] = 1
    vertical_repeat = np.repeat(np.arange(sequence.shape[0]), sequence.shape[1])
    mutated_sequences[arange, vertical_repeat, horizontal_cycle, :] = 1
    # make mutant predictions
    mutated_predictions = model.predict(NumpyDataset(mutated_sequences))
    mutated_predictions = mutated_predictions.reshape(sequence.shape +
+13 −6
Original line number Diff line number Diff line
@@ -49,13 +49,20 @@ class TestGenomicMetrics(unittest.TestCase):
  def test_in_silico_mutagenesis(self):
    """Test in-silico mutagenesis returns correct shape."""
    # Construct and train SequenceDNN model
    X = np.random.rand(10, 1, 4, 50)
    y = np.random.randint(0, 2, size=(10, 1))
    dataset = dc.data.NumpyDataset(X, y)
    sequences = np.array(["ACGTA", "GATAG", "CGCGC"])
    sequences = dc.utils.save.seq_one_hot_encode(sequences, letters=LETTERS)
    labels = np.array([1, 0, 0])
    labels = np.reshape(labels, (3, 1))
    self.assertEqual(sequences.shape, (3, 4, 5, 1))

    #X = np.random.rand(10, 1, 4, 50)
    #y = np.random.randint(0, 2, size=(10, 1))
    #dataset = dc.data.NumpyDataset(X, y)
    dataset = dc.data.NumpyDataset(sequences, labels)
    model = dc.models.SequenceDNN(
        50, "binary_crossentropy", num_filters=[1, 1], kernel_size=[15, 15])
        5, "binary_crossentropy", num_filters=[1, 1], kernel_size=[15, 15])
    model.fit(dataset, nb_epoch=1)

    # Call in-silico mutagenesis
    mutagenesis_scores = in_silico_mutagenesis(model, X)
    self.assertEqual(mutagenesis_scores.shape, (1, 10, 1, 4, 50))
    mutagenesis_scores = in_silico_mutagenesis(model, sequences)
    self.assertEqual(mutagenesis_scores.shape, (1, 3, 4, 5, 1))
+1 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ author = DeepChem contributors
summary = Deep-learning models for drug discovery, quantum chemistry, and the life sciences.
home-page = https://github.com/deepchem/deepchem
license = MIT
version = 2.1.0
version = 2.1.1
classifier =
    Development Status :: 4 - Beta
    Environment :: Console