Commit eaba3d73 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

send weights to target device instead of CPU memory

parent 57e59c14
Loading
Loading
Loading
Loading
+15 −9
Original line number Diff line number Diff line
@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
    ```
    """

    def __init__(self, state_dict, device):
    def __init__(self, state_dict, device, weight_dtype_conversion=None):
        super().__init__()
        self.state_dict = state_dict
        self.device = device
        self.weight_dtype_conversion = weight_dtype_conversion or {}
        self.default_dtype = self.weight_dtype_conversion.get('')

    def get_weight_dtype(self, key):
        key_first_term, _ = key.split('.', 1)
        return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)

    def __enter__(self):
        if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper):
        sd = self.state_dict
        device = self.device

        def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
        def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
            used_param_keys = []

            for name, param in self._parameters.items():
            for name, param in module._parameters.items():
                if param is None:
                    continue

                key = prefix + name
                sd_param = sd.pop(key, None)
                if sd_param is not None:
                    state_dict[key] = sd_param
                    state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
                    used_param_keys.append(key)

                if param.is_meta:
                    dtype = sd_param.dtype if sd_param is not None else param.dtype
                    self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
                    module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)

            for name in self._buffers:
            for name in module._buffers:
                key = prefix + name

                sd_param = sd.pop(key, None)
@@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper):
                    state_dict[key] = sd_param
                    used_param_keys.append(key)

            original(self, state_dict, prefix, *args, **kwargs)
            original(module, state_dict, prefix, *args, **kwargs)

            for key in used_param_keys:
                state_dict.pop(key, None)

        def load_state_dict(original, self, state_dict, strict=True):
        def load_state_dict(original, module, state_dict, strict=True):
            """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
            because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
            all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
@@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
            if state_dict == sd:
                state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}

            original(self, state_dict, strict=strict)
            original(module, state_dict, strict=strict)

        module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
        module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
+16 −1
Original line number Diff line number Diff line
@@ -518,6 +518,13 @@ def send_model_to_cpu(m):
    devices.torch_gc()


def model_target_device():
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        return devices.cpu
    else:
        return devices.device


def send_model_to_device(m):
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):

    timer.record("create model")

    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
    if shared.cmd_opts.no_half:
        weight_dtype_conversion = None
    else:
        weight_dtype_conversion = {
            'first_stage_model': None,
            '': torch.float16,
        }

    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
    timer.record("load weights from state dict")