Commit a44f5729 authored by smorabit's avatar smorabit
Browse files

wrote ROC curve functions

parent 21d149f8
Loading
Loading
Loading
Loading
+179 −0
Original line number Diff line number Diff line
@@ -900,3 +900,182 @@ TransferModuleGenome <- function(
  modules

}


#' ComputeROC
#'
#'
#' @keywords scRNA-seq
#' @export
#' @examples
#' ComputeROC
ComputeROC <- function(
  seurat_obj, group.by=NULL,
  split_col=NULL, # needs to be a logical!!!
  features = 'hMEs',
  seurat_test=NULL,
  harmony_group_vars=NULL,
  scale_genes=TRUE,
  verbose=FALSE,
  exp_thresh = 0.75,
  wgcna_name=NULL, wgcna_name_test=NULL
){

  #TODO:
  # should be able to skip the ME computation if

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get Modules
  modules <- GetModules(seurat_obj, wgcna_name)
  mods <- levels(modules$module)
  mods <- mods[mods != 'grey']

  # if group.by column is null, use Idents:
  if(is.null(group.by)){
    group.by <- 'roc_group'
    seurat_obj@meta.data[[group.by]] <- Idents(seurat_obj)
  }

  # get names of different cell groupings
  groups <- as.character(unique(Idents(seurat_obj)))
  groups <- groups[order(groups)]

  # split seurat object into training & testing if the testing is not provided
  if(is.null(seurat_test)){

    print('splitting seurat obj')

    # split into two seurat objects based on train & test split:
    seurat_train <- seurat_obj[,seurat_obj@meta.data[[split_col]]]
    seurat_test <- seurat_obj[,!seurat_obj@meta.data[[split_col]]]

    # project modules for train
    wgcna_name_train <- "ROC"
    seurat_train <- ProjectModules(
      seurat_train,
      seurat_ref = seurat_obj,
      group.by.vars=harmony_group_vars,
      wgcna_name_proj=wgcna_name_train,
      scale_genes=scale_genes, verbose=verbose
    )

  } else{

    if(is.null(wgcna_name_test)){wgcna_name_test <- seurat_test@misc$active_wgcna}

    # if group.by column is null, use Idents:
    if(is.null(group.by)){
      group.by <- 'roc_group'
      seurat_test@meta.data[[group.by]] <- Idents(seurat_test)
    }

    seurat_train <- seurat_obj
    wgcna_name_train <- wgcna_name

  }

  print('here')

  # get names of different cell groupings in test
  groups_test <- as.character(unique(Idents(seurat_test)))
  groups_test <- groups_test[order(groups_test)]

  # check if groups are equal:
  if(sum(groups == groups_test) != length(groups)){
    stop("Different groups present in train & test data. Idents likely do not match.")
  }

  # project modules for test
  seurat_test <- ProjectModules(
    seurat_test,
    seurat_ref = seurat_obj,
    group.by.vars=harmony_group_vars,
    wgcna_name_proj="ROC",
    scale_genes=scale_genes, verbose=verbose
  )

  # get MEs from seurat object
  if(features == 'hMEs'){
    MEs <- GetMEs(seurat_train, TRUE, wgcna_name_train)
    MEs_p <- GetMEs(seurat_test, TRUE, "ROC")
  } else if(features == 'MEs'){
    MEs <- GetMEs(seurat_train, FALSE, wgcna_name_train)
    MEs_p <- GetMEs(seurat_test, FALSE, "ROC")
  } else if(features == 'scores'){
    MEs <- GetModuleScores(seurat_train, wgcna_name_train)
    MEs_p <- GetModuleScores(seurat_test, "ROC")
    stop("Haven't implemented this one yet >.<")
  } else(
    stop('Invalid feature selection. Valid choices: hMEs, MEs, scores.')
  )

  print('train')
  print(dim(MEs))
  print(dim(seurat_train))
  print('test')
  print(dim(MEs_p))
  print(dim(seurat_test))

  # add group column to MEs:
  MEs <- as.data.frame(MEs) %>% mutate(group = seurat_train@meta.data[[group.by]])
  MEs_p <- as.data.frame(MEs_p) %>% mutate(group = seurat_test@meta.data[[group.by]])

  # compute average MEs in each group:
  avg_MEs <- MEs %>% group_by(group) %>% summarise(across(!!mods, mean))
  avg_MEs_p <- MEs_p %>% group_by(group) %>% summarise(across(!!mods, mean))
  groups <- avg_MEs$group

  # scale each column between 0 & 1
  avg_MEs <- avg_MEs %>% summarise(across(!!mods, scale01))
  avg_MEs_p <- avg_MEs_p %>% summarise(across(!!mods, scale01))

  # convert to binary labels:
  labels <- avg_MEs %>% purrr::map(~ifelse(. >= exp_thresh, TRUE, FALSE))
  labels <- as.data.frame(do.call(cbind, labels));
  rownames(labels) <- as.character(groups)

  # loop over modules to compute ROC curves:
  plot_df <- data.frame()
  conf_df <- data.frame()
  auc_list <- list()
  mod_colors <- list()
  roc_list <- list()
  for(cur_mod in mods){
    print(cur_mod)
    cur_color <- modules %>% subset(module == cur_mod) %>% .$color %>% unique
    mod_colors[[cur_mod]] <- cur_color

    # compute ROC
    rocobj <- pROC::roc(labels[,cur_mod], avg_MEs_p[[cur_mod]])
    auc_list[[cur_mod]] <- as.numeric(rocobj$auc)
    roc_list[[cur_mod]] <- rocobj

    # update plotting df
    cur_df <- data.frame(
      specificity = 1-rocobj$specificities,
      sensitivity = rocobj$sensitivities,
      module = cur_mod,
      color = cur_color,
      auc = as.numeric(rocobj$auc)
    )
    plot_df <- rbind(plot_df, cur_df)

    # update confidence interval df
    cur_conf <- as.data.frame(pROC::ci.se(rocobj))
    cur_conf$sensitivity <- 1-as.numeric(rownames(cur_conf))
    cur_conf$module <- cur_mod
    cur_conf$color <- cur_color
    conf_df <- rbind(conf_df, cur_conf)

  }
  colnames(conf_df)[1:3] <- c('lo', 'mid', 'hi')

  # return ROC tables & objects:
  list(
    roc = plot_df,
    conf = conf_df,
    objects = roc_list
  )

}
+37 −21
Original line number Diff line number Diff line
@@ -391,6 +391,10 @@ ResetModuleNames <- function(
  wgcna_name=NULL
){

  #TODO:
  # only re-name things if they exist. For example, skip the avg mod exp step
  # if we never ran it!

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get modules
@@ -409,6 +413,8 @@ ResetModuleNames <- function(
    new_names <- c(new_names[1:(grey_ind-1)], 'grey', new_names[grey_ind:length(new_names)])
  }

  print('here')

  # update kMEs
  new_kMEs <- paste0('kME_', new_names)
  colnames(modules) <- c(colnames(modules)[1:3], new_kMEs)
@@ -428,40 +434,50 @@ ResetModuleNames <- function(
  seurat_obj <- SetModules(seurat_obj, modules, wgcna_name)

  # update hME table:
  hMEs <- GetMEs(seurat_obj, wgcna_name)
  hMEs <- GetMEs(seurat_obj, harmonized=TRUE, wgcna_name)
  if(!is.na(hMEs)){
    colnames(hMEs) <- new_mod_df$new
    seurat_obj <- SetMEs(seurat_obj, hMEs, harmonized=TRUE, wgcna_name)
  }

  # update ME table
  MEs <- GetMEs(seurat_obj, harmonized=FALSE, wgcna_name)
  if(!is.na(MEs)){
    colnames(MEs) <- new_mod_df$new
    seurat_obj <- SetMEs(seurat_obj, MEs, harmonized=FALSE, wgcna_name)
  }

  # update module scores:
  module_scores <- GetModuleScores(seurat_obj, wgcna_name)
  if(!is.na(module_scores)){
    if(!("grey" %in% colnames(module_scores))){
      colnames(module_scores) <- new_mod_df$new[new_mod_df$new != 'grey']
    } else {
      colnames(module_scores) <- new_mod_df$new
    }
    seurat_obj <- SetModuleScores(seurat_obj, module_scores, wgcna_name)
  }

  # update average module expression:
  avg_exp <- GetAvgModuleExpr(seurat_obj, wgcna_name)
  if(!is.na(avg_exp)){
    if(!("grey" %in% colnames(avg_exp))){
      colnames(avg_exp) <- new_mod_df$new[new_mod_df$new != 'grey']
    } else {
      colnames(avg_exp) <- new_mod_df$new
    }
    seurat_obj <- SetAvgModuleExpr(seurat_obj, avg_exp, wgcna_name)
  }

  # update enrichr table:
  enrich_table <- GetEnrichrTable(seurat_obj, wgcna_name)
  if(!is.na(enrich_table)){
    enrich_table$module <- factor(
      new_mod_df[match(enrich_table$module, new_mod_df$old),'new'],
      levels = as.character(new_mod_df$new)
    )
    seurat_obj <- SetEnrichrTable(seurat_obj, enrich_table, wgcna_name)
  }

  seurat_obj

+66 −0
Original line number Diff line number Diff line
@@ -975,3 +975,69 @@ OverlapBarPlot <- function(
  plot_list

}


#' ROCCurves
#'
#' Makes barplots from Enrichr data
#'
#' @param seurat_obj A Seurat object
#' @param dbs List of EnrichR databases
#' @param max_genes Max number of genes to include per module, ranked by kME.
#' @param wgcna_name
#' @keywords scRNA-seq
#' @export
#' @examples
#' ROCCurves
ROCCurves <- function(
  seurat_obj,
  roc_df, conf_df, wgcna_name=NULL
){

  if(is.null(wgcna_name)){wgcna_name <- seurat_obj@misc$active_wgcna}

  # get Modules
  modules <- GetModules(seurat_obj)
  mods <- levels(modules$module)
  mods <- mods[mods != 'grey']

  # get module colors:
  mod_colors <- modules %>% subset(module %in% mods) %>%
    select(c(module, color)) %>%
    distinct %>%
    arrange(module) %>% .$color

  # plot the ROC curve
  roc_df <- roc_df %>% group_by(module) %>% arrange(sensitivity)
  conf_df <- conf_df %>% group_by(module) %>% arrange(sensitivity)
  auc_df <- distinct(roc_df[,c('module', 'auc')])

  # set factor levels for modules:
  roc_df$module <- factor(roc_df$module, levels=mods)
  conf_df$module <- factor(conf_df$module, levels=mods)
  auc_df$module <- factor(auc_df$module, levels=mods)

  p <- roc_df %>% ggplot(
    aes(x=specificity, y=sensitivity, color=module, fill=module),
  ) +
    geom_line() +
    geom_ribbon(
      data=conf_df,
      aes(x = sensitivity, ymin=lo, ymax=hi, fill=module),
      inherit.aes=FALSE, alpha=0.4
    ) +
    scale_color_manual(values = unlist(mod_colors)) +
    scale_fill_manual(values = unlist(mod_colors)) +
    scale_x_continuous(breaks = c(0, 0.5, 1), labels=c("0", "0.5", "1")) +
    scale_y_continuous(breaks = c(0, 0.5, 1), labels=c("0", "0.5", "1")) +
    xlab("1 - Specificity (FPR)") + ylab("Sensitivity (TPR)") +
    geom_text(
      data = auc_df,
      aes(color=module),
      x=0.75, y=0.1, label=paste0("AUC: ", format(auc_df$auc, digits=2)),
      inherit.aes=FALSE, size=4, color='black'
    )

  p

}
+60 −89
Original line number Diff line number Diff line
@@ -181,8 +181,27 @@ seurat_obj <- ModuleConnectivity(seurat_obj)

# compute module hub gene scores:
seurat_obj <- ModuleExprScore(seurat_obj, n_genes = 25, method='Seurat')
seurat_obj <- AvgModuleExpr(seurat_obj, n_genes = 100)

plot_list <- ModuleFeaturePlot(seurat_obj)
# run RenameModules
seurat_obj <- ResetModuleNames(
  seurat_obj,
  new_name = "M"
)
print(names(GetModules(seurat_obj)))

# reset colors:
library(MetBrewer)
modules <- GetModules(seurat_obj)
mods <- levels(modules$module)
mod_colors <- select(modules, c(module, color)) %>%
  distinct %>% arrange(module) %>% .$color
n_colors <- length(mod_colors) -1

new_colors <- paste0(met.brewer("Tiepolo", n=n_colors))
seurat_obj <- ResetModuleColors(seurat_obj, new_colors)

plot_list <- ModuleFeaturePlot(seurat_obj, order='shuffle')
pdf("figures/test_MEFeaturePlot_human_hMEs.pdf",height=12, width=12)
wrap_plots(plot_list, ncol=4)
dev.off()
@@ -195,6 +214,7 @@ dev.off()

# save processed object:
saveRDS(seurat_obj, file=paste0(data_dir, 'human_AD_NatGen_scWGCNA.rds'))
seurat_obj <- readRDS(file=paste0(data_dir, 'human_AD_NatGen_scWGCNA.rds'))

```

@@ -213,101 +233,58 @@ modules <- GetModules(seurat_obj)
mods <- levels(modules$module)
mods <- mods[mods != 'grey']

seurat_obj$wgcna_train <- ifelse(seurat_obj$wgcna_train == 'train', TRUE, FALSE)
seurat_obj$wgcna_train_binary <- ifelse(seurat_obj$wgcna_train == 'train', TRUE, FALSE)

# ROC with a single seurat obj
test <- ComputeROC(
  seurat_obj,
  group.by = 'monocle_clusters_umap_ID',
  split_col = 'wgcna_train_binary',
  features = 'hMEs',
  harmony_group_vars = "Batch",
  scale_genes = FALSE,
  verbose=TRUE
)

split_col <- 'wgcna_train'
group.by <- 'monocle_clusters_umap_ID'
groups <- as.character(unique(seurat_obj@meta.data[[group.by]]))
# plot the ROC Curves
p <- ROCCurves(
  seurat_obj,
  roc_df = test$roc,
  conf_df = test$conf
)

# split into two seurat objects based on train & test split:
seurat_train <- seurat_obj[,seurat_obj@meta.data[[split_col]]]
seurat_test <- seurat_obj[,!seurat_obj@meta.data[[split_col]]]
pdf(paste0(fig_dir, 'test_ROC_human_func.pdf'), width=8, height=6)
p + facet_wrap(~module, ncol=4) + NoLegend()
dev.off()

# project modules for train
seurat_train <- ProjectModules(
  seurat_train,
  seurat_ref = seurat_obj,
  group.by.vars="Batch",
  wgcna_name_proj="ROC",
  scale_genes=FALSE, verbose=TRUE
)

# project modules for test
seurat_test <- ProjectModules(
  seurat_test,
  seurat_ref = seurat_obj,
  group.by.vars="Batch",
  wgcna_name_proj="ROC",
  scale_genes=FALSE, verbose=TRUE
)

# get hMEs for ROC comparison:
MEs <- as.data.frame(GetMEs(seurat_train, harmonized=TRUE, wgcna_name="ROC")) %>% mutate(group = seurat_train@meta.data[[group.by]])
MEs_p <- as.data.frame(GetMEs(seurat_test, harmonized=TRUE, wgcna_name="ROC")) %>% mutate(group = seurat_test@meta.data[[group.by]])

# compute average MEs in each group:
avg_MEs <- MEs %>% group_by(group) %>% summarise(across(!!mods, mean))
avg_MEs_p <- MEs_p %>% group_by(group) %>% summarise(across(!!mods, mean))
groups <- avg_MEs$group

# scale each column between 0 & 1
avg_MEs <- avg_MEs %>% summarise(across(!!mods, scale01))
avg_MEs_p <- avg_MEs_p %>% summarise(across(!!mods, scale01))

# convert to binary labels:
thresh <- 0.75
labels <- avg_MEs %>% purrr::map(~ifelse(. >= thresh, TRUE, FALSE))
labels <- as.data.frame(do.call(cbind, labels));
rownames(labels) <- as.character(groups)

predictions <- avg_MEs_p %>% purrr::map(~ifelse(. >= thresh, TRUE, FALSE))
predictions <- as.data.frame(do.call(cbind, predictions));
rownames(predictions) <- as.character(groups)

# loop over modules to compute ROC curves:
plot_df <- data.frame()
conf_df <- data.frame()
auc_list <- list()
mod_colors <- list()
for(cur_mod in mods){
  print(cur_mod)
  cur_color <- modules %>% subset(module == cur_mod) %>% .$color %>% unique
  mod_colors[[cur_mod]] <- cur_color

  # compute ROC
  rocobj <- pROC::roc(labels[,cur_mod], avg_MEs_p[[cur_mod]])
  auc_list[[cur_mod]] <- as.numeric(rocobj$auc)
  # confidence intervals:
  # sens_ci <- ci.se(rocobj)

  cur_df <- data.frame(
    specificity = 1-rocobj$specificities,
    sensitivity = rocobj$sensitivities,
    #auc = rocobj$auc,
    module = cur_mod,
    color = cur_color,
    auc = as.numeric(rocobj$auc)
  #  ci_lo = sens_ci[,1],
  #  ci_med = sens_ci[,2],
  #  ci_hi = sens_ci[,3]
# ROC with two separate seurat objects (TO-DO)
plot_df <- test$roc
conf_df <- test$conf

  )
  plot_df <- rbind(plot_df, cur_df)
# get Modules
modules <- GetModules(seurat_obj)
mods <- levels(modules$module)
mods <- mods[mods != 'grey']

  cur_conf <- as.data.frame(ci.se(rocobj))
  cur_conf$sensitivity <- 1-as.numeric(rownames(cur_conf))
  cur_conf$module <- cur_mod
  cur_conf$color <- cur_color
  conf_df <- rbind(conf_df, cur_conf)
# get module colors:
mod_colors <- modules %>% subset(module %in% mods) %>%
  select(c(module, color)) %>%
  distinct %>%
  arrange(module) %>% .$color

}
colnames(conf_df)[1:3] <- c('lo', 'mid', 'hi')

# plot the ROC curve
plot_df <- plot_df %>% group_by(module) %>% arrange(sensitivity)
conf_df <- conf_df %>% group_by(module) %>% arrange(sensitivity)
auc_df <- distinct(plot_df[,c('module', 'auc')])

# set factor levels for modules:
plot_df$module <- factor(plot_df$module, levels=mods)
conf_df$module <- factor(conf_df$module, levels=mods)
auc_df$module <- factor(auc_df$module, levels=mods)

p <- plot_df %>% ggplot(
  aes(x=specificity, y=sensitivity, color=module, fill=module),
) +
@@ -328,15 +305,9 @@ p <- plot_df %>% ggplot(
    x=0.75, y=0.1, label=paste0("AUC: ", format(auc_df$auc, digits=2)),
    inherit.aes=FALSE, size=4, color='black'
  )
  #annotate("text", x=0.75, y=0.1, label=paste0("AUC: ", plot_df$auc))
  # theme(
  #   axis.text = element_blank(),
  #   axis.ticks = element_blank()
  # )



pdf(paste0(fig_dir, 'test_ROC_human.pdf'), width=8, height=6)
pdf(paste0(fig_dir, 'test_ROC_human_func.pdf'), width=8, height=6)
p + facet_wrap(~module, ncol=4) + NoLegend()
dev.off()