Commit 847f869c authored by AUTOMATIC's avatar AUTOMATIC
Browse files

experimental optimization

parent f8d0cf6a
Loading
Loading
Loading
Loading
+25 −3
Original line number Diff line number Diff line
@@ -544,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    infotexts = []
    output_images = []

    cached_uc = [None, None]
    cached_c = [None, None]

    def get_conds_with_caching(function, required_prompts, steps, cache):
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.

        cache is an array containing two elements. The first element is a tuple
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.
        """

        if cache[0] is not None and (required_prompts, steps) == cache[0]:
            return cache[1]

        with devices.autocast():
            cache[1] = function(shared.sd_model, required_prompts, steps)

        cache[0] = (required_prompts, steps)
        return cache[1]

    with torch.no_grad(), p.sd_model.ema_scope():
        with devices.autocast():
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -571,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.scripts is not None:
                p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)

            with devices.autocast():
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
            uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
            c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)

            if len(model_hijack.comments) > 0:
                for comment in model_hijack.comments: