Commit fbf88343 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

prevent calculating cons for second pass of hires fix when they are the same as for the first pass

parent 1ca5e76f
Loading
Loading
Loading
Loading
+13 −7
Original line number Diff line number Diff line
@@ -312,7 +312,7 @@ class StableDiffusionProcessing:
        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

    def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data):
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.
@@ -321,10 +321,16 @@ class StableDiffusionProcessing:
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.

        caches is a list with items described above.
        """

        for cache in caches:
            if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
                return cache[1]

        cache = caches[0]

        with devices.autocast():
            cache[1] = function(shared.sd_model, required_prompts, steps)

@@ -335,8 +341,8 @@ class StableDiffusionProcessing:
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1

        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data)
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)

    def parse_extra_network_prompts(self):
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -1106,8 +1112,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        if self.hr_c is not None:
            return

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data)
        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)

    def setup_conds(self):
        super().setup_conds()