Unverified Commit f6c06e3e authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #10458 from akx/graceful-stop

Graceful server stopping
parents 216b0fa6 875990a2
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -103,3 +103,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+41 −1
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import datetime
import json
import os
import sys
import threading
import time

import gradio as gr
@@ -110,8 +111,47 @@ class State:
    id_live_preview = 0
    textinfo = None
    time_start = None
    need_restart = False
    server_start = None
    _server_command_signal = threading.Event()
    _server_command: str | None = None

    @property
    def need_restart(self) -> bool:
        # Compatibility getter for need_restart.
        return self.server_command == "restart"

    @need_restart.setter
    def need_restart(self, value: bool) -> None:
        # Compatibility setter for need_restart.
        if value:
            self.server_command = "restart"

    @property
    def server_command(self):
        return self._server_command

    @server_command.setter
    def server_command(self, value: str | None) -> None:
        """
        Set the server command to `value` and signal that it's been set.
        """
        self._server_command = value
        self._server_command_signal.set()

    def wait_for_server_command(self, timeout: float | None = None) -> str | None:
        """
        Wait for server command to get set; return and clear the value and signal.
        """
        if self._server_command_signal.wait(timeout):
            self._server_command_signal.clear()
            req = self._server_command
            self._server_command = None
            return req
        return None

    def request_restart(self) -> None:
        self.interrupt()
        self.server_command = True

    def skip(self):
        self.skipped = True
+1 −5
Original line number Diff line number Diff line
@@ -1609,12 +1609,8 @@ def create_ui():
            outputs=[]
        )

        def request_restart():
            shared.state.interrupt()
            shared.state.need_restart = True

        restart_gradio.click(
            fn=request_restart,
            fn=shared.state.request_restart,
            _js='restart_reload',
            inputs=[],
            outputs=[],
+2 −5
Original line number Diff line number Diff line
@@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
    shared.opts.disabled_extensions = disabled
    shared.opts.disable_all_extensions = disable_all
    shared.opts.save(shared.config_filename)

    shared.state.interrupt()
    shared.state.need_restart = True
    shared.state.request_restart()


def save_config_state(name):
@@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
    if restore_type == "webui" or restore_type == "both":
        config_states.restore_webui_config(config_state)

    shared.state.interrupt()
    shared.state.need_restart = True
    shared.state.request_restart()

    return ""

+34 −16
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import warnings
import json
from threading import Thread

from fastapi import FastAPI
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
@@ -241,6 +241,9 @@ def initialize():
        print(f'Interrupted with signal {sig} in {frame}')
        os._exit(0)

    if not os.environ.get("COVERAGE_RUN"):
        # Don't install the immediate-quit handler when running under coverage,
        # as then the coverage report won't be generated.
        signal.signal(signal.SIGINT, sigint_handler)


@@ -262,19 +265,6 @@ def create_api(app):
    return api


def wait_on_server(demo=None):
    while 1:
        time.sleep(0.5)
        if shared.state.need_restart:
            shared.state.need_restart = False
            time.sleep(0.5)
            demo.close()
            time.sleep(0.5)

            modules.script_callbacks.app_reload_callback()
            break


def api_only():
    initialize()

@@ -287,6 +277,12 @@ def api_only():
    print(f"Startup time: {startup_timer.summary()}.")
    api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)


def stop_route(request):
    shared.state.server_command = "stop"
    return Response("Stopping.")


def webui():
    launch_api = cmd_opts.api
    initialize()
@@ -335,6 +331,9 @@ def webui():
            inbrowser=cmd_opts.autolaunch,
            prevent_thread_lock=True
        )
        if cmd_opts.add_stop_route:
            app.add_route("/_stop", stop_route, methods=["POST"])

        # after initial launch, disable --autolaunch for subsequent restarts
        cmd_opts.autolaunch = False

@@ -366,8 +365,27 @@ def webui():
            redirector.get("/")
            gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")

        wait_on_server(shared.demo)
        try:
            while True:
                server_command = shared.state.wait_for_server_command(timeout=5)
                if server_command:
                    if server_command in ("stop", "restart"):
                        break
                    else:
                        print(f"Unknown server command: {server_command}")
        except KeyboardInterrupt:
            print('Caught KeyboardInterrupt, stopping...')
            server_command = "stop"

        if server_command == "stop":
            print("Stopping server...")
            # If we catch a keyboard interrupt, we want to stop the server and exit.
            shared.demo.close()
            break
        print('Restarting UI...')
        shared.demo.close()
        time.sleep(0.5)
        modules.script_callbacks.app_reload_callback()

        startup_timer.reset()