Commit 0f28aee9 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Refactor gradio auth

parent 674e80c6
Loading
Loading
Loading
Loading
+29 −8
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import re
import warnings
import json
from threading import Thread
from typing import Iterable

from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
@@ -178,6 +179,32 @@ def validate_tls_options():
    startup_timer.record("TLS")


def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
    """
    Convert the gradio_auth and gradio_auth_path commandline arguments into
    an iterable of (username, password) tuples.
    """
    def process_credential_line(s) -> tuple[str, ...] | None:
        s = s.strip()
        if not s:
            return None
        return tuple(s.split(':', 1))

    if cmd_opts.gradio_auth:
        for cred in cmd_opts.gradio_auth.split(','):
            cred = process_credential_line(cred)
            if cred:
                yield cred

    if cmd_opts.gradio_auth_path:
        with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
            for line in file.readlines():
                for cred in line.strip().split(','):
                    cred = process_credential_line(cred)
                    if cred:
                        yield cred


def configure_sigint_handler():
    # make the program just exit at ctrl+c without waiting for anything
    def sigint_handler(sig, frame):
@@ -316,13 +343,7 @@ def webui():
        if not cmd_opts.no_gradio_queue:
            shared.demo.queue(64)

        gradio_auth_creds = []
        if cmd_opts.gradio_auth:
            gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
        if cmd_opts.gradio_auth_path:
            with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
                for line in file.readlines():
                    gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
        gradio_auth_creds = list(get_gradio_auth_creds()) or None

        # this restores the missing /docs endpoint
        if launch_api and not hasattr(FastAPI, 'original_setup'):
@@ -343,7 +364,7 @@ def webui():
            ssl_certfile=cmd_opts.tls_certfile,
            ssl_verify=cmd_opts.disable_tls_verify,
            debug=cmd_opts.gradio_debug,
            auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
            auth=gradio_auth_creds,
            inbrowser=cmd_opts.autolaunch,
            prevent_thread_lock=True,
            allowed_paths=cmd_opts.gradio_allowed_path,