Commit f2a5cbe6 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix #3986 breaking --no-half-vae

parent 675b51eb
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
            model.to(memory_format=torch.channels_last)

        if not shared.cmd_opts.no_half:
            vae = model.first_stage_model

            # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
            if shared.cmd_opts.no_half_vae:
                model.first_stage_model = None

            model.half()
            model.first_stage_model = vae

        devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
        devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16

        model.first_stage_model.to(devices.dtype_vae)

        if shared.opts.sd_checkpoint_cache > 0:
            # if PR #4035 were to get merged, restore base VAE first before caching
            checkpoints_loaded[checkpoint_key] = model.state_dict().copy()