Commit 875990a2 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Add option for /_stop route (for graceful shutdown)

parent 85b4f899
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -103,3 +103,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+11 −2
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import warnings
import json
from threading import Thread

from fastapi import FastAPI
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
@@ -270,6 +270,12 @@ def api_only():
    print(f"Startup time: {startup_timer.summary()}.")
    api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)


def stop_route(request):
    shared.state.server_command = "stop"
    return Response("Stopping.")


def webui():
    launch_api = cmd_opts.api
    initialize()
@@ -318,6 +324,8 @@ def webui():
            inbrowser=cmd_opts.autolaunch,
            prevent_thread_lock=True
        )
        if cmd_opts.add_stop_route:
            app.add_route("/_stop", stop_route, methods=["POST"])

        # after initial launch, disable --autolaunch for subsequent restarts
        cmd_opts.autolaunch = False
@@ -359,11 +367,12 @@ def webui():
                    else:
                        print(f"Unknown server command: {server_command}")
        except KeyboardInterrupt:
            print('Caught KeyboardInterrupt, stopping...')
            server_command = "stop"

        if server_command == "stop":
            print("Stopping server...")
            # If we catch a keyboard interrupt, we want to stop the server and exit.
            print('Caught KeyboardInterrupt, stopping...')
            shared.demo.close()
            break
        print('Restarting UI...')