Commit b75b004f authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

lora extension rework to include other types of networks

parent 7d26c479
Loading
Loading
Loading
Loading
+9 −9
Original line number Diff line number Diff line
from modules import extra_networks, shared
import lora
import networks


class ExtraNetworkLora(extra_networks.ExtraNetwork):
@@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
    def activate(self, p, params_list):
        additional = shared.opts.sd_lora

        if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
        if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
            p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
            params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

@@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
            names.append(params.items[0])
            multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)

        lora.load_loras(names, multipliers)
        networks.load_networks(names, multipliers)

        if shared.opts.lora_add_hashes_to_infotext:
            lora_hashes = []
            for item in lora.loaded_loras:
                shorthash = item.lora_on_disk.shorthash
            network_hashes = []
            for item in networks.loaded_networks:
                shorthash = item.network_on_disk.shorthash
                if not shorthash:
                    continue

@@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):

                alias = alias.replace(":", "").replace(",", "")

                lora_hashes.append(f"{alias}: {shorthash}")
                network_hashes.append(f"{alias}: {shorthash}")

            if lora_hashes:
                p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
            if network_hashes:
                p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)

    def deactivate(self, p):
        pass
+15 −0
Original line number Diff line number Diff line
import torch


def make_weight_cp(t, wa, wb):
    temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
    return torch.einsum('i j k l, i r -> r j k l', temp, wa)


def rebuild_conventional(up, down, shape, dyn_dim=None):
    up = up.reshape(up.size(0), -1)
    down = down.reshape(down.size(0), -1)
    if dyn_dim is not None:
        up = up[:, :dyn_dim]
        down = down[:dyn_dim, :]
    return (up @ down).reshape(shape)
+98 −0
Original line number Diff line number Diff line
import os
from collections import namedtuple

import torch

from modules import devices, sd_models, cache, errors, hashes, shared

NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])

metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}


class NetworkOnDisk:
    def __init__(self, name, filename):
        self.name = name
        self.filename = filename
        self.metadata = {}
        self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"

        def read_metadata():
            metadata = sd_models.read_metadata_from_safetensors(filename)
            metadata.pop('ssmd_cover_images', None)  # those are cover images, and they are too big to display in UI as text

            return metadata

        if self.is_safetensors:
            try:
                self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
            except Exception as e:
                errors.display(e, f"reading lora {filename}")

        if self.metadata:
            m = {}
            for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
                m[k] = v

            self.metadata = m

        self.alias = self.metadata.get('ss_output_name', self.name)

        self.hash = None
        self.shorthash = None
        self.set_hash(
            self.metadata.get('sshs_model_hash') or
            hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
            ''
        )

    def set_hash(self, v):
        self.hash = v
        self.shorthash = self.hash[0:12]

        if self.shorthash:
            import networks
            networks.available_network_hash_lookup[self.shorthash] = self

    def read_hash(self):
        if not self.hash:
            self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')

    def get_alias(self):
        import networks
        if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
            return self.name
        else:
            return self.alias


class Network:  # LoraModule
    def __init__(self, name, network_on_disk: NetworkOnDisk):
        self.name = name
        self.network_on_disk = network_on_disk
        self.multiplier = 1.0
        self.modules = {}
        self.mtime = None

        self.mentioned_name = None
        """the text that was used to add the network to prompt - can be either name or an alias"""


class ModuleType:
    def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
        return None


class NetworkModule:
    def __init__(self, net: Network, weights: NetworkWeights):
        self.network = net
        self.network_key = weights.network_key
        self.sd_key = weights.sd_key
        self.sd_module = weights.sd_module

    def calc_updown(self, target):
        raise NotImplementedError()

    def forward(self, x, y):
        raise NotImplementedError()
+59 −0
Original line number Diff line number Diff line
import lyco_helpers
import network
import network_lyco


class ModuleTypeHada(network.ModuleType):
    def create_module(self, net: network.Network, weights: network.NetworkWeights):
        if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
            return NetworkModuleHada(net, weights)

        return None


class NetworkModuleHada(network_lyco.NetworkModuleLyco):
    def __init__(self,  net: network.Network, weights: network.NetworkWeights):
        super().__init__(net, weights)

        if hasattr(self.sd_module, 'weight'):
            self.shape = self.sd_module.weight.shape

        self.w1a = weights.w["hada_w1_a"]
        self.w1b = weights.w["hada_w1_b"]
        self.dim = self.w1b.shape[0]
        self.w2a = weights.w["hada_w2_a"]
        self.w2b = weights.w["hada_w2_b"]

        self.t1 = weights.w.get("hada_t1")
        self.t2 = weights.w.get("hada_t2")

        self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
        self.scale = weights.w["scale"].item() if "scale" in weights.w else None

    def calc_updown(self, orig_weight):
        w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
        w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
        w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
        w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)

        output_shape = [w1a.size(0), w1b.size(1)]

        if self.t1 is not None:
            output_shape = [w1a.size(1), w1b.size(1)]
            t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
            updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
            output_shape += t1.shape[2:]
        else:
            if len(w1b.shape) == 4:
                output_shape += w1b.shape[2:]
            updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)

        if self.t2 is not None:
            t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
            updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
        else:
            updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)

        updown = updown1 * updown2

        return self.finalize_updown(updown, orig_weight, output_shape)
+70 −0
Original line number Diff line number Diff line
import torch

import network
from modules import devices


class ModuleTypeLora(network.ModuleType):
    def create_module(self, net: network.Network, weights: network.NetworkWeights):
        if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
            return NetworkModuleLora(net, weights)

        return None


class NetworkModuleLora(network.NetworkModule):
    def __init__(self,  net: network.Network, weights: network.NetworkWeights):
        super().__init__(net, weights)

        self.up = self.create_module(weights.w["lora_up.weight"])
        self.down = self.create_module(weights.w["lora_down.weight"])
        self.alpha = weights.w["alpha"] if "alpha" in weights.w else None

    def create_module(self, weight, none_ok=False):
        if weight is None and none_ok:
            return None

        if type(self.sd_module) == torch.nn.Linear:
            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
        elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
        elif type(self.sd_module) == torch.nn.MultiheadAttention:
            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
        elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
        elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
        else:
            print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
            return None

        with torch.no_grad():
            module.weight.copy_(weight)

        module.to(device=devices.cpu, dtype=devices.dtype)
        module.weight.requires_grad_(False)

        return module

    def calc_updown(self, target):
        up = self.up.weight.to(target.device, dtype=target.dtype)
        down = self.down.weight.to(target.device, dtype=target.dtype)

        if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
            updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
        elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
            updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
        else:
            updown = up @ down

        updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)

        return updown

    def forward(self, x, y):
        self.up.to(device=devices.device)
        self.down.to(device=devices.device)

        return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)

Loading