Commit a2d635ad authored by space-nuko's avatar space-nuko
Browse files

Add before_process_batch script callback

parent 0cc0ee1b
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]

            if p.scripts is not None:
                p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)

            if len(prompts) == 0:
                break

+23 −0
Original line number Diff line number Diff line
@@ -80,6 +80,20 @@ class Script:

        pass

    def before_process_batch(self, p, *args, **kwargs):
        """
        Called before extra networks are parsed from the prompt, so you can add
        new extra network keywords to the prompt with this callback.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
        """

        pass

    def process_batch(self, p, *args, **kwargs):
        """
        Same as process(), but called for every batch.
@@ -388,6 +402,15 @@ class ScriptRunner:
                print(f"Error running process: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def before_process_batch(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process_batch(p, *script_args, **kwargs)
            except Exception:
                print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def process_batch(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try: