Commit d61e31ba authored by Robert Barron's avatar Robert Barron
Browse files

Merge remote-tracking branch 'auto1111/dev' into shared-hires-prompt-test

parents 54f926b1 f3b96d49
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
    def __init__(self):
        super().__init__('lora')

        self.errors = {}
        """mapping of network names to the number of errors the network had during operation"""

    def activate(self, p, params_list):
        additional = shared.opts.sd_lora

        self.errors.clear()

        if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
            p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
            params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
                p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)

    def deactivate(self, p):
        pass
        if self.errors:
            p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))

            self.errors.clear()
+5 −2
Original line number Diff line number Diff line
@@ -133,7 +133,7 @@ class NetworkModule:

        return 1.0

    def finalize_updown(self, updown, orig_weight, output_shape):
    def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
        if self.bias is not None:
            updown = updown.reshape(self.bias.shape)
            updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@@ -145,7 +145,10 @@ class NetworkModule:
        if orig_weight.size().numel() == updown.size().numel():
            updown = updown.reshape(orig_weight.shape)

        return updown * self.calc_scale() * self.multiplier()
        if ex_bias is not None:
            ex_bias = ex_bias * self.multiplier()

        return updown * self.calc_scale() * self.multiplier(), ex_bias

    def calc_updown(self, target):
        raise NotImplementedError()
+28 −0
Original line number Diff line number Diff line
import network


class ModuleTypeNorm(network.ModuleType):
    def create_module(self, net: network.Network, weights: network.NetworkWeights):
        if all(x in weights.w for x in ["w_norm", "b_norm"]):
            return NetworkModuleNorm(net, weights)

        return None


class NetworkModuleNorm(network.NetworkModule):
    def __init__(self,  net: network.Network, weights: network.NetworkWeights):
        super().__init__(net, weights)

        self.w_norm = weights.w.get("w_norm")
        self.b_norm = weights.w.get("b_norm")

    def calc_updown(self, orig_weight):
        output_shape = self.w_norm.shape
        updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)

        if self.b_norm is not None:
            ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
        else:
            ex_bias = None

        return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
+111 −31
Original line number Diff line number Diff line
import logging
import os
import re

@@ -7,6 +8,7 @@ import network_hada
import network_ia3
import network_lokr
import network_full
import network_norm

import torch
from typing import Union
@@ -19,6 +21,7 @@ module_types = [
    network_ia3.ModuleTypeIa3(),
    network_lokr.ModuleTypeLokr(),
    network_full.ModuleTypeFull(),
    network_norm.ModuleTypeNorm(),
]


@@ -31,6 +34,8 @@ suffix_conversion = {
    "resnets": {
        "conv1": "in_layers_2",
        "conv2": "out_layers_3",
        "norm1": "in_layers_0",
        "norm2": "out_layers_0",
        "time_emb_proj": "emb_layers_1",
        "conv_shortcut": "skip_connection",
    }
@@ -190,7 +195,7 @@ def load_network(name, network_on_disk):
        net.modules[key] = net_module

    if keys_failed_to_match:
        print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
        logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")

    return net

@@ -203,7 +208,6 @@ def purge_networks_from_memory():
    devices.torch_gc()



def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
    already_loaded = {}

@@ -244,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No

        if net is None:
            failed_to_load_networks.append(name)
            print(f"Couldn't find network with name {name}")
            logging.info(f"Couldn't find network with name {name}")
            continue

        net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@@ -253,25 +257,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
        loaded_networks.append(net)

    if failed_to_load_networks:
        sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
        sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))

    purge_networks_from_memory()


def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
    weights_backup = getattr(self, "network_weights_backup", None)
    bias_backup = getattr(self, "network_bias_backup", None)

    if weights_backup is None:
    if weights_backup is None and bias_backup is None:
        return

    if weights_backup is not None:
        if isinstance(self, torch.nn.MultiheadAttention):
            self.in_proj_weight.copy_(weights_backup[0])
            self.out_proj.weight.copy_(weights_backup[1])
        else:
            self.weight.copy_(weights_backup)

    if bias_backup is not None:
        if isinstance(self, torch.nn.MultiheadAttention):
            self.out_proj.bias.copy_(bias_backup)
        else:
            self.bias.copy_(bias_backup)
    else:
        if isinstance(self, torch.nn.MultiheadAttention):
            self.out_proj.bias = None
        else:
            self.bias = None


def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
    """
    Applies the currently selected set of networks to the weights of torch layer self.
    If weights already have this particular set of networks applied, does nothing.
@@ -294,20 +311,40 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

        self.network_weights_backup = weights_backup

    bias_backup = getattr(self, "network_bias_backup", None)
    if bias_backup is None:
        if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
            bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
        elif getattr(self, 'bias', None) is not None:
            bias_backup = self.bias.to(devices.cpu, copy=True)
        else:
            bias_backup = None
        self.network_bias_backup = bias_backup

    if current_names != wanted_names:
        network_restore_weights_from_backup(self)

        for net in loaded_networks:
            module = net.modules.get(network_layer_name, None)
            if module is not None and hasattr(self, 'weight'):
                try:
                    with torch.no_grad():
                    updown = module.calc_updown(self.weight)
                        updown, ex_bias = module.calc_updown(self.weight)

                        if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
                            # inpainting model. zero pad updown to make channel[1]  4 to 9
                            updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))

                        self.weight += updown
                        if ex_bias is not None and hasattr(self, 'bias'):
                            if self.bias is None:
                                self.bias = torch.nn.Parameter(ex_bias)
                            else:
                                self.bias += ex_bias
                except RuntimeError as e:
                    logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
                    extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

                continue

            module_q = net.modules.get(network_layer_name + "_q_proj", None)
@@ -316,21 +353,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
            module_out = net.modules.get(network_layer_name + "_out_proj", None)

            if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
                try:
                    with torch.no_grad():
                    updown_q = module_q.calc_updown(self.in_proj_weight)
                    updown_k = module_k.calc_updown(self.in_proj_weight)
                    updown_v = module_v.calc_updown(self.in_proj_weight)
                        updown_q, _ = module_q.calc_updown(self.in_proj_weight)
                        updown_k, _ = module_k.calc_updown(self.in_proj_weight)
                        updown_v, _ = module_v.calc_updown(self.in_proj_weight)
                        updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
                    updown_out = module_out.calc_updown(self.out_proj.weight)
                        updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)

                        self.in_proj_weight += updown_qkv
                        self.out_proj.weight += updown_out
                    if ex_bias is not None:
                        if self.out_proj.bias is None:
                            self.out_proj.bias = torch.nn.Parameter(ex_bias)
                        else:
                            self.out_proj.bias += ex_bias

                except RuntimeError as e:
                    logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
                    extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

                continue

            if module is None:
                continue

            print(f'failed to calculate network weights for layer {network_layer_name}')
            logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
            extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

        self.network_current_names = wanted_names

@@ -357,7 +406,7 @@ def network_forward(module, input, original_forward):
        if module is None:
            continue

        y = module.forward(y, input)
        y = module.forward(input, y)

    return y

@@ -397,6 +446,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
    return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)


def network_GroupNorm_forward(self, input):
    if shared.opts.lora_functional:
        return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)

    network_apply_weights(self)

    return torch.nn.GroupNorm_forward_before_network(self, input)


def network_GroupNorm_load_state_dict(self, *args, **kwargs):
    network_reset_cached_weight(self)

    return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)


def network_LayerNorm_forward(self, input):
    if shared.opts.lora_functional:
        return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)

    network_apply_weights(self)

    return torch.nn.LayerNorm_forward_before_network(self, input)


def network_LayerNorm_load_state_dict(self, *args, **kwargs):
    network_reset_cached_weight(self)

    return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)


def network_MultiheadAttention_forward(self, *args, **kwargs):
    network_apply_weights(self)

@@ -473,6 +552,7 @@ def infotext_pasted(infotext, params):
    if added:
        params["Prompt"] += "\n" + "".join(added)

extra_network_lora = None

available_networks = {}
available_network_aliases = {}
+19 −3
Original line number Diff line number Diff line
@@ -23,9 +23,9 @@ def unload():
def before_ui():
    ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())

    extra_network = extra_networks_lora.ExtraNetworkLora()
    extra_networks.register_extra_network(extra_network)
    extra_networks.register_extra_network_alias(extra_network, "lyco")
    networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
    extra_networks.register_extra_network(networks.extra_network_lora)
    extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")


if not hasattr(torch.nn, 'Linear_forward_before_network'):
@@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
    torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict

if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
    torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward

if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
    torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict

if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
    torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward

if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
    torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict

if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
    torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward

@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict

Loading