Commit 7a14c8ab authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add an option to enable sections from extras tab in txt2img/img2img

fix some style inconsistenices
parent 645f4e7e
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:

                image = Image.fromarray(x_sample)

                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image

                if p.color_corrections is not None and i < len(p.color_corrections):
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+28 −4
Original line number Diff line number Diff line
@@ -6,12 +6,16 @@ from collections import namedtuple

import gradio as gr

from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing

AlwaysVisible = object()


class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image


class Script:
    filename = None
    args_from = None
@@ -65,7 +69,7 @@ class Script:
        args contains all values returned by components from ui()
        """

        raise NotImplementedError()
        pass

    def process(self, p, *args):
        """
@@ -100,6 +104,13 @@ class Script:

        pass

    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
@@ -247,11 +258,15 @@ class ScriptRunner:
        self.infotext_fields = []

    def initialize_scripts(self, is_img2img):
        from modules import scripts_auto_postprocessing

        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

        for script_class, path, basedir, script_module in scripts_data:
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

        for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
            script = script_class()
            script.filename = path
            script.is_txt2img = not is_img2img
@@ -332,7 +347,7 @@ class ScriptRunner:

        return inputs

    def run(self, p: StableDiffusionProcessing, *args):
    def run(self, p, *args):
        script_index = args[0]

        if script_index == 0:
@@ -386,6 +401,15 @@ class ScriptRunner:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_image(p, pp, *script_args)
            except Exception:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
+42 −0
Original line number Diff line number Diff line
from modules import scripts, scripts_postprocessing, shared


class ScriptPostprocessingForMainUI(scripts.Script):
    def __init__(self, script_postproc):
        self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
        self.postprocessing_controls = None

    def title(self):
        return self.script.name

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        self.postprocessing_controls = self.script.ui()
        return self.postprocessing_controls.values()

    def postprocess_image(self, p, script_pp, *args):
        args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}

        pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
        pp.info = {}
        self.script.process(pp, **args_dict)
        p.extra_generation_params.update(pp.info)
        script_pp.image = pp.image


def create_auto_preprocessing_script_data():
    from modules import scripts

    res = []

    for name in shared.opts.postprocessing_enable_in_main_ui:
        script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
        if script is None:
            continue

        constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
        res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))

    return res
+8 −3
Original line number Diff line number Diff line
@@ -46,6 +46,8 @@ class ScriptPostprocessing:
        pass




def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
        res = func(*args, **kwargs)
@@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
            script: ScriptPostprocessing = script_class()
            script.filename = path

            if script.name == "Simple Upscale":
                continue

            self.scripts.append(script)

    def create_script_ui(self, script, inputs):
@@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
            import modules.scripts
            self.initialize_scripts(modules.scripts.postprocessing_scripts_data)

        scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
        scripts_order = shared.opts.postprocessing_operation_order

        def script_score(name):
            name = name.lower()
            for i, possible_match in enumerate(scripts_order):
                if possible_match in name:
                if possible_match == name:
                    return i

            return len(self.scripts)
@@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
    def image_changed(self):
        for script in self.scripts_in_preferred_order():
            script.image_changed()
+5 −10
Original line number Diff line number Diff line
@@ -13,8 +13,8 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
from modules.paths import models_path, script_path, sd_path
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path


demo = None
@@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")

face_restorers = []


def realesrgan_models_names():
    import modules.realesrgan_model
    return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]


class OptionInfo:
    def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
        self.default = default
@@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), {
    "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
    "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
    "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
    "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
    "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))

@@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
}))

options_templates.update(options_section(('postprocessing', "Postprocessing"), {
    'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
    'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
    'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
    'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
}))

Loading