Commit a8cbe50c authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

remove duplicated code

parent 891ccb76
Loading
Loading
Loading
Loading
+2 −29
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ import torch
from typing import Union

from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding
import modules.textual_inversion.textual_inversion as textual_inversion

from lora_logger import logger

@@ -210,34 +210,7 @@ def load_network(name, network_on_disk):

    embeddings = {}
    for emb_name, data in bundle_embeddings.items():
        # textual inversion embeddings
        if 'string_to_param' in data:
            param_dict = data['string_to_param']
            param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1]
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
            vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
            vectors = data['clip_g'].shape[0]
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

            emb = next(iter(data.values()))
            if len(emb.shape) == 1:
                emb = emb.unsqueeze(0)
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        else:
            raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

        embedding = Embedding(vec, emb_name)
        embedding.vectors = vectors
        embedding.shape = shape
        embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
        embedding.loaded = None
        embeddings[emb_name] = embedding

+40 −34
Original line number Diff line number Diff line
@@ -181,40 +181,7 @@ class EmbeddingDatabase:
        else:
            return


        # textual inversion embeddings
        if 'string_to_param' in data:
            param_dict = data['string_to_param']
            param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1]
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
            vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
            vectors = data['clip_g'].shape[0]
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

            emb = next(iter(data.values()))
            if len(emb.shape) == 1:
                emb = emb.unsqueeze(0)
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        else:
            raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

        embedding = Embedding(vec, name)
        embedding.step = data.get('step', None)
        embedding.sd_checkpoint = data.get('sd_checkpoint', None)
        embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
        embedding.vectors = vectors
        embedding.shape = shape
        embedding.filename = path
        embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
        embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)

        if self.expected_shape == -1 or self.expected_shape == embedding.shape:
            self.register_embedding(embedding, shared.sd_model)
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
    return fn


def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
    if 'string_to_param' in data:  # textual inversion embeddings
        param_dict = data['string_to_param']
        param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
        assert len(param_dict) == 1, 'embedding file has multiple terms in it'
        emb = next(iter(param_dict.items()))[1]
        vec = emb.detach().to(devices.device, dtype=torch.float32)
        shape = vec.shape[-1]
        vectors = vec.shape[0]
    elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
        vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
        shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
        vectors = data['clip_g'].shape[0]
    elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:  # diffuser concepts
        assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

        emb = next(iter(data.values()))
        if len(emb.shape) == 1:
            emb = emb.unsqueeze(0)
        vec = emb.detach().to(devices.device, dtype=torch.float32)
        shape = vec.shape[-1]
        vectors = vec.shape[0]
    else:
        raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

    embedding = Embedding(vec, name)
    embedding.step = data.get('step', None)
    embedding.sd_checkpoint = data.get('sd_checkpoint', None)
    embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
    embedding.vectors = vectors
    embedding.shape = shape

    if filepath:
        embedding.filename = filepath
        embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')

    return embedding


def write_loss(log_directory, filename, step, epoch_len, values):
    if shared.opts.training_write_csv_every == 0:
        return