Commit 81e94de3 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Add warning when meet emb name conflicting

Choose standalone embedding (in /embeddings folder) first
parent 2282eb8d
Loading
Loading
Loading
Loading
+33 −0
Original line number Diff line number Diff line
import sys
import copy
import logging


class ColoredFormatter(logging.Formatter):
    COLORS = {
        "DEBUG": "\033[0;36m",  # CYAN
        "INFO": "\033[0;32m",  # GREEN
        "WARNING": "\033[0;33m",  # YELLOW
        "ERROR": "\033[0;31m",  # RED
        "CRITICAL": "\033[0;37;41m",  # WHITE ON RED
        "RESET": "\033[0m",  # RESET COLOR
    }

    def format(self, record):
        colored_record = copy.copy(record)
        levelname = colored_record.levelname
        seq = self.COLORS.get(levelname, self.COLORS["RESET"])
        colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
        return super().format(colored_record)


logger = logging.getLogger("lora")
logger.propagate = False


if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(
        ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
    )
    logger.addHandler(handler)
 No newline at end of file
+48 −32
Original line number Diff line number Diff line
@@ -17,6 +17,8 @@ from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding

from lora_logger import logger

module_types = [
    network_lora.ModuleTypeLora(),
    network_hada.ModuleTypeHada(),
@@ -206,7 +208,40 @@ def load_network(name, network_on_disk):

        net.modules[key] = net_module

    net.bundle_embeddings = bundle_embeddings
    embeddings = {}
    for emb_name, data in bundle_embeddings.items():
        # textual inversion embeddings
        if 'string_to_param' in data:
            param_dict = data['string_to_param']
            param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1]
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
            vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
            vectors = data['clip_g'].shape[0]
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

            emb = next(iter(data.values()))
            if len(emb.shape) == 1:
                emb = emb.unsqueeze(0)
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        else:
            raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

        embedding = Embedding(vec, emb_name)
        embedding.vectors = vectors
        embedding.shape = shape
        embedding.loaded = None
        embeddings[emb_name] = embedding

    net.bundle_embeddings = embeddings

    if keys_failed_to_match:
        logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -229,7 +264,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
    for net in loaded_networks:
        if net.name in names:
            already_loaded[net.name] = net
        for emb_name in net.bundle_embeddings:
        for emb_name, embedding in net.bundle_embeddings.items():
            if embedding.loaded:
                emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)

    loaded_networks.clear()
@@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
        net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
        loaded_networks.append(net)

        for emb_name, data in net.bundle_embeddings.items():
            # textual inversion embeddings
            if 'string_to_param' in data:
                param_dict = data['string_to_param']
                param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
                assert len(param_dict) == 1, 'embedding file has multiple terms in it'
                emb = next(iter(param_dict.items()))[1]
                vec = emb.detach().to(devices.device, dtype=torch.float32)
                shape = vec.shape[-1]
                vectors = vec.shape[0]
            elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
                vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
                shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
                vectors = data['clip_g'].shape[0]
            elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
                assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

                emb = next(iter(data.values()))
                if len(emb.shape) == 1:
                    emb = emb.unsqueeze(0)
                vec = emb.detach().to(devices.device, dtype=torch.float32)
                shape = vec.shape[-1]
                vectors = vec.shape[0]
            else:
                raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

            embedding = Embedding(vec, emb_name)
            embedding.vectors = vectors
            embedding.shape = shape
        for emb_name, embedding in net.bundle_embeddings.items():
            if embedding.loaded is None and emb_name in emb_db.word_embeddings:
                logger.warning(
                    f'Skip bundle embedding: "{emb_name}"'
                    ' as it was already loaded from embeddings folder'
                )
                continue

            embedding.loaded = False
            if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
                embedding.loaded = True
                emb_db.register_embedding(embedding, shared.sd_model)
            else:
                emb_db.skipped_embeddings[name] = embedding