Commit c6e9fed5 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix for #3086 failing to load any previous hypernet

parent c664b231
Loading
Loading
Loading
Loading
+28 −32
Original line number Diff line number Diff line
@@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module):

    def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
        super().__init__()
        if layer_structure is not None:

        assert layer_structure is not None, "layer_structure mut not be None"
        assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
        assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
        else:
            layer_structure = parse_layer_structure(dim, state_dict)

        linears = []
        for i in range(len(layer_structure) - 1):
@@ -39,10 +38,8 @@ class HypernetworkModule(torch.nn.Module):
        self.linear = torch.nn.Sequential(*linears)

        if state_dict is not None:
            try:
            self.fix_old_state_dict(state_dict)
            self.load_state_dict(state_dict)
            except RuntimeError:
                self.try_load_previous(state_dict)
        else:
            for layer in self.linear:
                layer.weight.data.normal_(mean=0.0, std=0.01)
@@ -50,12 +47,21 @@ class HypernetworkModule(torch.nn.Module):

        self.to(devices.device)

    def try_load_previous(self, state_dict):
        states = self.state_dict()
        states['linear.0.bias'].copy_(state_dict['linear1.bias'])
        states['linear.0.weight'].copy_(state_dict['linear1.weight'])
        states['linear.1.bias'].copy_(state_dict['linear2.bias'])
        states['linear.1.weight'].copy_(state_dict['linear2.weight'])
    def fix_old_state_dict(self, state_dict):
        changes = {
            'linear1.bias': 'linear.0.bias',
            'linear1.weight': 'linear.0.weight',
            'linear2.bias': 'linear.1.bias',
            'linear2.weight': 'linear.1.weight',
        }

        for fr, to in changes.items():
            x = state_dict.get(fr, None)
            if x is None:
                continue

            del state_dict[fr]
            state_dict[to] = x

    def forward(self, x):
        return x + self.linear(x) * self.multiplier
@@ -71,18 +77,6 @@ def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength


def parse_layer_structure(dim, state_dict):
    i = 0
    layer_structure = [1]

    while (key := "linear.{}.weight".format(i)) in state_dict:
        weight = state_dict[key]
        layer_structure.append(len(weight) // dim)
        i += 1

    return layer_structure


class Hypernetwork:
    filename = None
    name = None
@@ -135,17 +129,18 @@ class Hypernetwork:

        state_dict = torch.load(filename, map_location='cpu')

        self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
        self.add_layer_norm = state_dict.get('is_layer_norm', False)

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                    HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                    HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
                )

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.layer_structure = state_dict.get('layer_structure', None)
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)

@@ -244,6 +239,7 @@ def stack_conds(conds):

    return torch.stack(conds)


def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    assert hypernetwork_name, 'hypernetwork not selected'