Commit 58c3df32 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

IA3 support

parent ef5dac77
Loading
Loading
Loading
Loading
+32 −0
Original line number Diff line number Diff line
import lyco_helpers
import network
import network_lyco


class ModuleTypeIa3(network.ModuleType):
    def create_module(self, net: network.Network, weights: network.NetworkWeights):
        if all(x in weights.w for x in ["weight"]):
            return NetworkModuleIa3(net, weights)

        return None


class NetworkModuleIa3(network_lyco.NetworkModuleLyco):
    def __init__(self,  net: network.Network, weights: network.NetworkWeights):
        super().__init__(net, weights)

        self.w = weights.w["weight"]
        self.on_input = weights.w["on_input"].item()

    def calc_updown(self, orig_weight):
        w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)

        output_shape = [w.size(0), orig_weight.size(1)]
        if self.on_input:
            output_shape.reverse()
        else:
            w = w.reshape(-1, 1)

        updown = orig_weight * w

        return self.finalize_updown(updown, orig_weight, output_shape)
+2 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import re
import network
import network_lora
import network_hada
import network_ia3

import torch
from typing import Union
@@ -13,6 +14,7 @@ from modules import shared, devices, sd_models, errors, scripts, sd_hijack
module_types = [
    network_lora.ModuleTypeLora(),
    network_hada.ModuleTypeHada(),
    network_ia3.ModuleTypeIa3(),
]