library(gtools) # Needed for mixedsort

# Method to get background Tn5 insertion observations from the footprintingProject object
setGeneric("dispModel", function(x, ...) standardGeneric("dispModel"))

setMethod("dispModel", "footprintingProject", function(x, mode) x@dispModel[[mode]])

# Method to set background Tn5 insertion observations in the footprintingProject object
setGeneric("dispModel<-", function(x, value, ...) standardGeneric("dispModel<-"))

setMethod("dispModel<-", "footprintingProject", function(x, value, mode) {
  x@dispModel[[mode]] <- value
  x
})

# Method to get footprints from the footprintingProject object
setGeneric("footprints", function(x, ...) standardGeneric("footprints"))

setMethod("footprints", "footprintingProject", function(x, mode) x@footprints[[mode]])

# Method to set footprints from the footprintingProject object
setGeneric("footprints<-", function(x, value, ...) standardGeneric("footprints<-"))

setMethod("footprints<-", "footprintingProject", function(x, value, mode) {
  x@footprints[[mode]] <- value
  x
})

setGeneric("getFootprints",
           function(project, # footprintingProject object
                    mode, # Character. This is used for retrieving the correct dispersion model. 
                    footprintRadius, # Radius of the footprint region
                    flankRadius, # Radius of the flanking region (not including the footprint region)
                    chunkSize = 2000, # Chunk size for parallel processing of regions
                    returnCellTypeScores = F, # Whether to also return results per cell type
                    nCores = 16 # Number of cores to use
                    ) standardGeneric("getFootprints"))

setMethod("getFootprints", 
          signature = c(project = "footprintingProject"),
          function(project,
                   mode,
                   footprintRadius,
                   flankRadius,
                   chunkSize = 2000,
                   returnCellTypeScores = F,
                   nCores = 16) {
            
            # Directory for storing intermediate results
            tmpDir <- dataDir(project)
            
            # Determine chunk size
            if(is.null(chunkSize)){
              chunkSize <- regionChunkSize(project)
            }
            chunkSize <- min(chunkSize, length(regionRanges(project)))
            print(paste0("Using chunk size = ", chunkSize))

            if(length(countTensor(project)) != 0){
              chunkSize = min(chunkSize, length(countTensor(project)))
            }
            
            # Retrieve dispersion model from project
            dispersionModel <- dispModel(project, mode)
            
            # Perform footprinting (results are stored in a list of regions)
            cat("Computing footprinting scores of CREs\n")
            footprintResults <- get_footprints(projectCountTensor = countTensor(project),
                                               dispersionModel = dispersionModel,
                                               tmpDir = tmpDir,
                                               mode = mode,
                                               footprintRadius = footprintRadius,
                                               flankRadius = flankRadius,
                                               cellTypeLabels = groupCellType(project),
                                               chunkSize = chunkSize,
                                               returnCellTypeScores = returnCellTypeScores ,
                                               nCores = nCores)
            footprints(project, mode) <- footprintResults
            
            project
          })

# Calculate footprint scores for genomic regions
get_footprints <- function(projectCountTensor, # Region-by-position-by-pseudobulk Tn5 insertion count tensor generated by getCountTensor().
                           dispersionModel, # Background dispersion model for center-vs-(center + flank) ratio insertion ratio
                           tmpDir, # Directory to store intermediate results
                           mode, # Character. This is used for retrieving the correct dispersion model.
                           footprintRadius, # Radius of the footprint region
                           flankRadius, # Radius of the flanking region (not including the footprint region)
                           cellTypeLabels, # Character vector specifying cell types for each pseudobulk
                           chunkSize = 2000, # Chunk size for parallel processing of regions
                           returnCellTypeScores = F, # Whether to also return results per cell type
                           nCores = nCores # Number of cores to use
){
  
  # Create a folder for saving intermediate results
  chunkTmpDir <- paste(tmpDir, "chunkedFootprintResults/", mode, "/", sep = "")
  if(!dir.exists(chunkTmpDir)){
    system(paste("mkdir -p", chunkTmpDir))
  }
  
  # Get region data we need to use later
  width <- regionWidth(project)
  seqBias <- regionBias(project)
  regions <- regionRanges(project)
  
  # To reduce memory usage, we chunk the region list in to smaller chunks
  cat("Chunking data ..\n")
  groupIDs <- mixedsort(groups(project))
  chunkIntervals <- getChunkInterval(regionRanges(project), chunkSize = chunkSize)
  starts <- chunkIntervals[["starts"]]
  ends <- chunkIntervals[["ends"]]
  
  # Process each chunk
  for(i in 1:length(starts)){
    
    gc()
    
    # Select regions in the current chunk
    print(paste0("Processing region chunk ", i, " out of ", length(starts), " chunks"))
    print(Sys.time())
    
    # Get ATAC insertion data for the current chunk
    chunkRegions <- starts[i]:ends[i]
    if(length(projectCountTensor) == 0){
      chunkTensorDir <- paste0(tmpDir, "chunkedCountTensor/")
      chunkCountTensor <- readRDS(paste(chunkTensorDir, "chunk_",i, ".rds", sep = ""))
    }else{
      chunkCountTensor <- projectCountTensor[chunkRegions]
    }
    names(chunkCountTensor) <- chunkRegions
    
    # Skip current chunk if result already exists
    if(file.exists(paste(chunkTmpDir, "chunk_",i, ".rds", sep = ""))){
      next
    }
    
    print(Sys.time(), "\n")
    
    # Footprint calling and scoring.
    print("Footprint calling and scoring")
    cluster <- prep_cluster(chunkSize, n_cores = nCores)
    opts <- cluster[["opts"]]
    cl <- cluster[["cl"]]
    chunkFootprintResults <- foreach(
      regionInd = chunkRegions,
      .options.snow = opts, 
      .packages = c("dplyr","Matrix", "GenomicRanges"),
      .export = c("regionFootprintScoring", "conv", "footprintScoring", 
                  "findSummits", "getRegionATAC",
                  "footprintWindowSum", "predictDispersion",
                  "groupCellType", "mergeRegions")) %dopar% {
                    
                    # Get position-by-pseudobulk matrix of ATAC data for the current region
                    regionATAC <- getRegionATAC(countData = chunkCountTensor, 
                                                regionInd = as.character(regionInd), 
                                                groupIDs = groupIDs, 
                                                width = width[regionInd])
                    
                    # Calculate footprint scores for each region of the current batch
                    return(regionFootprintScoring(regionATAC = regionATAC,
                                                Tn5Bias = seqBias[regionInd,],
                                                dispersionModel = dispersionModel,
                                                footprintRadius = footprintRadius,
                                                flankRadius = flankRadius,
                                                regionID = as.character(regions[regionInd, ]),
                                                scoreType = "pval",
                                                cellTypeLabels = cellTypeLabels,
                                                returnCellTypeScores = returnCellTypeScores))
                  }
    stopCluster(cl)
    print("Finished!")
    
    # Save results
    saveRDS(chunkFootprintResults,
            paste(chunkTmpDir, "chunk_",i, ".rds", sep = ""))
  }
  
  # Integrate results for all chunks
  chunkFiles <- gtools::mixedsort(list.files(chunkTmpDir))
  chunkResults <- lapply(chunkFiles, function(f){readRDS(paste(chunkTmpDir, f, sep = ""))})
  names(chunkResults) <- sapply(chunkFiles, function(f){strsplit(f, "\\.")[[1]][1]})
  footprintResults <- Reduce(c, chunkResults)
  
  # Remove intermediate results
  rm(chunkResults)
  
  # Return results
  footprintResults
}

# Calculate footprint score track for a single genomic region
regionFootprintScoring <- function(regionATAC, # Position-by-pseudobulk matrix of ATAC data for the current region
                                 Tn5Bias, # Numeric vector of predicted Tn5 bias for the current region
                                 dispersionModel, # Background dispersion model for center-vs-(center + flank) ratio insertion ratio
                                 footprintRadius, # Radius of the footprint region
                                 flankRadius, # Radius of the flanking region (not including the footprint region)
                                 regionID = "", # Character, ID of the region
                                 scoreType = "fdr", # Whether to use -log10 of "fdr" or "pval" as scores
                                 cellTypeLabels = NULL, # Character vector specifying cell types for each pseudobulk
                                 returnCellTypeScores = F # Whether to also return results per cell type
                                 ){
  
  width <- length(Tn5Bias)
  aggregateATAC <- rowSums(regionATAC)
  
  # Calculate the position-by-pseudobulk footprint pvalue matrix
  # (We made sure rows are positions)
  footprintPvalMatrix <- sapply(
    1:dim(regionATAC)[2],
    function(groupInd){
      Tn5Insertion <- regionATAC[, groupInd]
      footprintScoring(
        Tn5Insertion = Tn5Insertion,
        Tn5Bias = Tn5Bias,
        dispersionModel = dispersionModel,
        footprintRadius = footprintRadius,
        flankRadius = flankRadius
      )
    }
  )
  
  # Generate footprint p-val tracks for each cell type
  if(is.null(cellTypeLabels) | length(cellTypeLabels) == 0){
    cellTypeLabels = rep("sample", dim(regionATAC)[2])
  }
  cellTypes <- unique(cellTypeLabels)
  cellTypePvals <- sapply(
    cellTypes,
    function(cellType){
      cellTypeFilter <- cellTypeLabels %in% cellType
      Tn5Insertion <- rowSums(regionATAC[, cellTypeFilter, drop = F])
      footprintScoring(
        Tn5Insertion = Tn5Insertion,
        Tn5Bias = Tn5Bias,
        dispersionModel = dispersionModel,
        footprintRadius = footprintRadius,
        flankRadius = flankRadius
      )
    }
  )
  
  # For each cell type, integrate pvalues at the same position across pseudobulks using Fisher's method
  smoothRadius <- as.integer(footprintRadius / 2)
  if(length(unique(cellTypeLabels)) > 1){
    aggregatePvals <- sapply(
      1:length(Tn5Bias),
      function(position){
        rawPval <- cellTypePvals[position, ]
        rawPval <- rawPval[!is.na(rawPval)] # Remove NA values (corresponding to positions with zero coverage)
        survcomp::combine.test(rawPval) # Combine p-values
      }
    )
  }else{
    aggregatePvals <- cellTypePvals
  }
  aggregateScores <- -log10(aggregatePvals)
  aggregateScores <- caTools::runmax(aggregateScores, 2 * smoothRadius)
  aggregateScores <- conv(aggregateScores, smoothRadius) / (2 * smoothRadius)
  
  # Convert pvals to footprint scores and smooth
  cellTypeScores <- sapply(
    cellTypes,
    function(cellType){
      pvals <- cellTypePvals[, cellType]
      fdr <- p.adjust(pvals, method = "fdr")
      if(scoreType == "fdr"){
        scores <- -log10(fdr)
      }else if(scoreType == "pval"){
        scores <- -log10(pvals)
      }
      scores[!is.finite(scores)] <- 0
      scores <- caTools::runmax(scores, 2 * smoothRadius)
      scores <- conv(scores, smoothRadius) / (2 * smoothRadius)
    }
  )
  
  # Filter NA values from the position-by-pseudobulk footprint score matrix
  footprintPvalMatrix[is.na(footprintPvalMatrix)] <- 1 # Set NA values to be pvalue = 1
  
  # Log-transform to get the scores
  pvalScoreMatrix <- -log10(footprintPvalMatrix)
  
  # Smooth the footprint score matrix
  pvalScoreMatrix <- sapply(
    1:dim(regionATAC)[2],
    function(groupInd){
      scores <- caTools::runmax(pvalScoreMatrix[,groupInd], 2 * smoothRadius)
      conv(scores, footprintRadius) / (2 * footprintRadius)
    }
  )
  
  # Find footprint summits for each cell type
  cellTypeSummits <- lapply(
    cellTypes,
    function(cellType){
      
      # Choose threshold based on whether we are using FDR scores or pvalue scores
      if(scoreType == "fdr"){
        threshold <- 1
      }else if(scoreType == "pval"){
        threshold <- -log10(0.01)
      }
      
      # Detect footprint score summits
      summits <- findSummits(cellTypeScores[, cellType], 
                             r = footprintRadius,
                             threshold = threshold)
      
      # Get genomic ranges of the footprints centered around the summits
      if(length(summits) > 0){
        footprintRanges <- IRanges::resize(GenomicRanges::GRanges(paste0("summit:", summits, "-", summits)), 
                                           width = 2 * footprintRadius, fix = "center")
        mcols(footprintRanges)$score <- cellTypeScores[summits, cellType]
        mcols(footprintRanges)$cellType <- cellType
        footprintRanges
        
      }else{
        NULL
      }
    }
  )
  cellTypeSummits <- cellTypeSummits[!sapply(cellTypeSummits, is.null)]
  cellTypeSummits <- Reduce(c, cellTypeSummits)
  
  # Integrate footprint summits across cell types.
  # Basically, for any overlapping footprints, keep the most significant one.
  if(!is.null(cellTypeSummits)){
    footprintRanges <- mergeRegions(cellTypeSummits)
    summits <- start(IRanges::resize(footprintRanges, width = 1, fix = "center"))
  }else{
    footprintRanges <- NULL
    summits <- NULL
  }
  
  # For the individual footprints we called previously on the aggregate data
  # we now retrieve footprint pval scors for them in each pseudobulk
  summitPvalscores <- pvalScoreMatrix[summits, , drop = F]
  
  # Integrate the results
  results <- list("regionID" = regionID,
                  "summitPvalscores" = summitPvalscores,
                  "aggregateATAC" = aggregateATAC,
                  "footprintRanges" = footprintRanges,
                  "aggregateScores" = aggregateScores,
                  "summits" = summits,
                  "bias" = Tn5Bias,
                  "summitCellTypeScores" = cellTypeScores[summits, ])
  
  if(returnCellTypeScores){
    results[["cellTypeScores"]] <- cellTypeScores
  }
  
  results
}

# Given a value vector x, for every single position, calculate sum of values in center and flanking window 
# Essentially we are doing a running-window sum for the center and flanking
footprintWindowSum <- function(x, # A numerical or integer vector x
                               footprintRadius, # Radius of the footprint region
                               flankRadius # Radius of the flanking region (not including the footprint region)
                               ){
  
  halfFlankRadius <- as.integer(flankRadius / 2)
  width <- length(x)
  
  # Calculate sum of x in the left flanking window
  shift <- halfFlankRadius + footprintRadius
  leftShifted <- c(rep(0, shift), x)
  leftFlankSum <- conv(leftShifted, halfFlankRadius)[1:width]
  
  # Calculate sum of x in the right flanking window
  rightShifted <- c(x, rep(0, shift))  
  rightFlankSum <- conv(rightShifted, halfFlankRadius)[(shift + 1):(width + shift)]
  
  centerSum <- conv(x, footprintRadius)
  
  list(leftFlank = leftFlankSum,
       center = centerSum,
       rightFlank = rightFlankSum)
}

# Calculate footprint score track for a single region and a single sample
footprintScoring <- function(Tn5Insertion, # Integer vector of raw observed Tn5 insertion counts at each single base pair
                             Tn5Bias, # Vector of predicted Tn5 bias. Should be the same length
                             dispersionModel, # Background dispersion model for center-vs-(center + flank) ratio insertion ratio
                             footprintRadius = 10, # Radius of the footprint region
                             flankRadius = 10 # Radius of the flanking region (not including the footprint region)
){
  
  modelWeights <- dispersionModel$modelWeights
  
  # Get sum of predicted bias in left flanking, center, and right flanking windows
  biasWindowSums <- footprintWindowSum(Tn5Bias, 
                                       footprintRadius, 
                                       flankRadius)
  
  # Get sum of insertion counts in left flanking, center, and right flanking windows
  insertionWindowSums <- footprintWindowSum(Tn5Insertion, 
                                            footprintRadius, 
                                            flankRadius)
  
  leftTotalInsertion <- insertionWindowSums$center + insertionWindowSums$leftFlank
  rightTotalInsertion <- insertionWindowSums$center + insertionWindowSums$rightFlank
  
  # Prepare input data (foreground features) for the dispersion model
  fgFeatures <- matrix(unlist(list(biasWindowSums$leftFlank, 
                                   biasWindowSums$rightFlank,
                                   biasWindowSums$center,
                                   log10(leftTotalInsertion),
                                   log10(rightTotalInsertion))), 
                       byrow = F, ncol = 5)
  fgFeaturesScaled <- t(t(fgFeatures) - dispersionModel$featureMean)
  fgFeaturesScaled <- t(t(fgFeaturesScaled) / dispersionModel$featureSD)
  
  # Given observed center bias, flank bias, and total insertion, use our model to estimate background
  # dispersion of and background mean of center-vs-(center + flank) ratio
  predDispersion <- predictDispersion(as.matrix(fgFeaturesScaled), modelWeights)
  predDispersion <- t(t(predDispersion) * dispersionModel$targetSD)
  predDispersion <- t(t(predDispersion) + dispersionModel$targetMean)
  leftPredRatioMean <- predDispersion[, 1] 
  leftPredRatioSD <- predDispersion[, 2]
  rightPredRatioMean <- predDispersion[, 3] 
  rightPredRatioSD <- predDispersion[, 4]
  
  # Calculate foreground (observed) center-vs-(center + flank) ratio
  fgLeftRatio <- insertionWindowSums$center / leftTotalInsertion
  fgRightRatio <- insertionWindowSums$center / rightTotalInsertion
  
  # Compute p-value based on background mean and dispersion
  leftPval <- pnorm(fgLeftRatio, leftPredRatioMean, leftPredRatioSD)
  leftPval[is.na(leftPval)] <- 1
  rightPval <- pnorm(fgRightRatio, rightPredRatioMean, rightPredRatioSD)
  rightPval[is.na(rightPval)] <- 1
  
  # Combine test results for left flank and right flank by taking the bigger pval
  p <- leftPval
  p[leftPval < rightPval] <- rightPval[leftPval < rightPval]
  
  # Mask positions with zero coverage on either flanking side
  p[(leftTotalInsertion < 1) | (rightTotalInsertion < 1)] <- 1
  
  p
}

setGeneric("regionFootprintMatrix", function(project, # footprintingProject object
                                           regionInd, # Index of the region
                                           lineageGroups = NULL, # If we want to only run on selected pseudobulks, provide their indices here
                                           footprintRadius = 20, # Radius of the footprint region
                                           nCores = 16, # Number of cores to use
                                           ...) 
  standardGeneric("regionFootprintMatrix"))

setMethod("regionFootprintMatrix", "footprintingProject", 
          function(project, 
                   regionInd, 
                   lineageGroups = NULL, 
                   footprintRadius = 20,
                   nCores = 16){
            
            Tn5Bias <- regionBias(project)[regionInd, ]
            dispersionModel <- dispModel(project, as.character(footprintRadius))
            smoothRadius <- as.integer(footprintRadius / 2)
            
            # If we only stored the countTensor in chunks. Determine which chunk we need 
            ret <- getCountData(project, regionInd)
            countData <- ret[["countData"]]
            chunkRegionInd <- ret[["regionInd"]]
            
            # Specift the pseudobulks we want to footprint
            if(is.null(lineageGroups)){
              groupIDs <- groups(project)
            }else{
              groupIDs <- lineageGroups
            }
            
            # Get position-by-pseudobulk matrix of ATAC data for the current region
            regionATAC <- getRegionATAC(countData, chunkRegionInd, groupIDs, length(Tn5Bias))
            
            # For each pseudobulk, calculate the footprint score track
            footprintScoreMatrix <- pbmcapply::pbmcmapply(
              function(groupInd){
                Tn5Insertion <- regionATAC[, groupInd]
                pvals <- footprintScoring(
                  Tn5Insertion = Tn5Insertion,
                  Tn5Bias = Tn5Bias,
                  dispersionModel = dispersionModel,
                  footprintRadius = footprintRadius,
                  flankRadius = footprintRadius
                )
                pvals <- pmax(pvals, 1e-300) # Prevent p-vals of 0
                scores <- -log10(pvals)
                scores <- caTools::runmax(scores, 2 * smoothRadius)
                scores <- conv(scores, smoothRadius) / (2 * smoothRadius)
                scores
              },
              1:dim(regionATAC)[2],
              mc.cores = nCores
            )
            
            footprintScoreMatrix
          })

# Loads Keras model weights and predict dispersion
predictDispersion <- function(x, # Model input
                              modelWeights # Model weights
                              ){
  
  # NN layer 1
  x <- x %*% modelWeights[[1]] # Linear transform
  x <- t(t(x) + as.numeric(modelWeights[[2]])) # Bias
  x[x < 0] <- 0 # ReLU activation
  
  # NN layer 2
  x <- x %*% modelWeights[[3]] # Linear transform
  x <- t(t(x) + as.numeric(modelWeights[[4]])) # Bias
  x # Linear activation at the end
}

# For a specific region, extract the Tn5 insertion count tensor for the corresponding region chunk
# Returns the count tensor chunk and the index of this region with respect to that chunk.
# For example, if chunk size = 2000, regionInd = 3500, this function extracts data for chunk No.2,
# which contains data from peak 2000 to 4000. The returned regionInd would be 1500
getCountData <- function(project, # footprintingProject object
                         regionInd # Index of the region
                         ){
  
  # If we only stored the countTensor in chunks. Determine which chunk we need 
  if(length(countTensor(project)) == 0){
    chunkSize <- regionChunkSize(project)
    chunkIntervals <- getChunkInterval(regionRanges(project), chunkSize = chunkSize)
    chunkInd <- max(which(regionInd >= chunkIntervals$starts))
    chunkPath <- paste0(dataDir(project), "chunkedCountTensor/chunk_", chunkInd, ".rds")
    countData <- readRDS(chunkPath)
    regionInd <- regionInd - (chunkInd - 1) * chunkSize 
  }else{
    countData <- countTensor(project)
  }
  
  list("countData" = countData, 
       "regionInd" = regionInd)
}

# Retrieve the position-by-pseudobulk ATAC insertion matrix for a specific genomic regoin
getRegionATAC <- function(countData, # Tn5 insertion count tensor
                          regionInd, # Index of the region
                          groupIDs, # Indices of pseudobulks we want to keep
                          width # Width of the region
                          ){
  sapply(
    groupIDs,
    function(groupID){
      groupRegionATAC <- countData[[regionInd]] %>% filter(group %in% groupID)
      regionTrack <- rep(0, width)
      regionTrack[groupRegionATAC$position] <- groupRegionATAC$count
      regionTrack
    }
  )
}

# Fast running window sum. This runs a window with radius = r across the input vector x and calculates running sum
# For any position i, this sums x[(i - r + 1) : (i + r)]
conv <- function(x, # Input vector x
                 r # Window radius
                 ){
  smoothKernel <- rep(1, 2 * r)
  xConv <- cladoRcpp::rcpp_convolve(x, smoothKernel)
  xConv[(r + 1):(length(x) + r)]
}

# Given a vector x, find signal peaks
findSummits <- function(x, # Input vector x
                        r = 10, # Radius of the peak 
                        threshold = 1 # Signal threshold
                        ){
  
  # Locate summits by finding locations of local max
  summits <- which((x == caTools::runmax(x, 2 * r))&
                     (x > threshold))
  
  if(length(summits) > 1){
    # Merge summits that are too close to each other
    # Essentially, for the n summites we found, examine the 1st to the n-1 th
    # If the distance of any of them to the next is smaller than a threshold
    # Then remove that summit and only keep the next one
    summits <- c(summits[1:(length(summits) - 1)][diff(summits) > r], 
                 summits[length(summits)])
  }
  
  summits
}