Commit 9bb6b650 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add postprocess call for scripts

parent 35c45df2
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
        model_hijack.embedding_db.load_textual_inversion_embeddings()

    if p.scripts is not None:
        p.scripts.run_alwayson_scripts(p)
        p.scripts.process(p)

    infotexts = []
    output_images = []
@@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]

            if (len(prompts) == 0):
            if len(prompts) == 0:
                break

            with devices.autocast():
@@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)

    devices.torch_gc()
    return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)

    res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
+21 −3
Original line number Diff line number Diff line
@@ -64,7 +64,16 @@ class Script:
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
        scripts. You can modify the processing object (p) here, inject hooks, etc.
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
        """

        pass
@@ -289,13 +298,22 @@ class ScriptRunner:

        return processed

    def run_alwayson_scripts(self, p):
    def process(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
                print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
                print(f"Error running process: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def reload_sources(self, cache):