Commit 50a21cb0 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Ensure the cached weight will not be affected

parent 110485d5
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
        for module in model.modules():
            if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
                if shared.opts.cache_fp16_weight:
                    module.fp16_weight = module.weight.clone().half()
                    module.fp16_weight = module.weight.data.clone().cpu().half()
                    if module.bias is not None:
                        module.fp16_bias = module.bias.clone().half()
                        module.fp16_bias = module.bias.data.clone().cpu().half()
                module.to(torch.float8_e4m3fn)
        model.first_stage_model = first_stage
        timer.record("apply fp8")