Commit bbf00a96 authored by v0xie's avatar v0xie
Browse files

refactor: remove unused function

parent 329c8bac
Loading
Loading
Loading
Loading
+0 −47
Original line number Diff line number Diff line
@@ -2,7 +2,6 @@ import torch
import network
from lyco_helpers import factorization
from einops import rearrange
from modules import devices


class ModuleTypeOFT(network.ModuleType):
@@ -54,58 +53,12 @@ class NetworkModuleOFT(network.NetworkModule):
            raise ValueError("sd_module must be Linear or Conv")

        if self.is_kohya:
            #self.num_blocks = self.dim
            #self.block_size = self.out_dim // self.num_blocks
            #self.block_size = self.dim
            #self.num_blocks = self.out_dim // self.block_size
            self.constraint = self.alpha * self.out_dim
            self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
        else:
            self.constraint = None
            self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)

        if is_other_linear:
            self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True)


    def create_module(self, weights, key, none_ok=False):
        weight = weights.get(key)

        if weight is None and none_ok:
            return None

        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
        is_conv = type(self.sd_module) in [torch.nn.Conv2d]

        if is_linear:
            weight = weight.reshape(weight.shape[0], -1)
            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
        elif is_conv and key == "lora_down.weight" or key == "dyn_up":
            if len(weight.shape) == 2:
                weight = weight.reshape(weight.shape[0], -1, 1, 1)

            if weight.shape[2] != 1 or weight.shape[3] != 1:
                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
            else:
                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
        elif is_conv and key == "lora_mid.weight":
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
        elif is_conv and key == "lora_up.weight" or key == "dyn_down":
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
        else:
            raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')

        with torch.no_grad():
            if weight.shape != module.weight.shape:
                weight = weight.reshape(module.weight.shape)
            module.weight.copy_(weight)

        module.to(device=devices.cpu, dtype=devices.dtype)
        module.weight.requires_grad_(False)

        return module


    def merge_weight(self, R_weight, org_weight):
        R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
        if org_weight.dim() == 4: