Unverified Commit c121f8c3 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #14031 from AUTOMATIC1111/test-fp8

A big improvement for dtype casting system with fp8 storage type and manual cast
parents 60186c7b 8edb9144
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -137,7 +137,7 @@ class NetworkModule:
    def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
        if self.bias is not None:
            updown = updown.reshape(self.bias.shape)
            updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
            updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
            updown = updown.reshape(output_shape)

        if len(output_shape) == 4:
+2 −2
Original line number Diff line number Diff line
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):

    def calc_updown(self, orig_weight):
        output_shape = self.weight.shape
        updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
        updown = self.weight.to(orig_weight.device)
        if self.ex_bias is not None:
            ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
            ex_bias = self.ex_bias.to(orig_weight.device)
        else:
            ex_bias = None

+5 −5
Original line number Diff line number Diff line
@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
        self.w2b = weights.w["b2.weight"]

    def calc_updown(self, orig_weight):
        w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
        w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
        w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
        w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
        w1a = self.w1a.to(orig_weight.device)
        w1b = self.w1b.to(orig_weight.device)
        w2a = self.w2a.to(orig_weight.device)
        w2b = self.w2b.to(orig_weight.device)

        output_shape = [w1a.size(0), w1b.size(1)]
        updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
        updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))

        return self.finalize_updown(updown, orig_weight, output_shape)
+6 −6
Original line number Diff line number Diff line
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
        self.t2 = weights.w.get("hada_t2")

    def calc_updown(self, orig_weight):
        w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
        w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
        w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
        w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
        w1a = self.w1a.to(orig_weight.device)
        w1b = self.w1b.to(orig_weight.device)
        w2a = self.w2a.to(orig_weight.device)
        w2b = self.w2b.to(orig_weight.device)

        output_shape = [w1a.size(0), w1b.size(1)]

        if self.t1 is not None:
            output_shape = [w1a.size(1), w1b.size(1)]
            t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
            t1 = self.t1.to(orig_weight.device)
            updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
            output_shape += t1.shape[2:]
        else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
            updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)

        if self.t2 is not None:
            t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
            t2 = self.t2.to(orig_weight.device)
            updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
        else:
            updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
+1 −1
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
        self.on_input = weights.w["on_input"].item()

    def calc_updown(self, orig_weight):
        w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
        w = self.w.to(orig_weight.device)

        output_shape = [w.size(0), orig_weight.size(1)]
        if self.on_input:
Loading