Commit ca45ff1a authored by ljleb's avatar ljleb
Browse files

add postprocess_batch_list callback

parent f4519940
Loading
Loading
Loading
Loading
+23 −1
Original line number Diff line number Diff line
@@ -717,7 +717,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]

    def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
        all_prompts = p.all_prompts[:]
        all_seeds = p.all_seeds[:]
        all_subseeds = p.all_subseeds[:]

        # apply changes to generation data
        all_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.prompts
        all_seeds[n * p.batch_size:(n + 1) * p.batch_size] = p.seeds
        all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] = p.subseeds

        # update p.all_negative_prompts in case extensions changed the size of the batch
        # create_infotext below uses it
        old_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
        p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.negative_prompts

        try:
            return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
        finally:
            # restore p.all_negative_prompts in case extensions changed the size of the batch
            p.all_negative_prompts[n * p.batch_size:n * p.batch_size + len(p.negative_prompts)] = old_negative_prompts

    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
        model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -806,6 +824,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)

                postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
                p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
                x_samples_ddim = postprocess_batch_list_args.images

            for i, x_sample in enumerate(x_samples_ddim):
                p.batch_index = i

+32 −0
Original line number Diff line number Diff line
@@ -16,6 +16,11 @@ class PostprocessImageArgs:
        self.image = image


class PostprocessBatchListArgs:
    def __init__(self, images):
        self.images = images


class Script:
    name = None
    """script's internal name derived from title"""
@@ -156,6 +161,25 @@ class Script:

        pass

    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
        """
        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
        This is useful when you want to update the entire batch instead of individual images.

        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
        If the number of images is different from the batch size when returning,
        then the script has the responsibility to also update the following attributes in the processing object (p):
          - p.prompts
          - p.negative_prompts
          - p.seeds
          - p.subseeds

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
        """

        pass

    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
@@ -536,6 +560,14 @@ class ScriptRunner:
            except Exception:
                errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)

    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch_list(p, pp, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try: