Unverified Commit 21645787 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #11821 from AUTOMATIC1111/lora_lyco

lora extension rework to include other types of networks
parents 05d23c78 35510f75
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -168,5 +168,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
+26 −12
Original line number Diff line number Diff line
from modules import extra_networks, shared
import lora
import networks


class ExtraNetworkLora(extra_networks.ExtraNetwork):
@@ -9,24 +9,38 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
    def activate(self, p, params_list):
        additional = shared.opts.sd_lora

        if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
        if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
            p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
            params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

        names = []
        multipliers = []
        te_multipliers = []
        unet_multipliers = []
        dyn_dims = []
        for params in params_list:
            assert params.items

            names.append(params.items[0])
            multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
            names.append(params.positional[0])

        lora.load_loras(names, multipliers)
            te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
            te_multiplier = float(params.named.get("te", te_multiplier))

            unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else 1.0
            unet_multiplier = float(params.named.get("unet", unet_multiplier))

            dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
            dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim

            te_multipliers.append(te_multiplier)
            unet_multipliers.append(unet_multiplier)
            dyn_dims.append(dyn_dim)

        networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)

        if shared.opts.lora_add_hashes_to_infotext:
            lora_hashes = []
            for item in lora.loaded_loras:
                shorthash = item.lora_on_disk.shorthash
            network_hashes = []
            for item in networks.loaded_networks:
                shorthash = item.network_on_disk.shorthash
                if not shorthash:
                    continue

@@ -36,10 +50,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):

                alias = alias.replace(":", "").replace(",", "")

                lora_hashes.append(f"{alias}: {shorthash}")
                network_hashes.append(f"{alias}: {shorthash}")

            if lora_hashes:
                p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
            if network_hashes:
                p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)

    def deactivate(self, p):
        pass
+7 −535

File changed.

Preview size limit exceeded, changes collapsed.

+21 −0
Original line number Diff line number Diff line
import torch


def make_weight_cp(t, wa, wb):
    temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
    return torch.einsum('i j k l, i r -> r j k l', temp, wa)


def rebuild_conventional(up, down, shape, dyn_dim=None):
    up = up.reshape(up.size(0), -1)
    down = down.reshape(down.size(0), -1)
    if dyn_dim is not None:
        up = up[:, :dyn_dim]
        down = down[:dyn_dim, :]
    return (up @ down).reshape(shape)


def rebuild_cp_decomposition(up, down, mid):
    up = up.reshape(up.size(0), -1)
    down = down.reshape(down.size(0), -1)
    return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
+134 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading