Commit e7f48085 authored by arcticfaded's avatar arcticfaded
Browse files

provide sampler by name

parent 8d5d863a
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
from modules.sd_samplers import samplers_k_diffusion
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
import json
import io
import base64

sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)

class TextToImageResponse(BaseModel):
    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
    parameters: Json
@@ -23,9 +26,14 @@ class Api:
        self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])

    def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
        
        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found") 
        
        populate = txt2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model, 
            "sampler_index": 0,
            "sampler_index": sampler_index[0],
            "do_not_save_samples": True,
            "do_not_save_grid": True
            }
+14 −2
Original line number Diff line number Diff line
@@ -42,7 +42,8 @@ class PydanticModelGenerator:
    def __init__(
        self,
        model_name: str = None,
        class_instance = None
        class_instance = None,
        additional_fields = None,
    ):
        def field_type_generator(k, v):
            # field_type = str if not overrides.get(k) else overrides[k]["type"]
@@ -71,6 +72,13 @@ class PydanticModelGenerator:
            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
        ]
        
        for fields in additional_fields:
            self._model_def.append(ModelDef(
                field=underscore(fields["key"]), 
                field_alias=fields["key"], 
                field_type=fields["type"],
                field_value=fields["default"]))

    def generate_model(self):
        """
        Creates a pydantic BaseModel
@@ -84,4 +92,8 @@ class PydanticModelGenerator:
        DynamicModel.__config__.allow_mutation = True
        return DynamicModel
    
StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
StableDiffusionProcessingAPI = PydanticModelGenerator(
    "StableDiffusionProcessingTxt2Img", 
    StableDiffusionProcessingTxt2Img,
    [{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
).generate_model()