Unverified Commit 0d07cbfa authored by AngelBottomless's avatar AngelBottomless Committed by GitHub
Browse files

I blame code autocomplete

parent 0abb39f4
Loading
Loading
Loading
Loading
+27 −49
Original line number Diff line number Diff line
@@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module):
        "tanh": torch.nn.Tanh,
        "sigmoid": torch.nn.Sigmoid,
    }
    activation_dict.update(
        {cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
         inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})

    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
                 add_layer_norm=False, use_dropout=False):
    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
        super().__init__()

        assert layer_structure is not None, "layer_structure must not be None"
@@ -130,8 +127,7 @@ class Hypernetwork:
    filename = None
    name = None

    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
                 add_layer_norm=False, use_dropout=False):
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
        self.filename = None
        self.name = name
        self.layers = {}
@@ -146,10 +142,8 @@ class Hypernetwork:

        for size in enable_sizes or []:
            self.layers[size] = (
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
            )

    def weights(self):
@@ -201,10 +195,8 @@ class Hypernetwork:
        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                )

        self.name = state_dict.get('name', self.name)
@@ -338,18 +330,14 @@ def report_statistics(loss_info: dict):
            print(e)


def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
                       training_height, steps, create_image_every, save_hypernetwork_every, template_file,
                       preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
                       preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):

def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
                                            save_hypernetwork_every, create_image_every, log_directory,
                                            name="hypernetwork")
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")

    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
@@ -388,13 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
                                                                height=training_height,
                                                                repeats=shared.opts.training_image_repeats_per_epoch,
                                                                placeholder_token=hypernetwork_name,
                                                                model=shared.sd_model, device=devices.device,
                                                                template_file=template_file, include_cond=True,
                                                                batch_size=batch_size)
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -518,10 +500,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

            if image is not None:
                shared.state.current_image = image
                last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
                                                                     shared.opts.samples_format, processed.infotexts[0],
                                                                     p=p, forced_filename=forced_filename,
                                                                     save_to_dirs=False)
                last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                last_saved_image += f", prompt: {preview_text}"

        shared.state.job_no = hypernetwork.step
@@ -543,7 +522,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>

    return hypernetwork, filename


def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
    old_hypernetwork_name = hypernetwork.name
    old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None