Commit 11d23e8c authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

remove Train/Preprocessing tab and put all its functionality into extras batch images mode

parent 4a666381
Loading
Loading
Loading
Loading
+17 −0
Original line number Diff line number Diff line
@@ -170,6 +170,23 @@ function submit_img2img() {
    return res;
}

function submit_extras() {
    showSubmitButtons('extras', false);

    var id = randomId();

    requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
        showSubmitButtons('extras', true);
    });

    var res = create_submit_args(arguments);

    res[0] = id;

    console.log(res);
    return res;
}

function restoreProgressTxt2img() {
    showRestoreProgressButton("txt2img", false);
    var id = localGet("txt2img_task_id");
+0 −15
Original line number Diff line number Diff line
@@ -22,7 +22,6 @@ from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
@@ -235,7 +234,6 @@ class Api:
        self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
@@ -675,19 +673,6 @@ class Api:
        finally:
            shared.state.end()

    def preprocess(self, args: dict):
        try:
            shared.state.begin(job="preprocess")
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return models.PreprocessResponse(info='preprocess complete')
        except KeyError as e:
            return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
        except Exception as e:
            return models.PreprocessResponse(info=f"preprocess error: {e}")
        finally:
            shared.state.end()

    def train_embedding(self, args: dict):
        try:
            shared.state.begin(job="train_embedding")
+0 −3
Original line number Diff line number Diff line
@@ -202,9 +202,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel):
    info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")

class PreprocessResponse(BaseModel):
    info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")

fields = {}
for key, metadata in opts.data_labels.items():
    value = opts.data.get(key)
+69 −23
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from modules import shared, images, devices, scripts, scripts_postprocessing, ui
from modules.shared import opts


def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
    devices.torch_gc()

    shared.state.begin(job="extras")
@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,

            image_list = shared.listfiles(input_dir)
            for filename in image_list:
                try:
                    image = Image.open(filename)
                except Exception:
                    continue
                yield image, filename
                yield filename, filename
        else:
            assert image, 'image not selected'
            yield image, None
@@ -45,22 +41,47 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,

    infotext = ''

    for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
    data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
    shared.state.job_count = len(data_to_process)

    for image_placeholder, name in data_to_process:
        image_data: Image.Image

        shared.state.nextjob()
        shared.state.textinfo = name
        shared.state.skipped = False

        if shared.state.interrupted:
            break

        if isinstance(image_placeholder, str):
            try:
                image_data = Image.open(image_placeholder)
            except Exception:
                continue
        else:
            image_data = image_placeholder

        shared.state.assign_current_image(image_data)

        parameters, existing_pnginfo = images.read_info_from_image(image_data)
        if parameters:
            existing_pnginfo["parameters"] = parameters

        pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
        initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))

        scripts.scripts_postproc.run(initial_pp, args)

        if shared.state.skipped:
            continue

        scripts.scripts_postproc.run(pp, args)
        used_suffixes = {}
        for pp in [initial_pp, *initial_pp.extra_images]:
            suffix = pp.get_suffix(used_suffixes)

            if opts.use_original_name_batch and name is not None:
                basename = os.path.splitext(os.path.basename(name))[0]
            forced_filename = basename
                forced_filename = basename + suffix
            else:
                basename = ''
                forced_filename = None
@@ -72,7 +93,30 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
                pp.image.info["postprocessing"] = infotext

            if save_output:
            images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename)
                fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)

                if pp.caption:
                    caption_filename = os.path.splitext(fullfn)[0] + ".txt"
                    if os.path.isfile(caption_filename):
                        with open(caption_filename, encoding="utf8") as file:
                            existing_caption = file.read().strip()
                    else:
                        existing_caption = ""

                    action = shared.opts.postprocessing_existing_caption_action
                    if action == 'Prepend' and existing_caption:
                        caption = f"{existing_caption} {pp.caption}"
                    elif action == 'Append' and existing_caption:
                        caption = f"{pp.caption} {existing_caption}"
                    elif action == 'Keep' and existing_caption:
                        caption = existing_caption
                    else:
                        caption = pp.caption

                    caption = caption.strip()
                    if caption:
                        with open(caption_filename, "w", encoding="utf8") as file:
                            file.write(caption)

            if extras_mode != 2 or show_extras_results:
                outputs.append(pp.image)
@@ -99,9 +143,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
            "upscaler_2_visibility": extras_upscaler_2_visibility,
        },
        "GFPGAN": {
            "enable": True,
            "gfpgan_visibility": gfpgan_visibility,
        },
        "CodeFormer": {
            "enable": True,
            "codeformer_visibility": codeformer_visibility,
            "codeformer_weight": codeformer_weight,
        },
+81 −5
Original line number Diff line number Diff line
import dataclasses
import os
import gradio as gr

from modules import errors, shared


@dataclasses.dataclass
class PostprocessedImageSharedInfo:
    target_width: int = None
    target_height: int = None


class PostprocessedImage:
    def __init__(self, image):
        self.image = image
        self.info = {}
        self.shared = PostprocessedImageSharedInfo()
        self.extra_images = []
        self.nametags = []
        self.disable_processing = False
        self.caption = None

    def get_suffix(self, used_suffixes=None):
        used_suffixes = {} if used_suffixes is None else used_suffixes
        suffix = "-".join(self.nametags)
        if suffix:
            suffix = "-" + suffix

        if suffix not in used_suffixes:
            used_suffixes[suffix] = 1
            return suffix

        for i in range(1, 100):
            proposed_suffix = suffix + "-" + str(i)

            if proposed_suffix not in used_suffixes:
                used_suffixes[proposed_suffix] = 1
                return proposed_suffix

        return suffix

    def create_copy(self, new_image, *, nametags=None, disable_processing=False):
        pp = PostprocessedImage(new_image)
        pp.shared = self.shared
        pp.nametags = self.nametags.copy()
        pp.info = self.info.copy()
        pp.disable_processing = disable_processing

        if nametags is not None:
            pp.nametags += nametags

        return pp


class ScriptPostprocessing:
@@ -42,10 +85,17 @@ class ScriptPostprocessing:

        pass

    def image_changed(self):
        pass
    def process_firstpass(self, pp: PostprocessedImage, **args):
        """
        Called for all scripts before calling process(). Scripts can examine the image here and set fields
        of the pp object to communicate things to other scripts.
        args contains a dictionary with all values returned by components from ui()
        """

        pass

    def image_changed(self):
        pass


def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
        return inputs

    def run(self, pp: PostprocessedImage, args):
        for script in self.scripts_in_preferred_order():
            shared.state.job = script.name
        scripts = []

        for script in self.scripts_in_preferred_order():
            script_args = args[script.args_from:script.args_to]

            process_args = {}
            for (name, _component), value in zip(script.controls.items(), script_args):
                process_args[name] = value

            script.process(pp, **process_args)
            scripts.append((script, process_args))

        for script, process_args in scripts:
            script.process_firstpass(pp, **process_args)

        all_images = [pp]

        for script, process_args in scripts:
            if shared.state.skipped:
                break

            shared.state.job = script.name

            for single_image in all_images.copy():

                if not single_image.disable_processing:
                    script.process(single_image, **process_args)

                for extra_image in single_image.extra_images:
                    if not isinstance(extra_image, PostprocessedImage):
                        extra_image = single_image.create_copy(extra_image)

                    all_images.append(extra_image)

                single_image.extra_images.clear()

        pp.extra_images = all_images[1:]

    def create_args_for_run(self, scripts_args):
        if not self.ui_created:
Loading