Commit 98497006 authored by siutin's avatar siutin
Browse files

multi users support

parent 70ab21e6
Loading
Loading
Loading
Loading
+14 −9
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import threading
import traceback
import time

import gradio as gr
from modules import shared, progress

queue_lock = threading.Lock()
@@ -20,40 +21,44 @@ def wrap_queued_call(func):


def wrap_gradio_gpu_call(func, extra_outputs=None):
    def f(*args, **kwargs):
    def f(request: gr.Request, *args, **kwargs):
        user = request.username

        # if the first argument is a string that says "task(...)", it is treated as a job id
        if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
            id_task = args[0]
            progress.add_task_to_queue(id_task)
            progress.add_task_to_queue(user, id_task)
        else:
            id_task = None

        with queue_lock:
            shared.state.begin()
            progress.start_task(id_task)
            progress.start_task(user, id_task)

            try:
                res = func(*args, **kwargs)
            finally:
                progress.finish_task(id_task)
                progress.set_last_task_result(id_task, res)
                progress.finish_task(user, id_task)
                progress.set_last_task_result(user, id_task, res)

            shared.state.end()

        return res

    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True)


def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
    def f(*args, extra_outputs_array=extra_outputs, **kwargs):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False):
    def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
        run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
        if run_memmon:
            shared.mem_mon.monitor()
        t = time.perf_counter()

        try:
            if add_request:
              res = list(func(request, *args, **kwargs))
            else: 
              res = list(func(*args, **kwargs))
        except Exception as e:
            # When printing out our debug argument list, do not print out more than a MB of text
+43 −17
Original line number Diff line number Diff line
@@ -4,7 +4,9 @@ import time

import gradio as gr
from pydantic import BaseModel, Field
from typing import List
from typing import Optional
from fastapi import Depends, Security
from fastapi.security import APIKeyCookie

from modules import call_queue
from modules.shared import opts
@@ -12,57 +14,71 @@ from modules.shared import opts
import modules.shared as shared


current_task_user = None
current_task = None
pending_tasks = {}
finished_tasks = []


def start_task(id_task):
def start_task(user, id_task):
    global current_task
    global current_task_user

    current_task_user = user
    current_task = id_task
    pending_tasks.pop(id_task, None)
    pending_tasks.pop((user, id_task), None)


def finish_task(id_task):
def finish_task(user, id_task):
    global current_task
    global current_task_user

    if current_task == id_task:
        current_task = None

    finished_tasks.append(id_task)
    if current_task_user == user:
        current_task_user = None

    finished_tasks.append((user, id_task))
    if len(finished_tasks) > 16:
        finished_tasks.pop(0)


def add_task_to_queue(id_job):
    pending_tasks[id_job] = time.time()
def add_task_to_queue(user, id_job):
    pending_tasks[(user, id_job)] = time.time()

last_task_id = None
last_task_result = None
last_task_user = None

def set_last_task_result(user, id_job, result):

def set_last_task_result(id_job, result):
  global last_task_id
  global last_task_result
  global last_task_user

  last_task_id = id_job
  last_task_result = result
  last_task_user = user


def restore_progress_call():
def restore_progress_call(request: gr.Request):
    if current_task is None:

      # image, generation_info, html_info, html_log
      return tuple(list([None, None, None, None]))

    else:
      user = request.username

      if current_task_user == user:
        t_task = current_task
        with call_queue.queue_lock_condition:
          call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id)

        return last_task_result

      return tuple(list([None, None, None, None]))

class CurrentTaskResponse(BaseModel):
  current_task: str = Field(default=None, title="Task ID", description="id of the current progress task")
@@ -87,6 +103,19 @@ def setup_progress_api(app):
    return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)

def setup_current_task_api(app):

    def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))):
      return None if token is None else app.tokens.get(token)

    def current_task_api(current_user: str = Depends(get_current_user)):

      if app.auth is None or current_task_user == current_user:
        current_user_task = current_task
      else:
        current_user_task = None

      return CurrentTaskResponse(current_task=current_user_task)

    return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse)

def progressapi(req: ProgressRequest):
@@ -128,6 +157,3 @@ def progressapi(req: ProgressRequest):
        live_preview = None

    return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
 No newline at end of file

def current_task_api():
  return CurrentTaskResponse(current_task=current_task)
 No newline at end of file
+2 −2
Original line number Diff line number Diff line
@@ -582,7 +582,7 @@ def create_ui():
            res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)

            restore_progress_button.click(
              fn=lambda: restore_progress_call(),
              fn=restore_progress_call,
              _js="() => restoreProgress('txt2img')",
              inputs=[],
              outputs=[
@@ -914,7 +914,7 @@ def create_ui():
            res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)

            restore_progress_button.click(
              fn=lambda: restore_progress_call(),
              fn=restore_progress_call,
              _js="() => restoreProgress('img2img')",
              inputs=[],
              outputs=[