Commit d56a9cfe authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Merge branch 'dev' into efficient-vae-methods

parents a6b245e4 a32f270a
Loading
Loading
Loading
Loading
+4 −6
Original line number Original line Diff line number Diff line
@@ -2,16 +2,14 @@ function setupExtraNetworksForTab(tabname) {
    gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
    gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');


    var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
    var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
    var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
    var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
    var search = searchDiv.querySelector('textarea');
    var sort = gradioApp().getElementById(tabname + '_extra_sort');
    var sort = gradioApp().getElementById(tabname + '_extra_sort');
    var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
    var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
    var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
    var refresh = gradioApp().getElementById(tabname + '_extra_refresh');


    search.classList.add('search');
    sort.classList.add('sort');
    sortOrder.classList.add('sortorder');
    sort.dataset.sortkey = 'sortDefault';
    sort.dataset.sortkey = 'sortDefault';
    tabs.appendChild(search);
    tabs.appendChild(searchDiv);
    tabs.appendChild(sort);
    tabs.appendChild(sort);
    tabs.appendChild(sortOrder);
    tabs.appendChild(sortOrder);
    tabs.appendChild(refresh);
    tabs.appendChild(refresh);
@@ -179,7 +177,7 @@ function saveCardPreview(event, tabname, filename) {
}
}


function extraNetworksSearchButton(tabs_id, event) {
function extraNetworksSearchButton(tabs_id, event) {
    var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
    var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
    var button = event.target;
    var button = event.target;
    var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
    var text = button.classList.contains("search-all") ? "" : button.textContent.trim();


+3 −0
Original line number Original line Diff line number Diff line
@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
    if "Hires sampler" not in res:
    if "Hires sampler" not in res:
        res["Hires sampler"] = "Use same sampler"
        res["Hires sampler"] = "Use same sampler"


    if "Hires checkpoint" not in res:
        res["Hires checkpoint"] = "Use same checkpoint"

    if "Hires prompt" not in res:
    if "Hires prompt" not in res:
        res["Hires prompt"] = ""
        res["Hires prompt"] = ""


+1 −1
Original line number Original line Diff line number Diff line
@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
    return res
    return res




invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+3 −0
Original line number Original line Diff line number Diff line
@@ -15,6 +15,9 @@ def send_everything_to_cpu():




def setup_for_low_vram(sd_model, use_medvram):
def setup_for_low_vram(sd_model, use_medvram):
    if getattr(sd_model, 'lowvram', False):
        return

    sd_model.lowvram = True
    sd_model.lowvram = True


    parents = {}
    parents = {}
+66 −32
Original line number Original line Diff line number Diff line
@@ -539,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
    return x
    return x




class DecodedSamples(list):
    already_decoded = True


def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
    samples = []
    samples = DecodedSamples()


    for i in range(batch.shape[0]):
    for i in range(batch.shape[0]):
        sample = decode_first_stage(model, batch[i:i + 1])[0]
        sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -788,8 +792,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)


            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
                p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
                p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

            x_samples_ddim = torch.stack(x_samples_ddim).float()
            x_samples_ddim = torch.stack(x_samples_ddim).float()
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)


@@ -931,7 +939,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    cached_hr_uc = [None, None]
    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
    cached_hr_c = [None, None]


    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
        super().__init__(**kwargs)
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
        self.denoising_strength = denoising_strength
@@ -942,11 +950,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        self.hr_resize_y = hr_resize_y
        self.hr_resize_y = hr_resize_y
        self.hr_upscale_to_x = hr_resize_x
        self.hr_upscale_to_x = hr_resize_x
        self.hr_upscale_to_y = hr_resize_y
        self.hr_upscale_to_y = hr_resize_y
        self.hr_checkpoint_name = hr_checkpoint_name
        self.hr_checkpoint_info = None
        self.hr_sampler_name = hr_sampler_name
        self.hr_sampler_name = hr_sampler_name
        self.hr_prompt = hr_prompt
        self.hr_prompt = hr_prompt
        self.hr_negative_prompt = hr_negative_prompt
        self.hr_negative_prompt = hr_negative_prompt
        self.all_hr_prompts = None
        self.all_hr_prompts = None
        self.all_hr_negative_prompts = None
        self.all_hr_negative_prompts = None
        self.latent_scale_mode = None


        if firstphase_width != 0 or firstphase_height != 0:
        if firstphase_width != 0 or firstphase_height != 0:
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_x = self.width
@@ -969,6 +980,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):


    def init(self, all_prompts, all_seeds, all_subseeds):
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
        if self.enable_hr:
            if self.hr_checkpoint_name:
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

                if self.hr_checkpoint_info is None:
                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name


@@ -978,6 +997,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
            if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
            if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
                self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt


            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
            if self.enable_hr and self.latent_scale_mode is None:
                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")

            if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
            if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                self.hr_resize_x = self.width
                self.hr_resize_x = self.width
                self.hr_resize_y = self.height
                self.hr_resize_y = self.height
@@ -1016,14 +1040,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f


            # special case: the user has chosen to do nothing
            if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
                self.enable_hr = False
                self.denoising_strength = None
                self.extra_generation_params.pop("Hires upscale", None)
                self.extra_generation_params.pop("Hires resize", None)
                return

            if not state.processing_has_refined_job_count:
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                if state.job_count == -1:
                    state.job_count = self.n_iter
                    state.job_count = self.n_iter
@@ -1041,17 +1057,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)


        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
        if self.enable_hr and latent_scale_mode is None:
            if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                raise Exception(f"could not find upscaler named {self.hr_upscaler}")

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
        del x


        if not self.enable_hr:
        if not self.enable_hr:
            return samples
            return samples


        if self.latent_scale_mode is None:
            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
        else:
            decoded_samples = None

        current = shared.sd_model.sd_checkpoint_info
        try:
            if self.hr_checkpoint_info is not None:
                self.sampler = None
                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
                devices.torch_gc()

            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
        finally:
            self.sampler = None
            sd_models.reload_model_weights(info=current)
            devices.torch_gc()

    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
        self.is_hr_pass = True
        self.is_hr_pass = True


        target_width = self.hr_upscale_to_x
        target_width = self.hr_upscale_to_x
@@ -1069,11 +1100,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")


        if latent_scale_mode is not None:
        img2img_sampler_name = self.hr_sampler_name or self.sampler_name

        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
            img2img_sampler_name = 'DDIM'

        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

        if self.latent_scale_mode is not None:
            for i in range(samples.shape[0]):
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)
                save_intermediate(samples, i)


            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])


            # Avoid making the inpainting conditioning unless necessary as
            # Avoid making the inpainting conditioning unless necessary as
            # this does need some extra compute to decode / encode the image again.
            # this does need some extra compute to decode / encode the image again.
@@ -1082,7 +1120,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
            else:
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
                image_conditioning = self.txt2img_image_conditioning(samples)
        else:
        else:
            decoded_samples = decode_first_stage(self.sd_model, samples)
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)


            batch_images = []
            batch_images = []
@@ -1108,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):


        shared.state.nextjob()
        shared.state.nextjob()


        img2img_sampler_name = self.hr_sampler_name or self.sampler_name

        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
            img2img_sampler_name = 'DDIM'

        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]


        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)


        # GC now before running the next img2img to prevent running out of memory
        # GC now before running the next img2img to prevent running out of memory
        x = None
        devices.torch_gc()
        devices.torch_gc()


        if not self.disable_extra_networks:
        if not self.disable_extra_networks:
@@ -1139,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):


        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())


        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

        self.is_hr_pass = False
        self.is_hr_pass = False


        return samples
        return decoded_samples


    def close(self):
    def close(self):
        super().close()
        super().close()
@@ -1180,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        if self.hr_c is not None:
        if self.hr_c is not None:
            return
            return


        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)


    def setup_conds(self):
    def setup_conds(self):
        super().setup_conds()
        super().setup_conds()
@@ -1189,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        self.hr_uc = None
        self.hr_uc = None
        self.hr_c = None
        self.hr_c = None


        if self.enable_hr:
        if self.enable_hr and self.hr_checkpoint_info is None:
            if shared.opts.hires_fix_use_firstpass_conds:
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()
                self.calculate_hr_conds()


Loading