Commit a5121e7a authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fixes for B007

parent 550256db
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -88,7 +88,7 @@ class LDSR:

        x_t = None
        logs = None
        for n in range(n_runs):
        for _ in range(n_runs):
            if custom_shape is not None:
                x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
                x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
+1 −1
Original line number Diff line number Diff line
@@ -418,7 +418,7 @@ def infotext_pasted(infotext, params):

    added = []

    for k, v in params.items():
    for k in params:
        if not k.startswith("AddNet Model "):
            continue

+1 −1
Original line number Diff line number Diff line
@@ -132,7 +132,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
        model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
        model.load_state_dict(torch.load(filename), strict=True)
        model.eval()
        for k, v in model.named_parameters():
        for _, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)

+1 −1
Original line number Diff line number Diff line
@@ -848,7 +848,7 @@ class SwinIR(nn.Module):
        H, W = self.patches_resolution
        flops += H * W * 3 * self.embed_dim * 9
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
        for layer in self.layers:
            flops += layer.flops()
        flops += H * W * 3 * self.embed_dim * self.embed_dim
        flops += self.upsample.flops()
+1 −1
Original line number Diff line number Diff line
@@ -1001,7 +1001,7 @@ class Swin2SR(nn.Module):
        H, W = self.patches_resolution
        flops += H * W * 3 * self.embed_dim * 9
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
        for layer in self.layers:
            flops += layer.flops()
        flops += H * W * 3 * self.embed_dim * self.embed_dim
        flops += self.upsample.flops()
Loading