Commit 7a2e36b5 authored by Bruno Seoane's avatar Bruno Seoane
Browse files

Add config and lists endpoints

parent d98eacea
Loading
Loading
Loading
Loading
+92 −5
Original line number Diff line number Diff line
@@ -2,14 +2,17 @@ import base64
import io
import time
import uvicorn
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo

from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List

def upscaler_to_index(name: str):
    try:
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):


class Api:
    def __init__(self, app, queue_lock):
    def __init__(self, app: FastAPI, queue_lock: Lock):
        self.router = APIRouter()
        self.app = app
        self.queue_lock = queue_lock
@@ -48,6 +51,19 @@ class Api:
        self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
        self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.app.add_api_route("/sdapi/v1/info", self.get_info, methods=["GET"])
        self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
        self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
        self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
        self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -191,6 +207,77 @@ class Api:

        return {}
        
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
        
        return options
        
    def set_config(self, req: OptionsModel):
        reqDict = vars(req)
        for o in reqDict:
            setattr(shared.opts, o, reqDict[o])

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_info(self):

        return {
            "hypernetworks": [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks],
            "face_restorers": [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers],
            "realesrgan_models":[{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)],
            "promp_styles":[shared.prompt_styles.styles[k] for k in shared.prompt_styles.styles],
            "artists_categories": shared.artist_db.cats,
            # "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
        }

    def get_samplers(self):
        return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]

    def get_upscalers(self):
        upscalers = []
        
        for upscaler in shared.sd_upscalers:
            u = upscaler.scaler
            upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
        
        return upscalers
        
    def get_sd_models(self):
        return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
    
    def get_promp_styles(self):
        styleList = []
        for k in shared.prompt_styles.styles:
            style = shared.prompt_styles.styles[k] 
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})

        return styleList

    def get_artists_categories(self):
        return shared.artist_db.cats

    def get_artists(self):
        return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]

    def launch(self, server_name, port):
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)
+67 −3
Original line number Diff line number Diff line
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
from modules.shared import sd_upscalers, opts, parser

API_NOT_ALLOWED = [
    "self",
@@ -165,3 +164,68 @@ class ProgressResponse(BaseModel):
    eta_relative: float = Field(title="ETA in secs")
    state: dict = Field(title="State", description="The current state snapshot")
    current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")

fields = {}
for key, value in opts.data.items():
    metadata = opts.data_labels.get(key)
    optType = opts.typemap.get(type(value), type(value))

    if (metadata is not None):
        fields.update({key: (Optional[optType], Field(
            default=metadata.default ,description=metadata.label))})
    else:
        fields.update({key: (Optional[optType], Field())})

OptionsModel = create_model("Options", **fields)

flags = {}
_options = vars(parser)['_option_string_actions']
for key in _options:
    if(_options[key].dest != 'help'):
        flag = _options[key]
        _type = str 
        if(_options[key].default != None): _type = type(_options[key].default) 
        flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})

FlagsModel = create_model("Flags", **flags)

class SamplerItem(BaseModel):
    name: str = Field(title="Name")
    aliases: list[str]  = Field(title="Aliases")
    options: dict[str, str] = Field(title="Options")

class UpscalerItem(BaseModel):
    name: str = Field(title="Name")
    model_name: str | None = Field(title="Model Name")
    model_path: str | None = Field(title="Path")
    model_url: str | None = Field(title="URL")

class SDModelItem(BaseModel):
    title: str = Field(title="Title")
    model_name: str = Field(title="Model Name")
    hash: str = Field(title="Hash")
    filename: str = Field(title="Filename")
    config: str = Field(title="Config file")

class HypernetworkItem(BaseModel):
    name: str = Field(title="Name")
    path: str | None = Field(title="Path")

class FaceRestorerItem(BaseModel):
    name: str = Field(title="Name")
    cmd_dir: str | None = Field(title="Path")

class RealesrganItem(BaseModel):
    name: str = Field(title="Name")
    path: str | None = Field(title="Path")
    scale: int | None = Field(title="Scale")

class PromptStyleItem(BaseModel):
    name: str = Field(title="Name")
    prompt: str | None = Field(title="Prompt")
    negative_prompt: str | None = Field(title="Negative Prompt")

class ArtistItem(BaseModel):
    name: str = Field(title="Name")
    score: float = Field(title="Score")
    category: str = Field(title="Category")
 No newline at end of file