Commit 42fbda83 authored by discus0434's avatar discus0434
Browse files

layer options moves into create hnet ui

parent 7f8670c4
Loading
Loading
Loading
Loading
+32 −32
Original line number Diff line number Diff line
@@ -19,37 +19,21 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler


def parse_layer_structure(dim, state_dict):
    i = 0
    res = [1]
    while (key := "linear.{}.weight".format(i)) in state_dict:
        weight = state_dict[key]
        res.append(len(weight) // dim)
        i += 1
    return res


class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0
    layer_structure = None
    add_layer_norm = False

    def __init__(self, dim, state_dict=None):
    def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
        super().__init__()
        if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
            layer_structure = (1, 2, 1)
        else:
            if self.layer_structure is not None:
                assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
                assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
                layer_structure = self.layer_structure
        if layer_structure is not None:
            assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
            assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
        else:
            layer_structure = parse_layer_structure(dim, state_dict)

        linears = []
        for i in range(len(layer_structure) - 1):
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
            if self.add_layer_norm:
            if add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

        self.linear = torch.nn.Sequential(*linears)
@@ -77,38 +61,47 @@ class HypernetworkModule(torch.nn.Module):
        return x + self.linear(x) * self.multiplier

    def trainables(self):
        res = []
        layer_structure = []
        for layer in self.linear:
            res += [layer.weight, layer.bias]
        return res
            layer_structure += [layer.weight, layer.bias]
        return layer_structure


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength


def apply_layer_structure(value=None):
    HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
def parse_layer_structure(dim, state_dict):
    i = 0
    layer_structure = [1]

    while (key := "linear.{}.weight".format(i)) in state_dict:
        weight = state_dict[key]
        layer_structure.append(len(weight) // dim)
        i += 1

def apply_layer_norm(value=None):
    HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
    return layer_structure


class Hypernetwork:
    filename = None
    name = None

    def __init__(self, name=None, enable_sizes=None):
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
        self.layer_structure = layer_structure
        self.add_layer_norm = add_layer_norm

        for size in enable_sizes or []:
            self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
            self.layers[size] = (
                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
            )

    def weights(self):
        res = []
@@ -128,6 +121,8 @@ class Hypernetwork:

        state_dict['step'] = self.step
        state_dict['name'] = self.name
        state_dict['layer_structure'] = self.layer_structure
        state_dict['is_layer_norm'] = self.add_layer_norm
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name

@@ -142,10 +137,15 @@ class Hypernetwork:

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                    HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                )

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.layer_structure = state_dict.get('layer_structure', None)
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)

+7 −2
Original line number Diff line number Diff line
@@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork


def create_hypernetwork(name, enable_sizes):
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
    fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
    assert not os.path.exists(fn), f"file {fn} already exists"

    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
        name=name,
        enable_sizes=[int(x) for x in enable_sizes],
        layer_structure=layer_structure,
        add_layer_norm=add_layer_norm,
    )
    hypernet.save(fn)

    shared.reload_hypernetworks()
+0 −2
Original line number Diff line number Diff line
@@ -260,8 +260,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
    "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
    "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
    "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
    "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
    "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
+6 −2
Original line number Diff line number Diff line
@@ -1198,6 +1198,8 @@ def create_ui(wrap_gradio_gpu_call):
                with gr.Tab(label="Create hypernetwork"):
                    new_hypernetwork_name = gr.Textbox(label="Name")
                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
                    new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)])
                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")

                    with gr.Row():
                        with gr.Column(scale=3):
@@ -1280,6 +1282,8 @@ def create_ui(wrap_gradio_gpu_call):
            inputs=[
                new_hypernetwork_name,
                new_hypernetwork_sizes,
                new_hypernetwork_layer_structure,
                new_hypernetwork_add_layer_norm,
            ],
            outputs=[
                train_hypernetwork_name,
+3 −5
Original line number Diff line number Diff line
@@ -85,8 +85,6 @@ def initialize():
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
    shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
    shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
    shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)

    # make the program just exit at ctrl+c without waiting for anything
    def sigint_handler(sig, frame):