Commit fe1967a4 authored by v0xie's avatar v0xie
Browse files

skip multihead attn for now

parent d727ddfc
Loading
Loading
Loading
Loading
+37 −17
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule):
        super().__init__(net, weights)

        self.lin_module = None
        self.org_module: list[torch.Module] = [self.sd_module]
        # kohya-ss
        if "oft_blocks" in weights.w.keys():
            self.is_kohya = True
@@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule):
            # alpha is rank if alpha is 0 or None
            if self.alpha is None:
                pass
            self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
            self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
        else:
            raise ValueError("oft_blocks or oft_diag must be in weights dict")

@@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule):
            #    raise ValueError("Linear sd_module must have out_features or embed_dim")
        elif is_other_linear:
            self.out_dim = self.sd_module.embed_dim
            #self.org_weight = self.org_module[0].weight
#            if hasattr(self.sd_module, "in_proj_weight"):
#                self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
#            if hasattr(self.sd_module, "out_proj_weight"):
#                self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
#            self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
        elif is_conv:
            self.out_dim = self.sd_module.out_channels
        else:
@@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule):
            self.constraint = self.alpha * self.out_dim
        #elif is_linear or is_conv:
        else:
            self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
            self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
            self.constraint = None

        self.org_module: list[torch.Module] = [self.sd_module]

        # if is_other_linear:
        #     weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
@@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule):

    def calc_updown(self, orig_weight):
        multiplier = self.multiplier() * self.calc_scale()
        is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
        if self.is_kohya and not is_other_linear:
            R = self.get_weight(self.oft_blocks, multiplier)
            #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
            merged_weight = self.merge_weight(R, orig_weight)
        elif not self.is_kohya and not is_other_linear:
            if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
                orig_weight=orig_weight.permute(1, 0)
            R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
            merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
            #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
            merged_weight = torch.einsum(
                'k n m, k n ... -> k m ...',
                R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
                merged_weight 
            )
            merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
            if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
                orig_weight=orig_weight.permute(1, 0)
                #merged_weight=merged_weight.permute(1, 0)
            updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
            #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
            output_shape = orig_weight.shape
        else:
            # skip for now
            updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
            output_shape = (orig_weight.shape[1], orig_weight.shape[1])

        #if self.lin_module is not None:
        #    R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
        #    weight = torch.mul(torch.mul(R, multiplier), orig_weight)
        #else:
        #    orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
        #    weight = torch.einsum(
        #        'k n m, k n ... -> k m ...',
        #        R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
        #        orig_weight
        #    )
        #    weight = rearrange(weight, 'k m ... -> (k m) ...')

        updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
        #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
        output_shape = orig_weight.shape
        orig_weight = orig_weight

        return self.finalize_updown(updown, orig_weight, output_shape)