Commit 600cc034 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

added support for setting hires fix from pasted prompts

added more robust detection of last line with parameters for pasted prompts
parent 53be15c2
Loading
Loading
Loading
Loading
+19 −15
Original line number Diff line number Diff line
from collections import namedtuple
import re
import gradio as gr

re_param = re.compile(r"\s*([\w ]+):\s*([^,]+)(?:,|$)")
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())


def parse_generation_parameters(x: str):
@@ -25,6 +27,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
    done_with_prompt = False

    *lines, lastline = x.strip().split("\n")
    if not re_params.match(lastline):
        lines.append(lastline)
        lastline = ''

    for i, line in enumerate(lines):
        line = line.strip()
        if line.startswith("Negative prompt:"):
@@ -32,9 +38,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
            line = line[16:].strip()

        if done_with_prompt:
            negative_prompt += line
            negative_prompt += ("" if negative_prompt == "" else "\n") + line
        else:
            prompt += line
            prompt += ("" if prompt == "" else "\n") + line

    if len(prompt) > 0:
        res["Prompt"] = prompt
@@ -53,19 +59,21 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
    return res


def connect_paste(button, d, input_comp, js=None):
    items = []
    outputs = []

def connect_paste(button, paste_fields, input_comp, js=None):
    def paste_func(prompt):
        params = parse_generation_parameters(prompt)
        res = []

        for key, output in zip(items, outputs):
        for output, key in paste_fields:
            if callable(key):
                v = key(params)
            else:
                v = params.get(key, None)

            if v is None:
                res.append(gr.update())
            elif isinstance(v, type_of_gr_update):
                res.append(v)
            else:
                try:
                    valtype = type(output.value)
@@ -76,13 +84,9 @@ def connect_paste(button, d, input_comp, js=None):

        return res

    for k, v in d.items():
        items.append(k)
        outputs.append(v)

    button.click(
        fn=paste_func,
        _js=js,
        inputs=[input_comp],
        outputs=outputs,
        outputs=[x[0] for x in paste_fields],
    )
+36 −34
Original line number Diff line number Diff line
@@ -521,23 +521,25 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
                ]
            )

            txt2img_paste_fields = {
                "Prompt": txt2img_prompt,
                "Negative prompt": txt2img_negative_prompt,
                "Steps": steps,
                "Sampler": sampler_index,
                "Face restoration": restore_faces,
                "CFG scale": cfg_scale,
                "Seed": seed,
                "Size-1": width,
                "Size-2": height,
                "Batch size": batch_size,
                "Variation seed": subseed,
                "Variation seed strength": subseed_strength,
                "Seed resize from-1": seed_resize_from_w,
                "Seed resize from-2": seed_resize_from_h,
                "Denoising strength": denoising_strength,
            }
            txt2img_paste_fields = [
                (txt2img_prompt, "Prompt"),
                (txt2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
                (enable_hr, lambda d: "Denoising strength" in d),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
            ]
            modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)

    with gr.Blocks(analytics_enabled=False) as img2img_interface:
@@ -741,23 +743,23 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
                    outputs=[prompt, negative_prompt, style1, style2],
                )

            img2img_paste_fields = {
                "Prompt": img2img_prompt,
                "Negative prompt": img2img_negative_prompt,
                "Steps": steps,
                "Sampler": sampler_index,
                "Face restoration": restore_faces,
                "CFG scale": cfg_scale,
                "Seed": seed,
                "Size-1": width,
                "Size-2": height,
                "Batch size": batch_size,
                "Variation seed": subseed,
                "Variation seed strength": subseed_strength,
                "Seed resize from-1": seed_resize_from_w,
                "Seed resize from-2": seed_resize_from_h,
                "Denoising strength": denoising_strength,
            }
            img2img_paste_fields = [
                (img2img_prompt, "Prompt"),
                (img2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
            ]
            modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)

    with gr.Blocks(analytics_enabled=False) as extras_interface: