Commit dfd8f1e2 authored by smorabit's avatar smorabit
Browse files

mt cor plotting

parent 92cb15cd
Loading
Loading
Loading
Loading
+49 −12
Original line number Diff line number Diff line
@@ -1821,10 +1821,30 @@ ModuleTraitCorrelation <- function(
    }
  }

  # also correlate all cells:
  cor_list <- list()
  cor_list[['all_cells']] <- cor(as.matrix(trait_df), as.matrix(MEs), method=cor_method)
  # correlate all cells:
  cor_list <- list(); pval_list <- list(); fdr_list <- list()

  # testing other correlation function:
  temp <- Hmisc::rcorr(as.matrix(trait_df), as.matrix(MEs), type=cor_method)

  cur_cor <- temp$r[traits,mods]
  cur_p <- temp$P[traits,mods]

  # compute FDR:
  p_df <- cur_p %>%
    reshape2::melt() %>%
    dplyr::mutate(fdr=p.adjust(value, method='fdr')) %>%
    dplyr::select(c(Var1, Var2, fdr))

  # reshape to match cor & pval
  cur_fdr <- reshape2::dcast(p_df, Var1 ~ Var2, value.var='fdr')
  rownames(cur_fdr) <- cur_fdr$Var1
  cur_fdr <- cur_fdr[,-1]

  # add to list
  cor_list[["all_cells"]] <- cur_cor
  pval_list[["all_cells"]] <- cur_p
  fdr_list[["all_cells"]] <- cur_fdr

  trait_df <- cbind(trait_df, seurat_obj@meta.data[,group.by])
  colnames(trait_df)[ncol(trait_df)] <- 'group'
@@ -1845,21 +1865,37 @@ ModuleTraitCorrelation <- function(
  names(ME_list) <- group_names

  for(i in names(trait_list)){
    cor_list[[i]] <- cor(as.matrix(trait_list[[i]]), as.matrix(ME_list[[i]]), method=cor_method)
  }
    # cor_list[[i]] <- cor(as.matrix(trait_list[[i]]), as.matrix(ME_list[[i]]), method=cor_method)

    # testing other correlation function:
    temp <- Hmisc::rcorr(as.matrix(trait_list[[i]]), as.matrix(ME_list[[i]]))

  # compute the correlation matrix
  #cor_mat <- cor(as.matrix(trait_df), as.matrix(MEs), method=cor_method)
    cur_cor <- temp$r[traits,mods]
    cur_p <- temp$P[traits,mods]

  # this is the wgcna way but I think t-test doesn't make sense for single-cell data right???
  # cor_p  <- corPvalueStudent(cor_mat, ncol(seurat_obj))
    # compute FDR:
    p_df <- cur_p %>%
      reshape2::melt() %>%
      dplyr::mutate(fdr=p.adjust(value, method='fdr')) %>%
      dplyr::select(c(Var1, Var2, fdr))

    # reshape to match cor & pval
    cur_fdr <- reshape2::dcast(p_df, Var1 ~ Var2, value.var='fdr')
    rownames(cur_fdr) <- cur_fdr$Var1
    cur_fdr <- cur_fdr[,-1]

    # add to list
    cor_list[[i]] <- cur_cor
    pval_list[[i]] <- cur_p
    fdr_list[[i]] <- as.matrix(cur_fdr)

  }

  # add Module-trait correlations to the seruat object:
  mt_cor <- list(
    'cor_mat' = cor_list,
    'pval' = NA,
    'fdr' = NA
    'cor' = cor_list,
    'pval' = pval_list,
    'fdr' = fdr_list
  )

  seurat_obj <- SetModuleTraitCorrelation(seurat_obj, mt_cor, wgcna_name)
@@ -1868,6 +1904,7 @@ ModuleTraitCorrelation <- function(




#' ModulePreservation
#'
#' Computes module preservation statistics in a query dataset for a given reference dataset
+189 −0
Original line number Diff line number Diff line
@@ -1644,3 +1644,192 @@ PlotModulePreservation <- function(
  plot_list

}

#' PlotModuleTraitCorrelation
#'
#' Plotting function for Module Preservation statistics
#'
#' @param seurat_obj A Seurat object
#' @param
#' @param
#' @param
#' @param
#' @param plot_labels logical determining whether to plot the module labels#' @param wgcna_name The name of the scWGCNA experiment in the seurat_obj@misc slot
#' @keywords scRNA-seq
#' @export
#' @examples
#' PlotModulePreservation
PlotModuleTraitCorrelation <- function(
  seurat_obj,
  high_color = 'red',
  mid_color = 'grey90',
  low_color = 'blue',
  label = NULL,
  label_symbol = 'stars',
  plot_max = NULL,
  text_size = 2,
  text_color = 'black',
  text_digits = 3,
  combine = TRUE,
  wgcna_name = NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}


  # get the module trait correlation results:
  temp <- GetModuleTraitCorrelation(seurat_obj)
  cor_list <- temp$cor
  pval_list <- temp$pval
  fdr_list <- temp$fdr

  # get module colors:
  modules <- GetModules(seurat_obj, wgcna_name)
  module_colors <- modules %>% dplyr::select(c(module, color)) %>% distinct %>% subset(module != 'grey')
  mod_colors <- module_colors$color

  # dummy variable
  module_colors$var <- 1

  # make the colorbar as its own heatmap
  module_colorbar <- module_colors %>%
    ggplot(aes(x=module, y=var, fill=module)) +
    geom_tile() +
    scale_fill_manual(values=mod_colors) +
    NoLegend() +
    RotatedAxis() +
    theme(
      plot.title=element_blank(),
      axis.line=element_blank(),
      axis.ticks.y=element_blank(),
      axis.text.y = element_blank(),
      axis.title = element_blank(),
      plot.margin=margin(0,0,0,0)
    )


  plot_list <- list()
  for(i in names(cor_list)){
    cor_mat <- as.matrix(cor_list[[i]])
    pval_mat <- as.matrix(pval_list[[i]])
    fdr_mat <- as.matrix(fdr_list[[i]])
    print(i)

    plot_df <- reshape2::melt(cor_mat)
    colnames(plot_df) <- c("Trait", "Module", "cor")

    #p_df <- reshape2::melt(pval_mat)
    if(!is.null(label)){
      if(label == 'fdr'){
        p_df <- reshape2::melt(fdr_mat)
      } else if(label == 'pval'){
        p_df <- reshape2::melt(pval_mat)
      }
      colnames(p_df) <- c("Trait", "Module", "pval")

      # add pval to plot_df
      plot_df$pval <- p_df$pval
      print(levels(plot_df$Trait))

      if(label_symbol == 'stars'){
        plot_df$significance <- gtools::stars.pval(plot_df$pval)
      } else if(label_symbol == 'numeric'){
        plot_df$significance <- ifelse(
          plot_df$pval <= 0.05,
          formatC(plot_df$pval, digits=text_digits), ''
        )
      } else{
        stop('Invalid input for label_symbol. Valid choices are stars or numeric.')
      }
    }

    # get limits for plot:
    if(is.null(plot_max)){
      max_plot <- max(abs(range(plot_df$cor)))
    } else{
      max_plot <- plot_max

      # fix values outside of the specified range:
      plot_df$cor <- ifelse(abs(plot_df$cor) >= plot_max, plot_max * sign(plot_df$cor), plot_df$cor)
    }

    p <- ggplot(plot_df, aes(x=Module, y=as.numeric(Trait), fill=cor)) +
      geom_tile() +
      scale_fill_gradient2(
        limits=c(-1*max_plot,max_plot),
        high=high_color,
        mid=mid_color,
        low=low_color,
        guide = guide_colorbar(ticks=FALSE, barwidth=16, barheight=0.5)
      ) +
      scale_y_continuous(
        breaks = 1:length(levels(plot_df$Trait)),
        labels=levels(plot_df$Trait),
        sec.axis = sec_axis(
          ~.,
          breaks = 1:length(levels(plot_df$Trait)),
          labels=levels(plot_df$Trait)
        )
      ) +
      RotatedAxis() + ylab('') + xlab('') + ggtitle(i) +
      theme(
        plot.title=element_text(hjust=0.5),
        legend.title=element_blank(),
        axis.line=element_blank(),
        axis.ticks.y=element_blank(),
        axis.text.y.left = element_blank()
      )

    if(!is.null(label)){
      p <- p + geom_text(label=plot_df$significance, color=text_color, size=text_size)
    }

    plot_list[[i]] <- p

  }

  if(combine){

    #plot_list <- c(plot_list, cbar)
    #names(plot_list)[length(plot_list)] <- 'module'

    for(i in 1:length(plot_list)){

      plot_list[[i]] <- plot_list[[i]] +
        ylab(names(plot_list)[i]) +
        theme(
          plot.margin = margin(t = 0, r = 0, b = 0, l = 0),
          axis.title.x = element_blank(),
          plot.title = element_blank(),
          legend.position='bottom'
        )
      #
      # if(i != length(plot_list)){
      #   plot_list[[i]] <- plot_list[[i]] +
      #   theme(
      #     axis.text.x = element_blank(),
      #     axis.ticks=element_blank()
      #   )
      # }
    }

    # assemble with patchwork:

    out <- wrap_plots(c(plot_list, module_colorbar), ncol=1) +
      plot_layout(guides = 'collect') +
      plot_annotation(
        title='Module Trait Correlation',
        theme=theme(
          plot.title=element_text(hjust=0.5),
          legend.position = 'bottom',
          legend.justification = 0.5
        )
      )

    return(out)

  } else{
    return(plot_list)
  }

}