Commit 1ad80e45 authored by Jun Zhao's avatar Jun Zhao
Browse files

update DA-seq algorithm

parent 14b07d94
Loading
Loading
Loading
Loading
+343 −72
Original line number Diff line number Diff line
library(RANN)
library(reticulate)
library(ggplot2)
library(cowplot)
library(RColorBrewer)
@@ -15,32 +16,29 @@ library(Seurat)
#' @param labels.1 vector, label name(s) that represent condition 1
#' @param labels.2 vector, label name(s) that represent condition 2
#' @param k.vector vector, k values to create the score vector
#' @param ratio maximum ratio of cells to keep, default 0.2
#' @param i.max maximum iteration to run in the iterative clustering of the score vector, default 10
#' @param do.diffuse a logical value to indicate whether to calculate diffusion coordinates for X, default False
#' @param neigen number of diffusion coordinates, default 20
#' @param k.folds integer, number of data splits used in the neural network, default 10
#' @param n.runs integer, number of times to run the neural network to get the predictions, default 1
#' @param pred.thres length-2 vector, top and bottom threshold on the predictions from the neural network, default c(0.05,0.95)
#' @param do.plot a logical value to indicate whether to return ggplot objects showing the results, default True
#' @param plot.embedding size N-by-2 matrix, 2D embedding for the cells
#' @param size cell size to use in the plot, default 0.5
#' @param python.use character string, the Python to use, default "/usr/bin/python"
#' @param source.code character string, the neural network source code, default "./DA_nn.py"
#' @param GPU which GPU to use, default '', using CPU
#' 
#' @return a list of results
#'         da.ratio: score vector for each cell
#'         da.pred: (mean) prediction from the neural network
#'         da.cell.idx: cell index with the most DA neighborhood
#'         da.plot: ggplot object showing the steps of iterative clustering result on plot.embedding
#'         pred.plot: ggplot object showing the predictions of the neural network on plot.embedding
#'         da.cells.plot: ggplot object highlighting cells of da.cell.idx on plot.embedding

getDAcells <- function(
  X, cell.labels, labels.1, labels.2, k.vector,
  ratio = 0.2, i.max = 10, 
  do.diffuse = F, neigen = 20, do.plot = T, plot.embedding = NULL, size = 0.5
  k.folds = 10, n.runs = 1, pred.thres = c(0.05,0.95),
  do.plot = T, plot.embedding = NULL, size = 0.5, 
  python.use = "/usr/bin/python", source.code = "./DA_logit.py", GPU = ""
){
  # get diffusion coordinates
  if(do.diffuse){
    library(diffusionMap)
    cat("Calculating diffusion coordinates.\n")
    X.input <- X
    X <- diffuse(D = dist(X.input), neigen = neigen)$X
  }
  
  # get DA score vector for each cell
  cat("Calculating DA score vector.\n")
@@ -52,39 +50,147 @@ getDAcells <- function(
    k.vector = k.vector
  )
  
  
  # prepare data for python script
  use_python(python = python.use, required = T)
  
  binary.labels <- cell.labels
  binary.labels[cell.labels %in% labels.1] <- 0.0
  binary.labels[cell.labels %in% labels.2] <- 1.0
  
  binary.labels_py <- r_to_py(as.matrix(binary.labels))
  X.knn.ratio_py <- r_to_py(as.matrix(X.knn.ratio))
  
  
  # set GPU device
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))
  
  # get neural network predictions for each cell
  cat("Running neural network classification.\n")
  source_python(file = source.code)
  # py_run_string(paste("epochs = ", epochs, sep = ""))
  py_run_string(paste("k_folds = ", k.folds, sep = ""))
  
  
  # check n.runs
  if(n.runs == 1){
    X.pred <- k_fold_predict_linear(X.knn.ratio_py, binary.labels_py, py$k_folds)
    # X.pred <- logi_predict(X.knn.ratio_py, binary.labels_py, py$epochs)
    # if(linear){
    #   X.pred <- k_fold_predict_linear(X.knn.ratio_py, binary.labels_py, py$k_folds)
    # } else {
    #   X.pred <- k_fold_predict(X.knn.ratio_py, binary.labels_py, py$k_folds)
    # }
    X.pred <- as.numeric(X.pred)
    X.std <- NULL
  } else {
    cat("Running ", n.runs, " runs.\n", sep = "")
    X.pred.all <- list()
    for(ii in 1:n.runs){
      X.pred.all[[ii]] <- as.numeric(k_fold_predict_linear(X.knn.ratio_py, binary.labels_py, py$k_folds))
      # X.pred.all[[ii]] <- as.numeric(logi_predict(X.knn.ratio_py, binary.labels_py, py$epochs))
      # if(linear){
      #   X.pred.all[[ii]] <- as.numeric(k_fold_predict_linear(X.knn.ratio_py, binary.labels_py, py$k_folds))
      # } else {
      #   X.pred.all[[ii]] <- as.numeric(k_fold_predict(X.knn.ratio_py, binary.labels_py, py$k_folds))
      # }
    }
    X.pred.all <- do.call("rbind", X.pred.all)
    X.pred <- colMeans(X.pred.all)
    X.std <- apply(X.pred.all, 2, sd)
  }
  
  
  # select DA cells
  cat("Selecting top DA cells.\n")
  X.da.res <- removeNeutralCells(
    X = X.knn.ratio, ratio = ratio, i.max = i.max, keep.info = T
  pred.thres <- sort(pred.thres, decreasing = F)
  X.da.idx <- which(
    X.pred < quantile(X.pred, pred.thres[1]) | X.pred > quantile(X.pred, pred.thres[2])
  )
  X.da.idx <- X.da.res[[1]]
  
  # get top down clustering process plot
  
  # plot results
  if(do.plot & is.null(plot.embedding)){
    warning("plot.embedding must be provided by user if do.plot = T")
    X.da.plot <- NULL
    X.pred.plot <- NULL
    X.da.cells.plot <- NULL
  } else if(do.plot & !is.null(plot.embedding)){
    X.da.plot <- plotCellLabel(
      X = plot.embedding, label = as.factor(X.da.res[[2]]), 
      cell.col = c("black", brewer.pal(length(unique(X.da.res[[2]])),name = "Blues")[-1]),
      size = size, do.label = F, return.plot = T
    )
    X.pred.plot <- plotCellScore(
      X = plot.embedding, score = X.pred, size = size
    ) + theme(legend.title = element_blank())
    X.da.cells.plot <- plotDAsite(
      X = plot.embedding, 
      site.list = list(X.da.idx), 
      size = size
      site.list = list(
        which(X.pred < quantile(X.pred, pred.thres[1])),
        which(X.pred > quantile(X.pred, pred.thres[2]))
      ), 
      size = size, cols = c("blue","red")
    )
  } else {
    X.da.plot <- NULL
    X.pred.plot <- NULL
    X.da.cells.plot <- NULL
  }
  
  # return result
  return(list(
    da.ratio = X.knn.ratio,
    da.pred = X.pred, 
#    da.std = X.std, 
    da.cell.idx = X.da.idx,
    pred.plot = X.pred.plot,
    da.cells.plot = X.da.cells.plot
  ))
}



## Step1.1: play with threshold

#' @param X output from getDAcells()
#' @param pred.thres length-2 vector, top and bottom threshold on the predictions from the neural network, default c(0.05,0.95)
#' @param do.plot a logical value to indicate whether to return ggplot objects showing the results, default True
#' @param plot.embedding size N-by-2 matrix, 2D embedding for the cells
#' @param size cell size to use in the plot, default 0.5
#' 
#' @return a list of results with updated DA cells

updateDAcells <- function(
  X, pred.thres = c(0.05,0.95), do.plot = T, plot.embedding = NULL, size = 0.5
){
  # select DA cells
  X.pred <- X$da.pred
  pred.thres <- sort(pred.thres, decreasing = F)
  X.da.idx <- which(
    X.pred < quantile(X.pred, pred.thres[1]) | X.pred > quantile(X.pred, pred.thres[2])
  )
  
  # plot results
  if(do.plot & is.null(plot.embedding)){
    warning("plot.embedding must be provided by user if do.plot = T")
    X.pred.plot <- NULL
    X.da.cells.plot <- NULL
  } else if(do.plot & !is.null(plot.embedding)){
    X.pred.plot <- plotCellScore(
      X = plot.embedding, score = X.pred, size = size
    ) + theme(legend.title = element_blank())
    X.da.cells.plot <- plotDAsite(
      X = plot.embedding, 
      site.list = list(
        which(X.pred < quantile(X.pred, pred.thres[1])),
        which(X.pred > quantile(X.pred, pred.thres[2]))
      ), 
      size = size, cols = c("blue","red")
    )
  } else {
    X.pred.plot <- NULL
    X.da.cells.plot <- NULL
  }
  
  # return result
  return(list(
    da.ratio = X$da.ratio,
    da.pred = X.pred, 
    da.cell.idx = X.da.idx,
    da.plot = X.da.plot,
    pred.plot = X.pred.plot,
    da.cells.plot = X.da.cells.plot
  ))
}
@@ -114,8 +220,9 @@ getDAregion <- function(
  X, cell.idx, k, alpha, restr.fact = 50,
  cell.labels, labels.1, labels.2, 
  do.plot = T, plot.embedding = NULL, size = 0.5, 
  ...
  seed = 0, ...
){
  set.seed(seed)
  X.tclust <- runtclust(X, cell.idx, k, alpha, restr.fact, ...)
  X.n.da <- length(unique(X.tclust)) - 1
  X.da.stat <- matrix(0, nrow = X.n.da, ncol = 3)
@@ -131,10 +238,18 @@ getDAregion <- function(
    warning("plot.embedding must be provided by user if do.plot = T")
    X.region.plot <- NULL
  } else if(do.plot & !is.null(plot.embedding)){
    X.da.label <- rep(0,nrow(X))
    X.da.label[cell.idx] <- X.tclust
    X.da.order <- order(X.da.label, decreasing = F)
    X.region.plot <- plotCellLabel(
      X = plot.embedding[cell.idx,], label = as.factor(X.tclust), 
      size = size, do.label = F, return.plot = T
    ) + scale_color_manual(values = c(rgb(255,255,255,max = 255,alpha = 0),hue_pal()(X.n.da)), breaks = c(1:X.n.da))
      X = plot.embedding[X.da.order,], label = as.factor(X.da.label[X.da.order]), 
      size = size, do.label = F, return.plot = T, 
    ) + scale_color_manual(values = c("gray",hue_pal()(X.n.da)), breaks = c(1:X.n.da))
    # X.region.plot <- plotCellLabel(
    #   X = plot.embedding[cell.idx,], label = as.factor(X.tclust), 
    #   size = size, do.label = F, return.plot = T
    # ) + scale_color_manual(values = c(rgb(255,255,255,max = 255,alpha = 0),hue_pal()(X.n.da)), 
    #                        breaks = c(1:X.n.da))
  } else {
    X.region.plot <- NULL
  }
@@ -150,6 +265,116 @@ getDAregion <- function(

## Step 3: detect genes that characterize DA regions from Step 2

#' @param X matrix, normalized expression matrix of all cells in the dataset, genes are in rows, rownames must be gene names
#' @param cell.idx result "da.cell.idx" from the output of function getDAcells
#' @param da.region.label result "cluster.res" from the output of function getDAregion
#' @param da.regions.to.run numeric (vector), which DA regions to run the marker finder, default is to run all regions
#' @param lambda numeric, regularization parameter that weights the number of selected genes, a larger lambda leads to fewer genes, default 1
#' @param n.runs integer, number of runs to run the model, default 5
#' @param python.use character string, the Python to use, default "/usr/bin/python"
#' @param source.code character string, the neural network source code, default "./STG_model.py"
#' @param return.model a logical value to indicate whether to return the actual model of STG
#' @param GPU which GPU to use, default '', using CPU
#' 
#' @return a list of results:
#'         da.markers: a list of data frame with markers for each DA region
#'         accuracy: a numeric vector showing mean accuracy for each DA region
#'         model: a list of model for each DA region, each model contains:
#'                model: the model of STG of the final run
#'                features: features used to train the model
#'                selected.features: the selected features of the final run
#'                pred: the linear prediction value for each cell from the model

STGmarkerFinder <- function(
  X, cell.idx, da.region.label,
  da.regions.to.run = NULL, 
  lambda = 1, n.runs = 5, return.model = F, 
  python.use = "/usr/bin/python", source.code = "./DA_STG.py", GPU = ""
){
  # set Python
  use_python(python = python.use, required = T)
  
  # set GPU device
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))
  
  source_python(file = source.code)
  
  # turn X into Python format
  X.py <- r_to_py(as.matrix(X))
  
  # get DA regions to run
  n.da <- length(unique(da.region.label)) - 1
  if(is.null(da.regions.to.run)){
    da.regions.to.run <- c(1:n.da)
  }
  
  # create DA label vector
  n.cells <- ncol(X)
  da.label <- rep(0, n.cells)
  da.label[cell.idx] <- da.region.label
  
  
  # run model for each da region
  da.markers <- list()
  da.accr <- vector("numeric")
  da.model <- list()
  for(ii in da.regions.to.run){
    # prepare labels
    da.label.bin <- (da.label == ii)
    da.label.bin.py <- r_to_py(as.matrix(da.label.bin))
    da.label.bin.py <- da.label.bin.py$flatten()
    
    py_run_string(sprintf("num_run = %s; lam = %s", n.runs, lambda))
    stg.out <- STG_FS(X.py, da.label.bin.py, py$num_run, py$lam)
    da.markers[[as.character(ii)]] <- rownames(X)[unique(unlist(stg.out[[1]])) + 1]
    da.accr[[as.character(ii)]] <- mean(stg.out[[2]])
    da.model[[as.character(ii)]] <- list()
    da.model[[as.character(ii)]][["model"]] <- stg.out[[3]]
    da.model[[as.character(ii)]][["features"]] <- rownames(X)[stg.out[[4]] + 1]
    da.model[[as.character(ii)]][["selected.features"]] <- rownames(X)[stg.out[[1]][[n.runs]] + 1]
    da.model[[as.character(ii)]][["pred"]] <- stg.out[[5]][[1]][,2] - stg.out[[5]][[1]][,1]
    da.model[[as.character(ii)]][["alpha"]] <- as.numeric(stg.out[[6]])
    names(da.model[[as.character(ii)]][["alpha"]]) <- da.model[[as.character(ii)]][["features"]]
  }
  
  
  ## Get statistics for each gene
  da.markers.result <- list()
  for(ii in da.regions.to.run){
    da.markers.logfc <- sapply(da.markers[[as.character(ii)]], function(x, x.data1, x.data2){
      log2(mean(x.data1[x,] + 1/100000) / mean(x.data2[x,] + 1/100000))
    }, x.data1 = X[,da.label == ii], x.data2 = X[,da.label != ii])
    da.markers.pval <- sapply(da.markers[[as.character(ii)]], function(x, x.data1, x.data2){
      wilcox.test(x = x.data1[x,], y = x.data2[x,])$p.value
    }, x.data1 = X[,da.label == ii], x.data2 = X[,da.label != ii])
    da.markers.result[[as.character(ii)]] <- data.frame(
      gene = da.markers[[as.character(ii)]],
      avg_logFC = da.markers.logfc,
      p_value = da.markers.pval, 
      stringsAsFactors = F
    )
    
    da.markers.result[[as.character(ii)]] <- da.markers.result[[as.character(ii)]][
      order(da.markers.result[[as.character(ii)]][,"p_value"]),
    ]
  }
  
  # return output
  if(return.model){
    return(list(
      da.markers = da.markers.result,
      accuracy = da.accr,
      model = da.model
    ))
  } else {
    return(list(
      da.markers = da.markers.result,
      accuracy = da.accr
    ))
  }
}


#' @param cell.idx result "da.cell.idx" from the output of function getDAcells
#' @param da.region.label result "cluster.res" from the output of function getDAregion
#' @param obj Seurat object that contain ALL cells in the analysis
@@ -213,9 +438,13 @@ daPerCell <- function(
  for(ii in 1:n.cells){
    for(kk in k.vector){
      i.kk.label <- cell.labels[knn.graph[ii,1:kk]]
      i.kk.label.ratio <- table(factor(i.kk.label, levels = cell.label.name)) / cell.label.tab
      knn.diff.ratio[ii,as.character(kk)] <- (mean(i.kk.label.ratio[labels.2]) - mean(i.kk.label.ratio[labels.1])) / 
        sum(i.kk.label.ratio)
      i.kk.ratio1 <- sum(i.kk.label %in% labels.1) / sum(cell.labels %in% labels.1)
      i.kk.ratio2 <- sum(i.kk.label %in% labels.2) / sum(cell.labels %in% labels.2)
      knn.diff.ratio[ii,as.character(kk)] <- (i.kk.ratio2 - i.kk.ratio1) / (i.kk.ratio2 + i.kk.ratio1)
      # i.kk.label.ratio <- table(factor(i.kk.label, levels = cell.label.name)) / cell.label.tab
      # knn.diff.ratio[ii,as.character(kk)] <- 
      #   (mean(i.kk.label.ratio[labels.2]) - mean(i.kk.label.ratio[labels.1])) / 
      #   sum(i.kk.label.ratio)
      #  (mean(i.kk.label.ratio[labels.2]) + mean(i.kk.label.ratio[labels.1]))
    }
  }
@@ -225,40 +454,6 @@ daPerCell <- function(



# remove neutral cells by cluster knn.diff.ratio
removeNeutralCells <- function(X, ratio = 0.2, i.max = 10, keep.info = F){
  n <- nrow(X)
  idx.out <- c(1:n)
  remove.info <- rep(0,n)
  
  for(ii in 1:i.max){
    if((length(idx.out)/n) <= ratio & ii > 1){
      break
    }
    
    # cluster into 3 groups
    X.clust <- kmeans(X[idx.out,], centers = 3)
    X.mean.by.clust <- by(
      rowMeans(X[idx.out,]), INDICES = X.clust$cluster, FUN = mean
    )
    to.remove <- names(X.mean.by.clust)[order(X.mean.by.clust)[2]]
    # to.remove <- names(X.mean.by.clust)[which.min(abs(X.mean.by.clust))]
    
    remove.info[idx.out][which(X.clust$cluster %in% to.remove)] <- rep(ii, length(to.remove))
    
    idx.hat <- which(X.clust$cluster %in% setdiff(c(1:3),to.remove))
    idx.out <- idx.out[idx.hat]
  }
  
  if(keep.info){
    return(list(idx.out, remove.info))
  } else {
    return(idx.out)
  }
}



# get DA regions with tclust
runtclust <- function(X, cell.idx, k, alpha, restr.fact = 50, ...){
  X.tclust.res <- tclust(
@@ -301,6 +496,22 @@ getDAscore <- function(cell.labels, cell.idx, labels.1, labels.2){



# plot a score for each cell
plotCellScore <- function(X, score, cell.col = c("blue","white","red"), size = 0.5){
  # Add colnames for X
  colnames(X) <- c("Dim1","Dim2")
  
  # Plot cells with labels
  myggplot <- ggplot() + theme_cowplot() +
    geom_point(data = data.frame(Dim1 = X[,1], Dim2 = X[,2], Score = score), 
               aes(x = Dim1, y = Dim2, col = Score), size = size) + 
    scale_color_gradientn(colours = cell.col)
  
  return(myggplot)
}



# plot da site
plotDAsite <- function(X, site.list, size = 0.5, cols = NULL){
  colnames(X) <- c("Dim1","Dim2")
@@ -310,7 +521,7 @@ plotDAsite <- function(X, site.list, size = 0.5, cols = NULL){
    site.label[site.list[[ii]]] <- ii
  }
  
  myggplot <- ggplot() + theme_classic() + 
  myggplot <- ggplot() + theme_cowplot() + 
    geom_point(data = data.frame(X), aes(Dim1, Dim2), col = "gray", size = size) + 
    geom_point(
      data = data.frame(X[unlist(site.list),]), 
@@ -328,7 +539,7 @@ plotDAsite <- function(X, site.list, size = 0.5, cols = NULL){



plotCellLabel <- function(X, label, cell.col = NULL, size = 0.5, do.label = T, return.plot = F){
plotCellLabel <- function(X, label, cell.col = NULL, size = 0.5, do.label = T, return.plot = T){
  # Add colnames for X
  colnames(X) <- c("Dim1","Dim2")
  
@@ -356,3 +567,63 @@ plotCellLabel <- function(X, label, cell.col = NULL, size = 0.5, do.label = T, r
  if(return.plot) {return(myggplot)} else {print(myggplot)}
}



# STG in general
runSTG <- function(
  X, X.labels, lambda = 1, n.runs = 5, return.model = F, 
  python.use = "/usr/bin/python", source.code = "./STG_model.py"
){
  # set Python
  use_python(python = python.use, required = T)
  source_python(file = source.code)
  
  X.py <- r_to_py(as.matrix(X))
  
  X.label.bin <- (X.labels == 1)
  X.label.bin.py <- r_to_py(as.matrix(X.label.bin))
  X.label.bin.py <- X.label.bin.py$flatten()
  
  py_run_string(sprintf("num_run = %s; lam = %s", n.runs, lambda))
  stg.out <- STG_FS(X.py, X.label.bin.py, py$num_run, py$lam)
  da.markers <- rownames(X)[unique(unlist(stg.out[[1]])) + 1]
  da.accr <- mean(stg.out[[2]])
  da.model <- list()
  da.model[["model"]] <- stg.out[[3]]
  da.model[["features"]] <- rownames(X)[stg.out[[4]] + 1]
  da.model[["selected.features"]] <- rownames(X)[stg.out[[1]][[n.runs]] + 1]
  da.model[["pred"]] <- stg.out[[5]][[1]][,2] - stg.out[[5]][[1]][,1]
  
  da.markers.logfc <- sapply(da.markers, function(x, x.data1, x.data2){
    log2(mean(x.data1[x,] + 1/100000) / mean(x.data2[x,] + 1/100000))
  }, x.data1 = X[,X.labels == 1], x.data2 = X[,X.labels == 0])
  da.markers.pval <- sapply(da.markers, function(x, x.data1, x.data2){
    wilcox.test(x = x.data1[x,], y = x.data2[x,])$p.value
  }, x.data1 = X[,X.labels == 1], x.data2 = X[,X.labels == 0])
  da.markers.result <- data.frame(
    gene = da.markers,
    avg_logFC = da.markers.logfc,
    p_value = da.markers.pval, 
    stringsAsFactors = F
  )
  
  da.markers.result <- da.markers.result[
    order(da.markers.result[,"p_value"]),
    ]
  
  # return results
  if(return.model){
    return(list(
      da.markers = da.markers.result,
      accuracy = da.accr,
      model = da.model
    ))
  } else {
    return(list(
      da.markers = da.markers.result,
      accuracy = da.accr
    ))
  }
}

+155 −34

File changed.

Preview size limit exceeded, changes collapsed.

DA_STG.py

0 → 100644
+504 −0

File added.

Preview size limit exceeded, changes collapsed.

DA_logit.py

0 → 100644
+141 −0
Original line number Diff line number Diff line
import os, sys
import numpy as np
import pandas as pd

from keras import backend as K
from keras.models import Model
from keras.layers import Input, Dense
from keras.activations import relu
from keras.callbacks import EarlyStopping

def k_fold_split(x, p, k):
    return [x[p_] for p_ in np.split(p, k)]

def rev_split(xs, p):
    n = len(p)
    inv_p = np.empty(n)
    inv_p[p] = np.arange(n)
    inv_p = inv_p.astype(int)
    return np.concatenate(xs)[inv_p]

def make_splits(n, k):
    sizes = int(n / k)
    splits = [sizes] * (k - 1) + [sizes + n % k]
    for i in range(k - 1):
        splits[i+1] = splits[i] + splits[i+1]
    return splits[:-1]

def k_fold_predict(data, labels, k_folds, architecture=[8]*8, activations='relu', end_activation='sigmoid'):
    # os.environ['CUDA_VISIBLE_DEVICES'] = '4'
    # build layers

    layers = []

    for width in architecture:

        layers.append(Dense(width, activation=activations))

    layers.append(Dense(1, activation=end_activation))



    # build neural network

    input_shape = data.shape[1:]

    x = x0 = Input(shape=input_shape)

    for layer in layers:

        x = layer(x)



    model = Model(inputs=x0, outputs=x)



    y_tests = []



    p = np.random.permutation(len(data))



    for i in range(k_folds):

        val_idx = (i - 1) % k_folds

        test_idx = i

        k_folds_ = make_splits(len(data), k_folds)

        x_full, y_full = k_fold_split(data, p, k_folds_), k_fold_split(labels, p, k_folds_)

        x_val_, y_val_ = x_full[val_idx], y_full[val_idx]

        x_test_, y_test_ = x_full[test_idx], y_full[test_idx]
	# remove VALIDATION AND TEST SETS from the training set
        if k_folds > 2:
            del x_full[val_idx]
            del y_full[val_idx]
        # if val_idx came before test_idx, we have to remove the (test_idx - 1)th element (as we have already deleted val_idx so the index corresponding to the test set has changed)
        if k_folds > 1 and val_idx < test_idx:
            del x_full[(test_idx-1)]
            del y_full[(test_idx-1)]
        elif k_folds > 1:
        # otherwise, simply remove the (test_idx)th element
            del x_full[test_idx]
            del y_full[test_idx]

        x_train_, y_train_ = np.concatenate(x_full), np.concatenate(y_full)

        model.compile('adam', loss='binary_crossentropy', metrics=['acc'])



        epochs = 1000

        batch_size = len(x_train_)



        model.fit(

            x=x_train_,

            y=y_train_,

            epochs=epochs,

            batch_size=batch_size,

            validation_data=[x_val_, y_val_],

            callbacks=[EarlyStopping(patience=10)],

            verbose=0)



        print("Finished {} / {} folds.".format(i + 1, k_folds))



        y_tests.append(model.predict(x_test_).reshape((-1,)))



    y_full = rev_split(y_tests, p)



    return y_full



def k_fold_predict_linear(data, labels, k_folds):
    return k_fold_predict(data, labels, k_folds, architecture=[], activations=None, end_activation='sigmoid')