Commit 866b36d7 authored by Bruno Seoane's avatar Bruno Seoane
Browse files

Move processing's models into models.py

It didn't make sense to have two differente files for the same and
"models" is a more descriptive name.
parent e0ca4dfb
Loading
Loading
Loading
Loading
+9 −48
Original line number Diff line number Diff line
from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
import modules.shared as shared
import uvicorn
from gradio import processing_utils
from fastapi import APIRouter, HTTPException
import json
import io
import base64
import modules.shared as shared
from modules.api.models import *
from PIL import Image
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras
from gradio import processing_utils

def upscaler_to_index(name: str):
    try:
@@ -20,29 +15,6 @@ def upscaler_to_index(name: str):

sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)

# def img_to_base64(img: str):
#     buffer = io.BytesIO()
#     img.save(buffer, format="png")
#     return base64.b64encode(buffer.getvalue())

# def base64_to_bytes(base64Img: str):
#     if "," in base64Img:
#         base64Img = base64Img.split(",")[1]
#     return io.BytesIO(base64.b64decode(base64Img))

# def base64_to_images(base64Imgs: list[str]):
#     imgs = []
#     for img in base64Imgs:
#         img = Image.open(base64_to_bytes(img))
#         imgs.append(img)
#     return imgs

class ImageToImageResponse(BaseModel):
    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
    parameters: dict
    info: str


class Api:
    def __init__(self, app, queue_lock):
        self.router = APIRouter()
@@ -51,15 +23,7 @@ class Api:
        self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)

    # def __base64_to_image(self, base64_string):
    #     # if has a comma, deal with prefix
    #     if "," in base64_string:
    #         base64_string = base64_string.split(",")[1]
    #     imgdata = base64.b64decode(base64_string)
    #     # convert base64 to PIL image
    #     return Image.open(io.BytesIO(imgdata))
        self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -81,7 +45,7 @@ class Api:
        
        b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
        
        return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info)
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info)

    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -120,10 +84,7 @@ class Api:
            processed = process_images(p)
        
        b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
        # for i in processed.images:
        #     buffer = io.BytesIO()
        #     i.save(buffer, format="png")
        #     b64images.append(base64.b64encode(buffer.getvalue()))
       
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info)

    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
@@ -134,12 +95,12 @@ class Api:
        reqDict.pop('upscaler_1')
        reqDict.pop('upscaler_2')

        reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image'])
        reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image'])

        with self.queue_lock:
            result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")

        return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2])
        return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2])

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
        upscaler1Index = upscaler_to_index(req.upscaler_1)
+110 −2
Original line number Diff line number Diff line
from pydantic import BaseModel, Field, Json
import inspect
from pydantic import BaseModel, Field, Json, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers

API_NOT_ALLOWED = [
    "self",
    "kwargs",
    "sd_model",
    "outpath_samples",
    "outpath_grids",
    "sampler_index",
    "do_not_save_samples",
    "do_not_save_grid",
    "extra_generation_params",
    "overlay_images",
    "do_not_reload_embeddings",
    "seed_enable_extras",
    "prompt_for_display",
    "sampler_noise_scheduler_override",
    "ddim_discretize"
]

class ModelDef(BaseModel):
    """Assistance Class for Pydantic Dynamic Model Generation"""

    field: str
    field_alias: str
    field_type: Any
    field_value: Any


class PydanticModelGenerator:
    """
    Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
    source_data is a snapshot of the default values produced by the class
    params are the names of the actual keys required by __init__
    """

    def __init__(
        self,
        model_name: str = None,
        class_instance = None,
        additional_fields = None,
    ):
        def field_type_generator(k, v):
            # field_type = str if not overrides.get(k) else overrides[k]["type"]
            # print(k, v.annotation, v.default)
            field_type = v.annotation
            
            return Optional[field_type]
        
        def merge_class_params(class_):
            all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
            parameters = {}
            for classes in all_classes:
                parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
            return parameters
            
                
        self._model_name = model_name
        self._class_data = merge_class_params(class_instance)
        self._model_def = [
            ModelDef(
                field=underscore(k),
                field_alias=k,
                field_type=field_type_generator(k, v),
                field_value=v.default
            )
            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
        ]
        
        for fields in additional_fields:
            self._model_def.append(ModelDef(
                field=underscore(fields["key"]), 
                field_alias=fields["key"], 
                field_type=fields["type"],
                field_value=fields["default"]))

    def generate_model(self):
        """
        Creates a pydantic BaseModel
        from the json and overrides provided at initialization
        """
        fields = {
            d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
        }
        DynamicModel = create_model(self._model_name, **fields)
        DynamicModel.__config__.allow_population_by_field_name = True
        DynamicModel.__config__.allow_mutation = True
        return DynamicModel
    
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
    "StableDiffusionProcessingTxt2Img", 
    StableDiffusionProcessingTxt2Img,
    [{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()

StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
    "StableDiffusionProcessingImg2Img", 
    StableDiffusionProcessingImg2Img,
    [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
).generate_model()

class TextToImageResponse(BaseModel):
    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
    parameters: str
    parameters: dict
    info: str

class ImageToImageResponse(BaseModel):
    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
    parameters: dict
    info: str

class ExtrasBaseRequest(BaseModel):

modules/api/processing.py

deleted100644 → 0
+0 −106
Original line number Diff line number Diff line
from array import array
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect


API_NOT_ALLOWED = [
    "self",
    "kwargs",
    "sd_model",
    "outpath_samples",
    "outpath_grids",
    "sampler_index",
    "do_not_save_samples",
    "do_not_save_grid",
    "extra_generation_params",
    "overlay_images",
    "do_not_reload_embeddings",
    "seed_enable_extras",
    "prompt_for_display",
    "sampler_noise_scheduler_override",
    "ddim_discretize"
]

class ModelDef(BaseModel):
    """Assistance Class for Pydantic Dynamic Model Generation"""

    field: str
    field_alias: str
    field_type: Any
    field_value: Any


class PydanticModelGenerator:
    """
    Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
    source_data is a snapshot of the default values produced by the class
    params are the names of the actual keys required by __init__
    """

    def __init__(
        self,
        model_name: str = None,
        class_instance = None,
        additional_fields = None,
    ):
        def field_type_generator(k, v):
            # field_type = str if not overrides.get(k) else overrides[k]["type"]
            # print(k, v.annotation, v.default)
            field_type = v.annotation
            
            return Optional[field_type]
        
        def merge_class_params(class_):
            all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
            parameters = {}
            for classes in all_classes:
                parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
            return parameters
            
                
        self._model_name = model_name
        self._class_data = merge_class_params(class_instance)
        self._model_def = [
            ModelDef(
                field=underscore(k),
                field_alias=k,
                field_type=field_type_generator(k, v),
                field_value=v.default
            )
            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
        ]
        
        for fields in additional_fields:
            self._model_def.append(ModelDef(
                field=underscore(fields["key"]), 
                field_alias=fields["key"], 
                field_type=fields["type"],
                field_value=fields["default"]))

    def generate_model(self):
        """
        Creates a pydantic BaseModel
        from the json and overrides provided at initialization
        """
        fields = {
            d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
        }
        DynamicModel = create_model(self._model_name, **fields)
        DynamicModel.__config__.allow_population_by_field_name = True
        DynamicModel.__config__.allow_mutation = True
        return DynamicModel
    
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
    "StableDiffusionProcessingTxt2Img", 
    StableDiffusionProcessingTxt2Img,
    [{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()

StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
    "StableDiffusionProcessingImg2Img", 
    StableDiffusionProcessingImg2Img,
    [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
).generate_model()
 No newline at end of file