Commit 64311faa authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

put refiner into main UI, into the new accordions section

add VAE from main model into infotext, not from refiner model
option to make scripts UI without gr.Group
fix inconsistencies with refiner when usings samplers that do more denoising than steps
parent 26c92f05
Loading
Loading
Loading
Loading
+15 −7
Original line number Diff line number Diff line
@@ -373,9 +373,10 @@ class StableDiffusionProcessing:
        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)

        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
        total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
        self.step_multiplier = total_steps // self.steps
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)

    def get_conds(self):
        return self.c, self.uc
@@ -579,8 +580,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
        "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
        "VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None,
        "VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None,
        "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
        "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -669,6 +670,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    if p.tiling is None:
        p.tiling = opts.tiling

    p.loaded_vae_name = sd_vae.get_loaded_vae_name()
    p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()

    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
    modules.sd_hijack.model_hijack.clear_comments()

@@ -1188,8 +1192,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
        sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
        steps = self.hr_second_pass_steps or self.steps
        total_steps = sampler_config.total_steps(steps) if sampler_config else steps

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)

    def setup_conds(self):
        super().setup_conds()
+55 −0
Original line number Diff line number Diff line
import gradio as gr

from modules import scripts, sd_models
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion


class ScriptRefiner(scripts.Script):
    section = "accordions"
    create_group = False

    def __init__(self):
        pass

    def title(self):
        return "Refiner"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
            with gr.Row():
                refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
                create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))

                refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")

        def lookup_checkpoint(title):
            info = sd_models.get_closet_checkpoint_match(title)
            return None if info is None else info.title

        self.infotext_fields = [
            (enable_refiner, lambda d: 'Refiner' in d),
            (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
            (refiner_switch_at, 'Refiner switch at'),
        ]

        return enable_refiner, refiner_checkpoint, refiner_switch_at

    def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
        # the actual implementation is in sd_samplers_common.py, apply_refiner

        p.refiner_checkpoint_info = None
        p.refiner_switch_at = None

        if not enable_refiner or refiner_checkpoint in (None, "", "None"):
            return

        refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
        if refiner_checkpoint_info is None:
            raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')

        p.refiner_checkpoint_info = refiner_checkpoint_info
        p.refiner_switch_at = refiner_switch_at
+16 −8
Original line number Diff line number Diff line
@@ -37,7 +37,10 @@ class Script:
    is_img2img = False

    group = None
    """A gr.Group component that has all script's UI inside it"""
    """A gr.Group component that has all script's UI inside it."""

    create_group = True
    """If False, for alwayson scripts, a group component will not be created."""

    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -232,6 +235,7 @@ class Script:
        """
        pass


current_basedir = paths.script_path


@@ -250,7 +254,7 @@ postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])


def list_scripts(scriptdirname, extension):
def list_scripts(scriptdirname, extension, *, include_extensions=True):
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
@@ -258,6 +262,7 @@ def list_scripts(scriptdirname, extension):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

    if include_extensions:
        for ext in extensions.active():
            scripts_list += ext.list_files(scriptdirname, extension)

@@ -288,7 +293,7 @@ def load_scripts():
    postprocessing_scripts_data.clear()
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")
    scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)

    syspath = sys.path

@@ -429,10 +434,13 @@ class ScriptRunner:
            if script.alwayson and script.section != section:
                continue

            if script.create_group:
                with gr.Group(visible=script.alwayson) as group:
                    self.create_script_ui(script)

                script.group = group
            else:
                self.create_script_ui(script)

    def prepare_ui(self):
        self.inputs = [None]
+3 −0
Original line number Diff line number Diff line
@@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")


def get_closet_checkpoint_match(search_string):
    if not search_string:
        return None

    checkpoint_info = checkpoint_aliases.get(search_string, None)
    if checkpoint_info is not None:
        return checkpoint_info
+5 −1
Original line number Diff line number Diff line
@@ -45,6 +45,11 @@ class CFGDenoiser(torch.nn.Module):
        self.nmask = None
        self.init_latent = None
        self.steps = None
        """number of steps as specified by user in UI"""

        self.total_steps = None
        """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""

        self.step = 0
        self.image_cfg_scale = None
        self.padded_cond_uncond = False
@@ -56,7 +61,6 @@ class CFGDenoiser(torch.nn.Module):
    def inner_model(self):
        raise NotImplementedError()


    def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
        denoised_uncond = x_out[-uncond.shape[0]:]
        denoised = torch.clone(denoised_uncond)
Loading