Commit 3e5b3c79 authored by siutin's avatar siutin
Browse files

replace with #wrap_session_call

parent 98497006
Loading
Loading
Loading
Loading
+9 −7
Original line number Diff line number Diff line
@@ -10,6 +10,11 @@ from modules import shared, progress
queue_lock = threading.Lock()
queue_lock_condition = threading.Condition(lock=queue_lock)

def wrap_session_call(func):
  def f(request: gr.Request, *args, **kwargs):
    return func(request, *args, **kwargs)
  return f

def wrap_queued_call(func):
    def f(*args, **kwargs):
        with queue_lock:
@@ -45,20 +50,17 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):

        return res

    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True)
    return wrap_session_call(wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True))


def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False):
    def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
    def f(*args, extra_outputs_array=extra_outputs, **kwargs):
        run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
        if run_memmon:
            shared.mem_mon.monitor()
        t = time.perf_counter()

        try:
            if add_request:
              res = list(func(request, *args, **kwargs))
            else: 
            res = list(func(*args, **kwargs))
        except Exception as e:
            # When printing out our debug argument list, do not print out more than a MB of text