Commit ef27a18b authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Hires fix rework

parent fd4461d4
Loading
Loading
Loading
Loading
+32 −0
Original line number Diff line number Diff line
import base64
import io
import math
import os
import re
from pathlib import Path
@@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
    return None


def restore_old_hires_fix_params(res):
    """for infotexts that specify old First pass size parameter, convert it into
    width, height, and hr scale"""

    firstpass_width = res.get('First pass size-1', None)
    firstpass_height = res.get('First pass size-2', None)

    if firstpass_width is None or firstpass_height is None:
        return

    firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
    width = int(res.get("Size-1", 512))
    height = int(res.get("Size-2", 512))

    if firstpass_width == 0 or firstpass_height == 0:
        # old algorithm for auto-calculating first pass size
        desired_pixel_count = 512 * 512
        actual_pixel_count = width * height
        scale = math.sqrt(desired_pixel_count / actual_pixel_count)
        firstpass_width = math.ceil(scale * width / 64) * 64
        firstpass_height = math.ceil(scale * height / 64) * 64

    hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height

    res['Size-1'] = firstpass_width
    res['Size-2'] = firstpass_height
    res['Hires upscale'] = hr_scale


def parse_generation_parameters(x: str):
    """parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
        hypernet_hash = res.get("Hypernet hash", None)
        res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)

    restore_old_hires_fix_params(res)

    return res


+20 −4
Original line number Diff line number Diff line
@@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
    return draw_grid_annotations(im, width, height, hor_texts, ver_texts)


def resize_image(resize_mode, im, width, height):
def resize_image(resize_mode, im, width, height, upscaler_name=None):
    """
    Resizes an image with the specified resize_mode, width, and height.

    Args:
        resize_mode: The mode to use when resizing the image.
            0: Resize the image to the specified width and height.
            1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
            2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
        im: The image to resize.
        width: The width to resize the image to.
        height: The height to resize the image to.
        upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
    """

    upscaler_name = upscaler_name or opts.upscaler_for_img2img

    def resize(im, w, h):
        if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
        if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
            return im.resize((w, h), resample=LANCZOS)

        scale = max(w / im.width, h / im.height)

        if scale > 1.0:
            upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
            assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
            upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
            assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"

            upscaler = upscalers[0]
            im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
+27 −41
Original line number Diff line number Diff line
@@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None

    def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
        self.firstphase_width = firstphase_width
        self.firstphase_height = firstphase_height
        self.truncate_x = 0
        self.truncate_y = 0
        self.hr_scale = hr_scale
        self.hr_upscaler = hr_upscaler

        if firstphase_width != 0 or firstphase_height != 0:
            print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
            self.hr_scale = self.width / firstphase_width
            self.width = firstphase_width
            self.height = firstphase_height

    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
@@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
            else:
                state.job_count = state.job_count * 2

            self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"

            if self.firstphase_width == 0 or self.firstphase_height == 0:
                desired_pixel_count = 512 * 512
                actual_pixel_count = self.width * self.height
                scale = math.sqrt(desired_pixel_count / actual_pixel_count)
                self.firstphase_width = math.ceil(scale * self.width / 64) * 64
                self.firstphase_height = math.ceil(scale * self.height / 64) * 64
                firstphase_width_truncated = int(scale * self.width)
                firstphase_height_truncated = int(scale * self.height)

            else:

                width_ratio = self.width / self.firstphase_width
                height_ratio = self.height / self.firstphase_height

                if width_ratio > height_ratio:
                    firstphase_width_truncated = self.firstphase_width
                    firstphase_height_truncated = self.firstphase_width * self.height / self.width
                else:
                    firstphase_width_truncated = self.firstphase_height * self.width / self.height
                    firstphase_height_truncated = self.firstphase_height

            self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
            self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
            self.extra_generation_params["Hires upscale"] = self.hr_scale
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler

    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        if not self.enable_hr:
        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode
        if self.enable_hr and latent_scale_mode is None:
            assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
            return samples

        x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
        if not self.enable_hr:
            return samples

        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
        target_width = int(self.width * self.hr_scale)
        target_height = int(self.height * self.hr_scale)

        """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
        def save_intermediate(image, index):
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
                return

@@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")

        if opts.use_scale_latent_for_hires_fix:
        if latent_scale_mode is not None:
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

            samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)

            # Avoid making the inpainting conditioning unless necessary as
            # this does need some extra compute to decode / encode the image again.
@@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

                save_intermediate(image, i)

                image = images.resize_image(0, image, self.width, self.height)
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)
@@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)

        # GC now before running the next img2img to prevent running out of memory
        x = None
+6 −1
Original line number Diff line number Diff line
@@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
    "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
    "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
    "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
    "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))

options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -545,6 +544,12 @@ opts = Options()
if os.path.exists(config_filename):
    opts.load(config_filename)

latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
    "Latent": "bilinear",
    "Latent (nearest)": "nearest",
}

sd_upscalers = []

sd_model = None
+3 −3
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html


def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
    p = StableDiffusionProcessingTxt2Img(
        sd_model=shared.sd_model,
        outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
        tiling=tiling,
        enable_hr=enable_hr,
        denoising_strength=denoising_strength if enable_hr else None,
        firstphase_width=firstphase_width if enable_hr else None,
        firstphase_height=firstphase_height if enable_hr else None,
        hr_scale=hr_scale,
        hr_upscaler=hr_upscaler,
    )

    p.scripts = modules.scripts.scripts_txt2img
Loading