Commit d0c494f1 authored by Jun Zhao's avatar Jun Zhao
Browse files

update GPU setting

parent 09a59791
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -43,11 +43,11 @@ STGmarkerFinder <- function(
  # set Python
  use_python(python = python.use, required = T)

  source_python(file = paste(system.file(package="DAseq"), "DA_STG.py", sep = "/"))

  # set GPU device
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))

  source_python(file = paste(system.file(package="DAseq"), "DA_STG.py", sep = "/"))

  # turn X into Python format
  X.py <- r_to_py(as.matrix(X))

@@ -164,8 +164,8 @@ runSTG <- function(
){
  # set Python
  use_python(python = python.use, required = T)
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))
  source_python(file = paste(system.file(package="DAseq"), "DA_STG.py", sep = "/"))
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))

  if(!is.null(label.2)){
    X.use <- which(X.labels %in% c(label.1,label.2))
+4 −4
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ getDAcells <- function(
  do.plot = T, plot.embedding = NULL, size = 0.5,
  python.use = "/usr/bin/python", GPU = ""
){
#  cat("Using GPU ", GPU, ".\n", sep = "")

  # get DA score vector for each cell
  cat("Calculating DA score vector.\n")
@@ -64,16 +65,15 @@ getDAcells <- function(
  binary.labels_py <- r_to_py(as.matrix(binary.labels))
  X.knn.ratio_py <- r_to_py(as.matrix(X.knn.ratio))


  # set GPU device
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))

  # get neural network predictions for each cell
  cat("Running logistic regression.\n")
  source_python(file = paste(system.file(package="DAseq"), "DA_logit.py", sep = "/"))
  # py_run_string(paste("epochs = ", epochs, sep = ""))
  py_run_string(paste("k_folds = ", k.folds, sep = ""))

  # set GPU device
  py_run_string(paste("os.environ['CUDA_VISIBLE_DEVICES'] = '", GPU, "'", sep = ""))


  # check n.runs
  if(n.runs == 1){