Unverified Commit f7c0a963 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #11957 from ljleb/pp-batch-list

Add postprocess_batch_list script callback
parents f4519940 5b066074
Loading
Loading
Loading
Loading
+25 −1
Original line number Diff line number Diff line
@@ -717,7 +717,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]

    def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
        all_prompts = p.all_prompts[:]
        all_negative_prompts = p.all_negative_prompts[:]
        all_seeds = p.all_seeds[:]
        all_subseeds = p.all_subseeds[:]

        # apply changes to generation data
        all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts
        all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts
        all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds
        all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds

        # update p.all_negative_prompts in case extensions changed the size of the batch
        # create_infotext below uses it
        old_negative_prompts = p.all_negative_prompts
        p.all_negative_prompts = all_negative_prompts

        try:
            return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
        finally:
            # restore p.all_negative_prompts in case extensions changed the size of the batch
            p.all_negative_prompts = old_negative_prompts

    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
        model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -806,6 +826,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)

                postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
                p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
                x_samples_ddim = postprocess_batch_list_args.images

            for i, x_sample in enumerate(x_samples_ddim):
                p.batch_index = i

+32 −0
Original line number Diff line number Diff line
@@ -16,6 +16,11 @@ class PostprocessImageArgs:
        self.image = image


class PostprocessBatchListArgs:
    def __init__(self, images):
        self.images = images


class Script:
    name = None
    """script's internal name derived from title"""
@@ -156,6 +161,25 @@ class Script:

        pass

    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
        """
        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
        This is useful when you want to update the entire batch instead of individual images.

        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
        If the number of images is different from the batch size when returning,
        then the script has the responsibility to also update the following attributes in the processing object (p):
          - p.prompts
          - p.negative_prompts
          - p.seeds
          - p.subseeds

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
        """

        pass

    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
@@ -536,6 +560,14 @@ class ScriptRunner:
            except Exception:
                errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)

    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch_list(p, pp, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try: