Commit 70e66e81 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Merge branch 'dev' into efficient-vae-methods

parents c134a480 f0c1063a
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,


def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images
    from modules import images, processing

    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
+1 −7
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType

decode_first_stage = sd_samplers_common.decode_first_stage

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -572,13 +573,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
    return samples


def decode_first_stage(model, x):
    from modules.sd_samplers_common import samples_to_images_tensor, approximation_indexes
    x = x.to(devices.dtype_vae)
    approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
    return samples_to_images_tensor(x, approx_index, model)


def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))
+3 −3
Original line number Diff line number Diff line
@@ -2,7 +2,6 @@ import torch
from torch.nn.functional import silu
from types import MethodType

import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
@@ -164,12 +163,13 @@ class StableDiffusionModelHijack:
    clip = None
    optimization_method = None

    embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()

    def __init__(self):
        import modules.textual_inversion.textual_inversion

        self.extra_generation_params = {}
        self.comments = []

        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
        self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)

    def apply_optimizations(self, option=None):
+6 −0
Original line number Diff line number Diff line
@@ -54,6 +54,12 @@ def single_sample_to_image(sample, approximation=None):
    return Image.fromarray(x_sample)


def decode_first_stage(model, x):
    x = model.decode_first_stage(x.to(devices.dtype_vae))

    return x


def sample_to_image(samples, index=0, approximation=None):
    return single_sample_to_image(samples[index], approximation)

+3 −0
Original line number Diff line number Diff line
@@ -50,6 +50,7 @@ def get_filename(filepath):


def refresh_vae_list():
    global vae_dict
    vae_dict.clear()

    paths = [
@@ -83,6 +84,8 @@ def refresh_vae_list():
        name = get_filename(filepath)
        vae_dict[name] = filepath

    vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))


def find_vae_near_checkpoint(checkpoint_file):
    checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
Loading