Unverified Commit 371c4b99 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #4218 from bamarillo/utils-endpoints

[API][Feature] Utils endpoints
parents f674c488 17bd3f4e
Loading
Loading
Loading
Loading
+80 −5
Original line number Diff line number Diff line
@@ -2,14 +2,17 @@ import base64
import io
import time
import uvicorn
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo

from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List

def upscaler_to_index(name: str):
    try:
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):


class Api:
    def __init__(self, app, queue_lock):
    def __init__(self, app: FastAPI, queue_lock: Lock):
        self.router = APIRouter()
        self.app = app
        self.queue_lock = queue_lock
@@ -48,6 +51,18 @@ class Api:
        self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
        self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
        self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
        self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
        self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -191,6 +206,66 @@ class Api:

        return {}
        
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
        
        return options
        
    def set_config(self, req: OptionsModel):
        reqDict = vars(req)
        for o in reqDict:
            setattr(shared.opts, o, reqDict[o])

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
        return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]

    def get_upscalers(self):
        upscalers = []
        
        for upscaler in shared.sd_upscalers:
            u = upscaler.scaler
            upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
        
        return upscalers
        
    def get_sd_models(self):
        return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
    
    def get_promp_styles(self):
        styleList = []
        for k in shared.prompt_styles.styles:
            style = shared.prompt_styles.styles[k] 
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})

        return styleList

    def get_artists_categories(self):
        return shared.artist_db.cats

    def get_artists(self):
        return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]

    def launch(self, server_name, port):
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)
+67 −3
Original line number Diff line number Diff line
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
from modules.shared import sd_upscalers, opts, parser

API_NOT_ALLOWED = [
    "self",
@@ -166,3 +165,68 @@ class ProgressResponse(BaseModel):
    eta_relative: float = Field(title="ETA in secs")
    state: dict = Field(title="State", description="The current state snapshot")
    current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")

fields = {}
for key, value in opts.data.items():
    metadata = opts.data_labels.get(key)
    optType = opts.typemap.get(type(value), type(value))

    if (metadata is not None):
        fields.update({key: (Optional[optType], Field(
            default=metadata.default ,description=metadata.label))})
    else:
        fields.update({key: (Optional[optType], Field())})

OptionsModel = create_model("Options", **fields)

flags = {}
_options = vars(parser)['_option_string_actions']
for key in _options:
    if(_options[key].dest != 'help'):
        flag = _options[key]
        _type = str 
        if(_options[key].default != None): _type = type(_options[key].default) 
        flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})

FlagsModel = create_model("Flags", **flags)

class SamplerItem(BaseModel):
    name: str = Field(title="Name")
    aliases: list[str]  = Field(title="Aliases")
    options: dict[str, str] = Field(title="Options")

class UpscalerItem(BaseModel):
    name: str = Field(title="Name")
    model_name: str | None = Field(title="Model Name")
    model_path: str | None = Field(title="Path")
    model_url: str | None = Field(title="URL")

class SDModelItem(BaseModel):
    title: str = Field(title="Title")
    model_name: str = Field(title="Model Name")
    hash: str = Field(title="Hash")
    filename: str = Field(title="Filename")
    config: str = Field(title="Config file")

class HypernetworkItem(BaseModel):
    name: str = Field(title="Name")
    path: str | None = Field(title="Path")

class FaceRestorerItem(BaseModel):
    name: str = Field(title="Name")
    cmd_dir: str | None = Field(title="Path")

class RealesrganItem(BaseModel):
    name: str = Field(title="Name")
    path: str | None = Field(title="Path")
    scale: int | None = Field(title="Scale")

class PromptStyleItem(BaseModel):
    name: str = Field(title="Name")
    prompt: str | None = Field(title="Prompt")
    negative_prompt: str | None = Field(title="Negative Prompt")

class ArtistItem(BaseModel):
    name: str = Field(title="Name")
    score: float = Field(title="Score")
    category: str = Field(title="Category")
 No newline at end of file

test/utils_test.py

0 → 100644
+63 −0
Original line number Diff line number Diff line
import unittest
import requests

class UtilsTests(unittest.TestCase):
  def setUp(self):
    self.url_options = "http://localhost:7860/sdapi/v1/options"
    self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
    self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
    self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
    self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
    self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
    self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
    self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
    self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
    self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
    self.url_artists = "http://localhost:7860/sdapi/v1/artists"

  def test_options_get(self):
    self.assertEqual(requests.get(self.url_options).status_code, 200)

  def test_options_write(self):
    response = requests.get(self.url_options)
    self.assertEqual(response.status_code, 200)
    
    pre_value = response.json()["send_seed"]

    self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)

    response = requests.get(self.url_options)
    self.assertEqual(response.status_code, 200)
    self.assertEqual(response.json()["send_seed"], not pre_value)

    requests.post(self.url_options, json={"send_seed": pre_value})

  def test_cmd_flags(self):
    self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)

  def test_samplers(self):
    self.assertEqual(requests.get(self.url_samplers).status_code, 200)

  def test_upscalers(self):
    self.assertEqual(requests.get(self.url_upscalers).status_code, 200)

  def test_sd_models(self):
    self.assertEqual(requests.get(self.url_sd_models).status_code, 200)

  def test_hypernetworks(self):
    self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)

  def test_face_restorers(self):
    self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
  
  def test_realesrgan_models(self):
    self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
  
  def test_prompt_styles(self):
    self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
  
  def test_artist_categories(self):
    self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)

  def test_artists(self):
    self.assertEqual(requests.get(self.url_artists).status_code, 200)
 No newline at end of file