Commit 2d8c894b authored by v0xie's avatar v0xie
Browse files

refactor: use forward hook instead of custom forward

parent 0550659c
Loading
Loading
Loading
Loading
+24 −9
Original line number Diff line number Diff line
@@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule):
    # how do we revert this to unload the weights?
    def apply_to(self):
        self.org_forward = self.org_module[0].forward
        self.org_module[0].forward = self.forward
        #self.org_module[0].forward = self.forward
        self.org_module[0].register_forward_hook(self.forward_hook)

    def get_weight(self, oft_blocks, multiplier=None):
        self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
        block_Q = oft_blocks - oft_blocks.transpose(1, 2)
        norm_Q = torch.norm(block_Q.flatten())
        new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
@@ -67,13 +69,9 @@ class NetworkModuleOFT(network.NetworkModule):

        return self.finalize_updown(updown, orig_weight, output_shape)
    
    def forward(self, x, y=None):
        x = self.org_forward(x)
        if self.multiplier() == 0.0:
            return x

        # calculating R here is excruciatingly slow
        #R = self.get_weight().to(x.device, dtype=x.dtype)
    def forward_hook(self, module, args, output):
        #print(f'Forward hook in {self.network_key} called')
        x = output
        R = self.R.to(x.device, dtype=x.dtype)

        if x.dim() == 4:
@@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule):
        else:
            x = torch.matmul(x, R)
        return x

    # def forward(self, x, y=None):
    #     x = self.org_forward(x)
    #     if self.multiplier() == 0.0:
    #         return x

    #     # calculating R here is excruciatingly slow
    #     #R = self.get_weight().to(x.device, dtype=x.dtype)
    #     R = self.R.to(x.device, dtype=x.dtype)

    #     if x.dim() == 4:
    #         x = x.permute(0, 2, 3, 1)
    #         x = torch.matmul(x, R)
    #         x = x.permute(0, 3, 1, 2)
    #     else:
    #         x = torch.matmul(x, R)
    #     return x