Commit 3ec7b705 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

suggestions and fixes from the PR

parent d25219b7
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)


shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(lora.available_loras)}, refresh=lora.list_available_loras),
    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
}))


+1 −5
Original line number Diff line number Diff line
@@ -644,17 +644,13 @@ class SwinIR(nn.Module):
    """

    def __init__(self, img_size=64, patch_size=1, in_chans=3,
                 embed_dim=96, depths=None, num_heads=None,
                 embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                 **kwargs):
        super(SwinIR, self).__init__()

        depths = depths or [6, 6, 6, 6]
        num_heads = num_heads or [6, 6, 6, 6]

        num_in_ch = in_chans
        num_out_ch = in_chans
        num_feat = 64
+2 −9
Original line number Diff line number Diff line
@@ -74,12 +74,9 @@ class WindowAttention(nn.Module):
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=None):
                 pretrained_window_size=(0, 0)):

        super().__init__()

        pretrained_window_size = pretrained_window_size or [0, 0]

        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
@@ -701,17 +698,13 @@ class Swin2SR(nn.Module):
    """

    def __init__(self, img_size=64, patch_size=1, in_chans=3,
                 embed_dim=96, depths=None, num_heads=None,
                 embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
                 window_size=7, mlp_ratio=4., qkv_bias=True, 
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                 **kwargs):
        super(Swin2SR, self).__init__()

        depths = depths or [6, 6, 6, 6]
        num_heads = num_heads or [6, 6, 6, 6]

        num_in_ch = in_chans
        num_out_ch = in_chans
        num_feat = 64
+2 −5
Original line number Diff line number Diff line
@@ -161,13 +161,10 @@ class Fuse_sft_block(nn.Module):
class CodeFormer(VQAutoEncoder):
    def __init__(self, dim_embd=512, n_head=8, n_layers=9, 
                codebook_size=1024, latent_size=256,
                connect_list=None,
                fix_modules=None):
                connect_list=('32', '64', '128', '256'),
                fix_modules=('quantize', 'generator')):
        super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)

        connect_list = connect_list or ['32', '64', '128', '256']
        fix_modules = fix_modules or ['quantize', 'generator']

        if fix_modules is not None:
            for module in fix_modules:
                for param in getattr(self, module).parameters():
+2 −2
Original line number Diff line number Diff line
@@ -5,13 +5,13 @@ import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared

not_available = ["hardswish", "multiheadattention"]
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]


def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
    filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)

    return gr.Dropdown.update(choices=sorted(shared.hypernetworks.keys())), f"Created: {filename}", ""
    return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""


def train_hypernetwork(*args):
Loading