Commit 2b28449a authored by smorabit's avatar smorabit
Browse files

added module preservation analysis

parent bc726104
Loading
Loading
Loading
Loading
+86 −1
Original line number Diff line number Diff line
@@ -1552,7 +1552,9 @@ OverlapModulesMotifs <- function(
  overlap_df <- overlap_df %>% dplyr::select(c(module, tf, color, odds_ratio, pval, fdr, Significance, Jaccard, size_intersection))

  # add overlap df to Seurat obj
  seruat_obj <- SetMotifOverlap(seurat_obj, overlap_df, wgcna_name)
  seurat_obj <- SetMotifOverlap(seurat_obj, overlap_df, wgcna_name)

  seurat_obj

}

@@ -1850,3 +1852,86 @@ ModuleTraitCorrelation <- function(
  seurat_obj <- SetModuleTraitCorrelation(seurat_obj, mt_cor, wgcna_name)
  seurat_obj
}



#' ModulePreservation
#'
#' Computes module preservation statistics in a query dataset for a given reference dataset
#'
#'
#' @param seurat_obj A Seurat object
#' @param wgcna_name The name of the scWGCNA experiment in the seurat_obj@misc slot
#' @keywords scRNA-seq
#' @export
#' @examples
#' ModulePreservation
ModulePreservation <- function(
  seurat_obj,
  seurat_ref,
  name,
  n_permutations = 500,
  parallel = FALSE,
  seed = 12345,
  return_raw = FALSE,
  wgcna_name = NULL,
  wgcna_name_ref = NULL,
  ...
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}
  if(is.null(wgcna_name_ref)){wgcna_name_ref <- seurat_ref@misc$active_wgcna}

  print('Setup datasets')

  # get datExpr for reference and query:
  datExpr_ref <- GetDatExpr(seurat_ref, wgcna_name_ref)
  datExpr_query <- GetDatExpr(seurat_obj, wgcna_name)

  # set up multiExpr:
  setLabels <- c("ref", "query")
  multiExpr <- list(
    ref = list(data=datExpr_ref),
    query = list(data=datExpr_query)
  )
  ref_modules <- list(ref = GetModules(seurat_ref)$module)

  print('ref:')
  print(dim(datExpr_ref))
  print('query:')
  print(dim(datExpr_query))

  print('Run Module Preservation')

  # run the module preservation test:
  mp <- WGCNA::modulePreservation(
    multiExpr,
    ref_modules,
    referenceNetworks = 1,
    nPermutations = n_permutations,
    randomSeed = seed,
    quickCor = 0,
    parallelCalculation = parallel,
    ...
  )

  if(return_raw){return(mp)}

  # get the stats obs and stats Z tables
  ref <- 1; test <- 2;
  statsObs <- cbind(
    mp$quality$observed[[ref]][[test]],
    mp$preservation$observed[[ref]][[test]]
  )
  statsZ <- cbind(
    mp$quality$Z[[ref]][[test]],
    mp$preservation$Z[[ref]][[test]]
  )

  # add stats to the seurat object:
  mod_pres <- list('obs'  = statsObs, 'Z' = statsZ)
  seurat_obj <- SetModulePreservation(seurat_obj, mod_pres, name, wgcna_name)

  seurat_obj

}
+21 −0
Original line number Diff line number Diff line
@@ -682,6 +682,27 @@ GetModuleTraitCorrelation <- function(seurat_obj, wgcna_name=NULL){
  seurat_obj@misc[[wgcna_name]]$mt_cor
}

############################
# ModulePreservation
###########################

SetModulePreservation <- function(seurat_obj, mod_pres, mod_name, wgcna_name=NULL){
  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # make an empty list if module preservation hasn't been called yet
  if(is.null(seurat_obj@misc[[wgcna_name]]$module_preservation)){
    seurat_obj@misc[[wgcna_name]]$module_preservation <- list()
  }

  seurat_obj@misc[[wgcna_name]]$module_preservation[[mod_name]] <- mod_pres
  seurat_obj
}

GetModulePreservation <- function(seurat_obj, mod_name, wgcna_name=NULL){
  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}
  seurat_obj@misc[[wgcna_name]]$module_preservation[[mod_name]]
}


############################
# Reset module names:
+106 −0
Original line number Diff line number Diff line
@@ -1524,3 +1524,109 @@ DoHubGeneHeatmap <- function(
  out

}



PlotModulePreservation <- function(
  seurat_obj,
  name,
  statistics = 'summary', # can be summary, all, or a custom list
  plot_labels = TRUE,
  label_size = 4,
  mod_point_size = 4,
  wgcna_name = NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}


  # get the module preservation stats:
  mod_pres <- GetModulePreservation(seurat_obj, name, wgcna_name)
  obs_df <- mod_pres$obs
  Z_df <- mod_pres$Z

  # get module colors:
  modules <- GetModules(seurat_obj, wgcna_name)
  module_colors <- modules %>% dplyr::select(c(module, color)) %>% distinct
  mods <- rownames(Z_df)
  mod_colors <- module_colors$color[match(mods, module_colors$module)]
  mod_colors = ifelse(is.na(mod_colors), 'gold', mod_colors)

  # what are we going to plot?
  if(statistics == 'summary'){
    stat_list <- c("medianRank.pres", "Zsummary.pres")
  } else if(statistics == 'all'){
    stat_list <- c(colnames(obs_df[,-1]), colnames(Z_df[,-1]))
  } else{
    stat_list <- statistics
  }

  stat_list <- stat_list[stat_list != 'moduleSize']


  plot_list <- list()
  for(statistic in stat_list){

    print(statistic)

    if(statistic %in% colnames(obs_df)){
      values <- obs_df[,statistic]
    } else if(statistic %in% colnames(Z_df)){
      values <- Z_df[,statistic]
    } else{
      stop("Invalid name for statistic.")
    }

    # setup plotting df
    plot_df <- data.frame(
      module = mods,
      color = mod_colors,
      value = values,
      size = Z_df$moduleSize
    )

    # don't include grey & gold:
    plot_df <- plot_df %>% subset(!(module %in% c('grey', 'gold')))


    if(grepl("Rank", statistic)){
      cur_p <-  plot_df %>% ggplot(aes(x=size, y=value, fill=module, color=module)) +
        geom_point(size=mod_point_size, pch=21, color='black') +
        scale_y_reverse()
    } else{
      cur_p <- plot_df %>% ggplot(aes(x=size, y=value, fill=module, color=module)) +
        geom_rect(
          data = plot_df[1,],
          aes(xmin=0, xmax=Inf, ymin=-Inf, ymax=2), fill='grey75', alpha=0.8, color=NA) +
        geom_rect(
          data=plot_df[1,],
          aes(xmin=0, xmax=Inf, ymin=2, ymax=10), fill='grey92', alpha=0.8, color=NA) +
        geom_point(size=mod_point_size, pch=21, color='black')
    }

    cur_p <- cur_p +
      scale_fill_manual(values=plot_df$color) +
      scale_color_manual(values=plot_df$color) +
      scale_x_continuous(trans='log10') +
      ylab(statistic) +
      xlab("Module Size") +
      ggtitle(statistic) +
      NoLegend() +
      theme(
        plot.title = element_text(hjust = 0.5)
      )


    if(plot_labels){
      cur_p <- cur_p + geom_text_repel(label = plot_df$module, size=label_size)
    }

    plot_list[[statistic]] <- cur_p

  }

  if(length(plot_list) == 1){return(plot_list[[1]])}

  plot_list

}
+2 −0
Original line number Diff line number Diff line
@@ -415,6 +415,8 @@ the downstream tasks.

# re-load seurat obj
cur_seurat <- readRDS(file='data/test_wgcna_seurat.rds')

# Set Modules to NULL
cur_seurat <- SetModules(cur_seurat, NULL)

######################################################