Unverified Commit f510a227 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #3086 from discus0434/master

Add settings for multi-layer structure hypernetworks
parents f894dd55 42fbda83
Loading
Loading
Loading
Loading
+63 −16
Original line number Diff line number Diff line
@@ -22,45 +22,86 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0

    def __init__(self, dim, state_dict=None):
    def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
        super().__init__()
        if layer_structure is not None:
            assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
            assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
        else:
            layer_structure = parse_layer_structure(dim, state_dict)

        linears = []
        for i in range(len(layer_structure) - 1):
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
            if add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

        self.linear1 = torch.nn.Linear(dim, dim * 2)
        self.linear2 = torch.nn.Linear(dim * 2, dim)
        self.linear = torch.nn.Sequential(*linears)

        if state_dict is not None:
            self.load_state_dict(state_dict, strict=True)
            try:
                self.load_state_dict(state_dict)
            except RuntimeError:
                self.try_load_previous(state_dict)
        else:

            self.linear1.weight.data.normal_(mean=0.0, std=0.01)
            self.linear1.bias.data.zero_()
            self.linear2.weight.data.normal_(mean=0.0, std=0.01)
            self.linear2.bias.data.zero_()
            for layer in self.linear:
                layer.weight.data.normal_(mean = 0.0, std = 0.01)
                layer.bias.data.zero_()

        self.to(devices.device)

    def try_load_previous(self, state_dict):
        states = self.state_dict()
        states['linear.0.bias'].copy_(state_dict['linear1.bias'])
        states['linear.0.weight'].copy_(state_dict['linear1.weight'])
        states['linear.1.bias'].copy_(state_dict['linear2.bias'])
        states['linear.1.weight'].copy_(state_dict['linear2.weight'])

    def forward(self, x):
        return x + (self.linear2(self.linear1(x))) * self.multiplier
        return x + self.linear(x) * self.multiplier

    def trainables(self):
        layer_structure = []
        for layer in self.linear:
            layer_structure += [layer.weight, layer.bias]
        return layer_structure


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength


def parse_layer_structure(dim, state_dict):
    i = 0
    layer_structure = [1]

    while (key := "linear.{}.weight".format(i)) in state_dict:
        weight = state_dict[key]
        layer_structure.append(len(weight) // dim)
        i += 1

    return layer_structure


class Hypernetwork:
    filename = None
    name = None

    def __init__(self, name=None, enable_sizes=None):
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
        self.layer_structure = layer_structure
        self.add_layer_norm = add_layer_norm

        for size in enable_sizes or []:
            self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
            self.layers[size] = (
                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
            )

    def weights(self):
        res = []
@@ -68,7 +109,7 @@ class Hypernetwork:
        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
                res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
                res += layer.trainables()

        return res

@@ -80,6 +121,8 @@ class Hypernetwork:

        state_dict['step'] = self.step
        state_dict['name'] = self.name
        state_dict['layer_structure'] = self.layer_structure
        state_dict['is_layer_norm'] = self.add_layer_norm
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name

@@ -94,10 +137,15 @@ class Hypernetwork:

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                    HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                )

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.layer_structure = state_dict.get('layer_structure', None)
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)

@@ -226,7 +274,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
+7 −2
Original line number Diff line number Diff line
@@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork


def create_hypernetwork(name, enable_sizes):
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
    fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
    assert not os.path.exists(fn), f"file {fn} already exists"

    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
        name=name,
        enable_sizes=[int(x) for x in enable_sizes],
        layer_structure=layer_structure,
        add_layer_norm=add_layer_norm,
    )
    hypernet.save(fn)

    shared.reload_hypernetworks()
+6 −2
Original line number Diff line number Diff line
@@ -1217,6 +1217,8 @@ def create_ui(wrap_gradio_gpu_call):
                with gr.Tab(label="Create hypernetwork"):
                    new_hypernetwork_name = gr.Textbox(label="Name")
                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
                    new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)])
                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")

                    with gr.Row():
                        with gr.Column(scale=3):
@@ -1299,6 +1301,8 @@ def create_ui(wrap_gradio_gpu_call):
            inputs=[
                new_hypernetwork_name,
                new_hypernetwork_sizes,
                new_hypernetwork_layer_structure,
                new_hypernetwork_add_layer_norm,
            ],
            outputs=[
                train_hypernetwork_name,
+1 −1

File changed.

Contains only whitespace changes.

+2 −2

File changed.

Contains only whitespace changes.