Commit fa9370b7 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add refiner to StableDiffusionProcessing class

write out correct model name in infotext, rather than the refiner model
parent b2080756
Loading
Loading
Loading
Loading
+30 −8
Original line number Diff line number Diff line
@@ -111,7 +111,7 @@ class StableDiffusionProcessing:
    cached_uc = [None, None]
    cached_c = [None, None]

    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None):
        if sampler_index is not None:
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)

@@ -153,10 +153,14 @@ class StableDiffusionProcessing:
        self.s_noise = s_noise if s_noise is not None else opts.s_noise
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
        self.override_settings_restore_afterwards = override_settings_restore_afterwards
        self.refiner_checkpoint = refiner_checkpoint
        self.refiner_switch_at = refiner_switch_at

        self.is_using_inpainting_conditioning = False
        self.disable_extra_networks = False
        self.token_merging_ratio = 0
        self.token_merging_ratio_hr = 0
        self.refiner_checkpoint_info = None

        if not seed_enable_extras:
            self.subseed = -1
@@ -191,6 +195,11 @@ class StableDiffusionProcessing:

        self.user = None

        self.sd_model_name = None
        self.sd_model_hash = None
        self.sd_vae_name = None
        self.sd_vae_hash = None

    @property
    def sd_model(self):
        return shared.sd_model
@@ -408,7 +417,10 @@ class Processed:
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.sd_model_name = p.sd_model_name
        self.sd_model_hash = p.sd_model_hash
        self.sd_vae_name = p.sd_vae_name
        self.sd_vae_hash = p.sd_vae_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
@@ -459,7 +471,10 @@ class Processed:
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_name": self.sd_model_name,
            "sd_model_hash": self.sd_model_hash,
            "sd_vae_name": self.sd_vae_name,
            "sd_vae_hash": self.sd_vae_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
@@ -578,10 +593,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
        "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
        "Face restoration": opts.face_restoration_model if p.restore_faces else None,
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
        "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
        "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
        "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
        "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
        "Model": p.sd_model_name if opts.add_model_name_to_info else None,
        "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
        "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -670,8 +685,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    if p.tiling is None:
        p.tiling = opts.tiling

    p.loaded_vae_name = sd_vae.get_loaded_vae_name()
    p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
    if p.refiner_checkpoint not in (None, "", "None"):
        p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
        if p.refiner_checkpoint_info is None:
            raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')

    p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
    p.sd_model_hash = shared.sd_model.sd_model_hash
    p.sd_vae_name = sd_vae.get_loaded_vae_name()
    p.sd_vae_hash = sd_vae.get_loaded_vae_hash()

    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
    modules.sd_hijack.model_hijack.clear_comments()
+5 −11
Original line number Diff line number Diff line
@@ -41,15 +41,9 @@ class ScriptRefiner(scripts.Script):
    def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
        # the actual implementation is in sd_samplers_common.py, apply_refiner

        if not enable_refiner or refiner_checkpoint in (None, "", "None"):
            p.refiner_checkpoint_info = None
            p.refiner_switch_at = None

        if not enable_refiner or refiner_checkpoint in (None, "", "None"):
            return

        refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
        if refiner_checkpoint_info is None:
            raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')

        p.refiner_checkpoint_info = refiner_checkpoint_info
        else:
            p.refiner_checkpoint = refiner_checkpoint
            p.refiner_switch_at = refiner_switch_at
+1 −1
Original line number Diff line number Diff line
@@ -145,7 +145,7 @@ def apply_refiner(cfg_denoiser):
    refiner_switch_at = cfg_denoiser.p.refiner_switch_at
    refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info

    if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
    if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
        return False

    if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: