Unverified Commit 336c341a authored by Maiko Tan's avatar Maiko Tan
Browse files

Merge branch 'master' into api-authorization

parents 8f2ff861 84a6f211
Loading
Loading
Loading
Loading
+17 −25
Original line number Diff line number Diff line
@@ -9,9 +9,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

import modules.shared as shared
from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
@@ -28,8 +28,12 @@ def upscaler_to_index(name: str):
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")


sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")

    return name

def setUpscalers(req: dict):
    reqDict = vars(req)
@@ -77,6 +81,7 @@ class Api:
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
@@ -103,14 +108,9 @@ class Api:
        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)

        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found")

        populate = txt2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model,
            "sampler_index": sampler_index[0],
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
            "do_not_save_samples": True,
            "do_not_save_grid": True
            }
@@ -130,12 +130,6 @@ class Api:
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        sampler_index = sampler_to_index(img2imgreq.sampler_index)

        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found")


        init_images = img2imgreq.init_images
        if init_images is None:
            raise HTTPException(status_code=404, detail="Init image not found")
@@ -144,10 +138,9 @@ class Api:
        if mask:
            mask = decode_base64_to_image(mask)


        populate = img2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model,
            "sampler_index": sampler_index[0],
            "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
            "do_not_save_samples": True,
            "do_not_save_grid": True,
            "mask": mask
@@ -266,6 +259,9 @@ class Api:

        return {}

    def skip(self):
        shared.state.skip()

    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
@@ -277,14 +273,10 @@ class Api:

        return options

    def set_config(self, req: OptionsModel):
        # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
        # overwrite all options with default values.
        raise RuntimeError('Setting options via API is not supported')
    def set_config(self, req: Dict[str, Any]):
       
        reqDict = vars(req)
        for o in reqDict:
            setattr(shared.opts, o, reqDict[o])
        for o in req:
            setattr(shared.opts, o, req[o])

        shared.opts.save(shared.config_filename)
        return
@@ -293,7 +285,7 @@ class Api:
        return vars(shared.cmd_opts)

    def get_samplers(self):
        return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
        return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]

    def get_upscalers(self):
        upscalers = []
+3 −3
Original line number Diff line number Diff line
@@ -176,9 +176,9 @@ class InterrogateResponse(BaseModel):
    caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")

fields = {}
for key, value in opts.data.items():
    metadata = opts.data_labels.get(key)
    optType = opts.typemap.get(type(value), type(value))
for key, metadata in opts.data_labels.items():
    value = opts.data.get(key)
    optType = opts.typemap.get(type(metadata.default), type(value))

    if (metadata is not None):
        fields.update({key: (Optional[optType], Field(
+5 −2
Original line number Diff line number Diff line
@@ -65,9 +65,12 @@ class Extension:
        self.can_update = False
        self.status = "latest"

    def pull(self):
    def fetch_and_reset_hard(self):
        repo = git.Repo(self.path)
        repo.remotes.origin.pull()
        # Fix: `error: Your local changes to the following files would be overwritten by merge`,
        # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
        repo.git.fetch('--all')
        repo.git.reset('--hard', 'origin')


def list_extensions():
+1 −0
Original line number Diff line number Diff line
@@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict):
        'sd_hypernetwork': 'Hypernet',
        'sd_hypernetwork_strength': 'Hypernet strength',
        'CLIP_stop_at_last_layers': 'Clip skip',
        'inpainting_mask_weight': 'Conditional mask weight',
        'sd_model_checkpoint': 'Model hash',
    }
    settings_paste_fields = [
+2 −2
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                p.prompt = preview_prompt
                p.negative_prompt = preview_negative_prompt
                p.steps = preview_steps
                p.sampler_index = preview_sampler_index
                p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
                p.cfg_scale = preview_cfg_scale
                p.seed = preview_seed
                p.width = preview_width
Loading