Unverified Commit 448d6bef authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #12599 from AUTOMATIC1111/ram_optim

RAM optimization round 2
parents 7056fdf2 0dc74545
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -304,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
    wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)

    weights_backup = getattr(self, "network_weights_backup", None)
    if weights_backup is None:
    if weights_backup is None and wanted_names != ():
        if current_names != ():
            raise RuntimeError("no backup weights found and current weights are not unchanged")

        if isinstance(self, torch.nn.MultiheadAttention):
            weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
        else:
+53 −10
Original line number Diff line number Diff line
@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
    ```
    """

    def __init__(self, state_dict, device):
    def __init__(self, state_dict, device, weight_dtype_conversion=None):
        super().__init__()
        self.state_dict = state_dict
        self.device = device
        self.weight_dtype_conversion = weight_dtype_conversion or {}
        self.default_dtype = self.weight_dtype_conversion.get('')

    def get_weight_dtype(self, key):
        key_first_term, _ = key.split('.', 1)
        return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)

    def __enter__(self):
        if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,23 +173,60 @@ class LoadStateDictOnMeta(ReplaceHelper):
        sd = self.state_dict
        device = self.device

        def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
            params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
        def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
            used_param_keys = []

            for name, param in params:
                if param.is_meta:
                    self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
            for name, param in module._parameters.items():
                if param is None:
                    continue

            original(self, state_dict, prefix, *args, **kwargs)
                key = prefix + name
                sd_param = sd.pop(key, None)
                if sd_param is not None:
                    state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
                    used_param_keys.append(key)

            for name, _ in params:
                if param.is_meta:
                    dtype = sd_param.dtype if sd_param is not None else param.dtype
                    module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)

            for name in module._buffers:
                key = prefix + name
                if key in sd:
                    del sd[key]

                sd_param = sd.pop(key, None)
                if sd_param is not None:
                    state_dict[key] = sd_param
                    used_param_keys.append(key)

            original(module, state_dict, prefix, *args, **kwargs)

            for key in used_param_keys:
                state_dict.pop(key, None)

        def load_state_dict(original, module, state_dict, strict=True):
            """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
            because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
            all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.

            In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).

            The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
            the function and does not call the original) the state dict will just fail to load because weights
            would be on the meta device.
            """

            if state_dict == sd:
                state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}

            original(module, state_dict, strict=strict)

        module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
        module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
        linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
        conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
        mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
        layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
        group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.restore()
+20 −2
Original line number Diff line number Diff line
@@ -343,7 +343,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
        model.to(memory_format=torch.channels_last)
        timer.record("apply channels_last")

    if not shared.cmd_opts.no_half:
    if shared.cmd_opts.no_half:
        model.float()
        timer.record("apply float()")
    else:
        vae = model.first_stage_model
        depth_model = getattr(model, 'depth_model', None)

@@ -518,6 +521,13 @@ def send_model_to_cpu(m):
    devices.torch_gc()


def model_target_device():
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        return devices.cpu
    else:
        return devices.device


def send_model_to_device(m):
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +589,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):

    timer.record("create model")

    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
    if shared.cmd_opts.no_half:
        weight_dtype_conversion = None
    else:
        weight_dtype_conversion = {
            'first_stage_model': None,
            '': torch.float16,
        }

    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
    timer.record("load weights from state dict")