Commit cdc8020d authored by AUTOMATIC's avatar AUTOMATIC
Browse files

change StableDiffusionProcessing to internally use sampler name instead of sampler index

parent d9fd4525
Loading
Loading
Loading
Loading
+9 −17
Original line number Diff line number Diff line
@@ -6,9 +6,9 @@ from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
@@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")


sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")

    return name

def setUpscalers(req: dict):
    reqDict = vars(req)
@@ -82,14 +86,9 @@ class Api:
        self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)

        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found")

        populate = txt2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model,
            "sampler_index": sampler_index[0],
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
            "do_not_save_samples": True,
            "do_not_save_grid": True
            }
@@ -109,12 +108,6 @@ class Api:
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        sampler_index = sampler_to_index(img2imgreq.sampler_index)

        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found")


        init_images = img2imgreq.init_images
        if init_images is None:
            raise HTTPException(status_code=404, detail="Init image not found")
@@ -123,10 +116,9 @@ class Api:
        if mask:
            mask = decode_base64_to_image(mask)


        populate = img2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model,
            "sampler_index": sampler_index[0],
            "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
            "do_not_save_samples": True,
            "do_not_save_grid": True,
            "mask": mask
@@ -272,7 +264,7 @@ class Api:
        return vars(shared.cmd_opts)

    def get_samplers(self):
        return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
        return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]

    def get_upscalers(self):
        upscalers = []
+2 −2
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                p.prompt = preview_prompt
                p.negative_prompt = preview_negative_prompt
                p.steps = preview_steps
                p.sampler_index = preview_sampler_index
                p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
                p.cfg_scale = preview_cfg_scale
                p.seed = preview_seed
                p.width = preview_width
+1 −1
Original line number Diff line number Diff line
@@ -303,7 +303,7 @@ class FilenameGenerator:
        'width': lambda self: self.image.width,
        'height': lambda self: self.image.height,
        'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
        'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
        'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
        'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
        'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
        'datetime': lambda self, *args: self.datetime(*args),  # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
+2 −2
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ import traceback
import numpy as np
from PIL import Image, ImageOps, ImageChops

from modules import devices
from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
        seed_resize_from_h=seed_resize_from_h,
        seed_resize_from_w=seed_resize_from_w,
        seed_enable_extras=seed_enable_extras,
        sampler_index=sampler_index,
        sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
        batch_size=batch_size,
        n_iter=n_iter,
        steps=steps,
+12 −17
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import json
import math
import os
import sys
import warnings

import torch
import numpy as np
@@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):

    return image

def get_correct_sampler(p):
    if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
        return sd_samplers.samplers
    elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
        return sd_samplers.samplers_for_img2img
    elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
        return sd_samplers.samplers

class StableDiffusionProcessing():
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
        if sampler_index is not None:
            warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name")

        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
@@ -91,7 +88,7 @@ class StableDiffusionProcessing():
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
        self.sampler_index: int = sampler_index
        self.sampler_name: str = sampler_name
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
@@ -210,8 +207,7 @@ class Processed:
        self.info = info
        self.width = p.width
        self.height = p.height
        self.sampler_index = p.sampler_index
        self.sampler = sd_samplers.samplers[p.sampler_index].name
        self.sampler_name = p.sampler_name
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps
        self.batch_size = p.batch_size
@@ -256,8 +252,7 @@ class Processed:
            "subseed_strength": self.subseed_strength,
            "width": self.width,
            "height": self.height,
            "sampler_index": self.sampler_index,
            "sampler": self.sampler,
            "sampler_name": self.sampler_name,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
            "batch_size": self.batch_size,
@@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration

    generation_params = {
        "Steps": p.steps,
        "Sampler": get_correct_sampler(p)[p.sampler_index].name,
        "Sampler": p.sampler_name,
        "CFG scale": p.cfg_scale,
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@@ -645,7 +640,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
            self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f

    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        if not self.enable_hr:
            x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -706,7 +701,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

        shared.state.nextjob()

        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

@@ -743,7 +738,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        self.image_conditioning = None

    def init(self, all_prompts, all_seeds, all_subseeds):
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
        crop_region = None

        if self.image_mask is not None:
Loading