Commit 1ca5e76f authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix for conds of second hires fox pass being calculated using first pass's...

fix for conds of second hires fox pass being calculated using first pass's networks, and add an option to revert to old behavior
parent 1c6dca93
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -15,6 +15,8 @@ def send_everything_to_cpu():


def setup_for_low_vram(sd_model, use_medvram):
    sd_model.lowvram = True

    parents = {}

    def send_me_to_gpu(module, _):
@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
        diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
        for block in diff_model.output_blocks:
            block.register_forward_pre_hook(send_me_to_gpu)


def is_enabled(sd_model):
    return getattr(sd_model, 'lowvram', False)
+29 −3
Original line number Diff line number Diff line
@@ -739,7 +739,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:

            del samples_ddim

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
            if lowvram.is_enabled(shared.sd_model):
                lowvram.send_everything_to_cpu()

            devices.torch_gc()
@@ -894,6 +894,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        self.hr_negative_prompts = None
        self.hr_extra_network_data = None

        self.cached_hr_uc = [None, None]
        self.cached_hr_c = [None, None]
        self.hr_c = None
        self.hr_uc = None

@@ -1056,6 +1058,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
            with devices.autocast():
                extra_networks.activate(self, self.hr_extra_network_data)

        with devices.autocast():
            self.calculate_hr_conds()

        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))

        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
@@ -1067,6 +1072,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        return samples

    def close(self):
        self.cached_hr_uc = [None, None]
        self.cached_hr_c = [None, None]
        self.hr_c = None
        self.hr_uc = None

@@ -1095,12 +1102,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]

    def calculate_hr_conds(self):
        if self.hr_c is not None:
            return

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data)

    def setup_conds(self):
        super().setup_conds()

        self.hr_uc = None
        self.hr_c = None

        if self.enable_hr:
            self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
            self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()

            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
                with devices.autocast():
                    extra_networks.activate(self, self.hr_extra_network_data)

                self.calculate_hr_conds()

                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)

    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()
+1 −0
Original line number Diff line number Diff line
@@ -429,6 +429,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
    "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
    "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
    "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
    "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
}))

options_templates.update(options_section(('interrogate', "Interrogate Options"), {