Commit fbf88343 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

prevent calculating cons for second pass of hires fix when they are the same as for the first pass

parent 1ca5e76f
Loading
Loading
Loading
Loading
+13 −7
Original line number Original line Diff line number Diff line
@@ -312,7 +312,7 @@ class StableDiffusionProcessing:
        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]


    def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data):
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
        """
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.
        using a cache to store the result if the same arguments have been used before.
@@ -321,10 +321,16 @@ class StableDiffusionProcessing:
        representing the previously used arguments, or None if no arguments
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        have been used before. The second element is where the previously
        computed result is stored.
        computed result is stored.

        caches is a list with items described above.
        """
        """

        for cache in caches:
            if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
            if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
                return cache[1]
                return cache[1]


        cache = caches[0]

        with devices.autocast():
        with devices.autocast():
            cache[1] = function(shared.sd_model, required_prompts, steps)
            cache[1] = function(shared.sd_model, required_prompts, steps)


@@ -335,8 +341,8 @@ class StableDiffusionProcessing:
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1


        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data)
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)


    def parse_extra_network_prompts(self):
    def parse_extra_network_prompts(self):
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -1106,8 +1112,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        if self.hr_c is not None:
        if self.hr_c is not None:
            return
            return


        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data)
        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)


    def setup_conds(self):
    def setup_conds(self):
        super().setup_conds()
        super().setup_conds()