Commit 6021f7a7 authored by discus0434's avatar discus0434
Browse files

add options to custom hypernetwork layer structure

parent c1093b80
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -27,3 +27,4 @@ __pycache__
notification.mp3
/SwinIR
/textual_inversion
/hypernetwork
+67 −21
Original line number Diff line number Diff line
import csv
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm
import csv

import torch

from ldm.util import default
from modules import devices, shared, processing, sd_models
import modules.textual_inversion.dataset
import torch
from torch import einsum
import tqdm
from einops import rearrange, repeat
import modules.textual_inversion.dataset
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum


def parse_layer_structure(dim, state_dict):
    i = 0
    res = [1]
    while (key := "linear.{}.weight".format(i)) in state_dict:
        weight = state_dict[key]
        res.append(len(weight) // dim)
        i += 1
    return res


class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0
    layer_structure = None
    add_layer_norm = False

    def __init__(self, dim, state_dict=None):
        super().__init__()
        if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
            layer_structure = (1, 2, 1)
        else:
            if self.layer_structure is not None:
                assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
                assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
                layer_structure = self.layer_structure
            else:
                layer_structure = parse_layer_structure(dim, state_dict)

        self.linear1 = torch.nn.Linear(dim, dim * 2)
        self.linear2 = torch.nn.Linear(dim * 2, dim)
        linears = []
        for i in range(len(layer_structure) - 1):
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
            if self.add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

        self.linear = torch.nn.Sequential(*linears)

        if state_dict is not None:
            self.load_state_dict(state_dict, strict=True)
            try:
                self.load_state_dict(state_dict)
            except RuntimeError:
                self.try_load_previous(state_dict)
        else:

            self.linear1.weight.data.normal_(mean=0.0, std=0.01)
            self.linear1.bias.data.zero_()
            self.linear2.weight.data.normal_(mean=0.0, std=0.01)
            self.linear2.bias.data.zero_()
            for layer in self.linear:
                layer.weight.data.normal_(mean = 0.0, std = 0.01)
                layer.bias.data.zero_()

        self.to(devices.device)

    def try_load_previous(self, state_dict):
        states = self.state_dict()
        states['linear.0.bias'].copy_(state_dict['linear1.bias'])
        states['linear.0.weight'].copy_(state_dict['linear1.weight'])
        states['linear.1.bias'].copy_(state_dict['linear2.bias'])
        states['linear.1.weight'].copy_(state_dict['linear2.weight'])

    def forward(self, x):
        return x + (self.linear2(self.linear1(x))) * self.multiplier
        return x + self.linear(x) * self.multiplier

    def trainables(self):
        res = []
        for layer in self.linear:
            res += [layer.weight, layer.bias]
        return res


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength


def apply_layer_structure(value=None):
    HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure


def apply_layer_norm(value=None):
    HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm


class Hypernetwork:
    filename = None
    name = None
@@ -68,7 +114,7 @@ class Hypernetwork:
        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
                res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
                res += layer.trainables()

        return res

@@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)

        assert ds.length > 1, "Dataset should contain more than 1 images"
    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
@@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

        with torch.autocast("cuda"):
            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
#            c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
            c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
            del x
@@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{mean_loss:.7f}",
            "learn_rate": scheduler.learn_rate
            "learn_rate": f"{scheduler.learn_rate:.7f}"
        })

        if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+3 −1
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization
from modules import sd_models, sd_samplers, localization
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path

@@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
    "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
    "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
    "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
    "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
    "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
+4 −2
Original line number Diff line number Diff line
@@ -86,6 +86,8 @@ def initialize():
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
    shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
    shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
    shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)


def webui():