Commit 80b26d2a authored by AUTOMATIC's avatar AUTOMATIC
Browse files

apply Lora by altering layer's weights instead of adding more calculations in forward()

parent 69eb2a9e
Loading
Loading
Loading
Loading
+56 −16
Original line number Diff line number Diff line
@@ -131,7 +131,7 @@ def load_lora(name, filename):
        with torch.no_grad():
            module.weight.copy_(weight)

        module.to(device=devices.device, dtype=devices.dtype)
        module.to(device=devices.cpu, dtype=devices.dtype)

        if lora_key == "lora_up.weight":
            lora_module.up = module
@@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
        loaded_loras.append(lora)


def lora_forward(module, input, res):
    input = devices.cond_cast_unet(input)
    if len(loaded_loras) == 0:
        return res
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
    """
    Applies the currently selected set of Loras to the weight of torch layer self.
    If weights already have this particular set of loras applied, does nothing.
    If not, restores orginal weights from backup and alters weights according to loras.
    """

    lora_layer_name = getattr(module, 'lora_layer_name', None)
    current_names = getattr(self, "lora_current_names", ())
    wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)

    weights_backup = getattr(self, "lora_weights_backup", None)
    if weights_backup is None:
        weights_backup = self.weight.to(devices.cpu, copy=True)
        self.lora_weights_backup = weights_backup

    if current_names != wanted_names:
        if weights_backup is not None:
            self.weight.copy_(weights_backup)

        lora_layer_name = getattr(self, 'lora_layer_name', None)
        for lora in loaded_loras:
            module = lora.modules.get(lora_layer_name, None)
        if module is not None:
            if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
                res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
            if module is None:
                continue

            with torch.no_grad():
                up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
                down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)

                if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
                    updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
                else:
                res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
                    updown = up @ down

                self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)

    return res
        setattr(self, "lora_current_names", wanted_names)


def lora_Linear_forward(self, input):
    return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
    lora_apply_weights(self)

    return torch.nn.Linear_forward_before_lora(self, input)


def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
    setattr(self, "lora_current_names", ())
    setattr(self, "lora_weights_backup", None)

    return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)


def lora_Conv2d_forward(self, input):
    return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
    lora_apply_weights(self)

    return torch.nn.Conv2d_forward_before_lora(self, input)


def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
    setattr(self, "lora_current_names", ())
    setattr(self, "lora_weights_backup", None)

    return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)


def list_available_loras():
+10 −2
Original line number Diff line number Diff line
@@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared

def unload():
    torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
    torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
    torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
    torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora


def before_ui():
@@ -20,11 +22,19 @@ def before_ui():
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
    torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward

if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
    torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict

if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
    torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward

if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
    torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict

torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict

script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
@@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui)

shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
    "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),

}))