Unverified Commit 94450b88 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #5589 from MrCheeze/better-special-model-support

Better support for 2.0-inpainting and 2.0-depth special models
parents 9441c28c ec0a4882
Loading
Loading
Loading
Loading
+7 −5
Original line number Diff line number Diff line
@@ -55,18 +55,20 @@ def setup_for_low_vram(sd_model, use_medvram):
    if hasattr(sd_model.cond_stage_model, 'model'):
        sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model

    # remove three big modules, cond, first_stage, and unet from the model and then
    # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
    # send the model to GPU. Then put modules back. the modules will be in CPU.
    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
    sd_model.to(devices.device)
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored

    # register hooks for those the first two models
    # register hooks for those the first three models
    sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.encode = first_stage_model_encode_wrap
    sd_model.first_stage_model.decode = first_stage_model_decode_wrap
    if sd_model.depth_model:
        sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
    parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    if hasattr(sd_model.cond_stage_model, 'model'):
+1 −2
Original line number Diff line number Diff line
@@ -324,12 +324,11 @@ def should_hijack_inpainting(checkpoint_info):

def do_inpainting_hijack():
    # most of this stuff seems to no longer be needed because it is already included into SD2.0
    # LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
    # p_sample_plms is needed because PLMS can't work with dicts as conditionings
    # this file should be cleaned up later if weverything tuens out to work fine

    # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
    ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
    # ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion

    # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
    # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+4 −1
Original line number Diff line number Diff line
@@ -293,13 +293,16 @@ def load_model(checkpoint_info=None):
    if should_hijack_inpainting(checkpoint_info):
        # Hardcoded config for now...
        sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
        sd_config.model.params.use_ema = False
        sd_config.model.params.conditioning_key = "hybrid"
        sd_config.model.params.unet_config.params.in_channels = 9
        sd_config.model.params.finetune_keys = None

        # Create a "fake" config with a different name so that we know to unload it when switching models.
        checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))

    if not hasattr(sd_config.model.params, "use_ema"):
        sd_config.model.params.use_ema = False

    do_inpainting_hijack()

    if shared.cmd_opts.no_half: