Commit d9cc27cb authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Fix MHA updown err and support ex-bias for no-bias layer

parent 5881dcb8
Loading
Loading
Loading
Loading
+29 −8
Original line number Diff line number Diff line
@@ -277,7 +277,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
            self.weight.copy_(weights_backup)

    if bias_backup is not None:
        if isinstance(self, torch.nn.MultiheadAttention):
            self.out_proj.bias.copy_(bias_backup)
        else:
            self.bias.copy_(bias_backup)
    else:
        if isinstance(self, torch.nn.MultiheadAttention):
            self.out_proj.bias = None
        else:
            self.bias = None


def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
@@ -305,7 +313,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

    bias_backup = getattr(self, "network_bias_backup", None)
    if bias_backup is None and getattr(self, 'bias', None) is not None:
        if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
            bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
        elif getattr(self, 'bias', None) is not None:
            bias_backup = self.bias.to(devices.cpu, copy=True)
        else:
            bias_backup = None
        self.network_bias_backup = bias_backup

    if current_names != wanted_names:
@@ -322,7 +335,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
                        updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))

                    self.weight += updown
                    if ex_bias is not None and getattr(self, 'bias', None) is not None:
                    if ex_bias is not None and hasattr(self, 'bias'):
                        if self.bias is None:
                            self.bias = torch.nn.Parameter(ex_bias)
                        else:
                            self.bias += ex_bias
                    continue

@@ -333,14 +349,19 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

            if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
                with torch.no_grad():
                    updown_q = module_q.calc_updown(self.in_proj_weight)
                    updown_k = module_k.calc_updown(self.in_proj_weight)
                    updown_v = module_v.calc_updown(self.in_proj_weight)
                    updown_q, _ = module_q.calc_updown(self.in_proj_weight)
                    updown_k, _ = module_k.calc_updown(self.in_proj_weight)
                    updown_v, _ = module_v.calc_updown(self.in_proj_weight)
                    updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
                    updown_out = module_out.calc_updown(self.out_proj.weight)
                    updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)

                    self.in_proj_weight += updown_qkv
                    self.out_proj.weight += updown_out
                    if ex_bias is not None:
                        if self.out_proj.bias is None:
                            self.out_proj.bias = torch.nn.Parameter(ex_bias)
                        else:
                            self.out_proj.bias += ex_bias
                    continue

            if module is None: