Commit 0abb39f4 authored by aria1th's avatar aria1th
Browse files

resolve conflict - first revert

parent 1764ac3c
Loading
Loading
Loading
Loading
+52 −71
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm
from collections import defaultdict, deque
from statistics import stdev, mean

optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}

class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0
@@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module):
        "tanh": torch.nn.Tanh,
        "sigmoid": torch.nn.Sigmoid,
    }
    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
    activation_dict.update(
        {cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
         inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})

    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
                 add_layer_norm=False, use_dropout=False):
        super().__init__()

        assert layer_structure is not None, "layer_structure must not be None"
@@ -128,7 +130,8 @@ class Hypernetwork:
    filename = None
    name = None

    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
                 add_layer_norm=False, use_dropout=False):
        self.filename = None
        self.name = name
        self.layers = {}
@@ -140,13 +143,13 @@ class Hypernetwork:
        self.weight_init = weight_init
        self.add_layer_norm = add_layer_norm
        self.use_dropout = use_dropout
        self.optimizer_name = None
        self.optimizer_state_dict = None

        for size in enable_sizes or []:
            self.layers[size] = (
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout),
            )

    def weights(self):
@@ -161,7 +164,6 @@ class Hypernetwork:

    def save(self, filename):
        state_dict = {}
        optimizer_saved_dict = {}

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -175,14 +177,8 @@ class Hypernetwork:
        state_dict['use_dropout'] = self.use_dropout
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
        if self.optimizer_name is not None:
            optimizer_saved_dict['optimizer_name'] = self.optimizer_name

        torch.save(state_dict, filename)
        if self.optimizer_state_dict:
            optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
            torch.save(optimizer_saved_dict, filename + '.optim')

    def load(self, filename):
        self.filename = filename
@@ -202,23 +198,13 @@ class Hypernetwork:
        self.use_dropout = state_dict.get('use_dropout', False)
        print(f"Dropout usage is set to {self.use_dropout}")

        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
        self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
        print(f"Optimizer name is {self.optimizer_name}")
        if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
        else:
            self.optimizer_state_dict = None
        if self.optimizer_state_dict:
            print("Loaded existing optimizer from checkpoint")
        else:
            print("No saved optimizer exists in checkpoint")

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.use_dropout),
                )

        self.name = state_dict.get('name', self.name)
@@ -233,7 +219,7 @@ def list_hypernetworks(path):
        name = os.path.splitext(os.path.basename(filename))[0]
        # Prevent a hypothetical "None.pt" from being listed.
        if name != "None":
            res[name + f"({sd_models.model_hash(filename)})"] = filename
            res[name] = filename
    return res


@@ -352,14 +338,18 @@ def report_statistics(loss_info:dict):
            print(e)



def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
                       training_height, steps, create_image_every, save_hypernetwork_every, template_file,
                       preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
                       preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
                                            save_hypernetwork_every, create_image_every, log_directory,
                                            name="hypernetwork")

    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
@@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    else:
        hypernetwork_dir = None

    hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
@@ -399,7 +388,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
                                                                height=training_height,
                                                                repeats=shared.opts.training_image_repeats_per_epoch,
                                                                placeholder_token=hypernetwork_name,
                                                                model=shared.sd_model, device=devices.device,
                                                                template_file=template_file, include_cond=True,
                                                                batch_size=batch_size)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -415,19 +410,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True
    # Here we use optimizer from saved HN, or we can specify as UI option.
    if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
        optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
    else:
        print(f"Optimizer type {optimizer_name} is not defined!")
        optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
        optimizer_name = 'AdamW'
    if hypernetwork.optimizer_state_dict:  # This line must be changed if Optimizer type can be different from saved optimizer.
        try:
            optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
        except RuntimeError as e:
            print("Cannot resume from saved optimizer!")
            print(e)
    # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
    optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)

    steps_without_grad = 0

@@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            # Before saving, change name to match current checkpoint.
            hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
            hypernetwork.optimizer_name = optimizer_name
            if shared.opts.save_optimizer_state:
                hypernetwork.optimizer_state_dict = optimizer.state_dict()
            save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
            hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{previous_mean_loss:.7f}",
            "learn_rate": scheduler.learn_rate
@@ -537,7 +518,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

            if image is not None:
                shared.state.current_image = image
                last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
                                                                     shared.opts.samples_format, processed.infotexts[0],
                                                                     p=p, forced_filename=forced_filename,
                                                                     save_to_dirs=False)
                last_saved_image += f", prompt: {preview_text}"

        shared.state.job_no = hypernetwork.step
@@ -551,15 +535,12 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""

    report_statistics(loss_dict)

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
    hypernetwork.optimizer_name = optimizer_name
    if shared.opts.save_optimizer_state:
        hypernetwork.optimizer_state_dict = optimizer.state_dict()
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
    del optimizer
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.

    return hypernetwork, filename