Commit d6d0b22e authored by v0xie's avatar v0xie
Browse files

fix: ignore calc_scale() for COFT which has very small alpha

parent 7edd50f3
Loading
Loading
Loading
Loading
+5 −11
Original line number Diff line number Diff line
@@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule):
        is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]

        if not is_other_linear:
            #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
            #    orig_weight=orig_weight.permute(1, 0)

            oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)

            # without this line the results are significantly worse / less accurate
            # ensure skew-symmetric matrix
            oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)

            R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
@@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule):
            )
            merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')

            #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
            #    orig_weight=orig_weight.permute(1, 0)

            updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
            output_shape = orig_weight.shape
        else:
@@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule):
        return self.finalize_updown(updown, orig_weight, output_shape)

    def calc_updown(self, orig_weight):
        multiplier = self.multiplier() * self.calc_scale()
        #if self.is_kohya:
        #    return self.calc_updown_kohya(orig_weight, multiplier)
        #else:
        # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it
        #multiplier = self.multiplier() * self.calc_scale()
        multiplier = self.multiplier()

        return self.calc_updown_kb(orig_weight, multiplier)

    # override to remove the multiplier/scale factor; it's already multiplied in get_weight