Unverified Commit 0198eaec authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #11757 from AUTOMATIC1111/sdxl

SD XL support
parents 9d3dd64f 14cf434b
Loading
Loading
Loading
Loading
+37 −11
Original line number Diff line number Diff line
@@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):

        return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"

    if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
        if 'mlp_fc1' in m[1]:
            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
        elif 'mlp_fc2' in m[1]:
            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
        else:
            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

    return key


@@ -147,6 +155,16 @@ class LoraUpDownModule:
def assign_lora_names_to_compvis_modules(sd_model):
    lora_layer_mapping = {}

    if shared.sd_model.is_sdxl:
        for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
            if not hasattr(embedder, 'wrapped'):
                continue

            for name, module in embedder.wrapped.named_modules():
                lora_name = f'{i}_{name.replace(".", "_")}'
                lora_layer_mapping[lora_name] = module
                module.lora_layer_name = lora_name
    else:
        for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
            lora_name = name.replace(".", "_")
            lora_layer_mapping[lora_name] = module
@@ -173,10 +191,10 @@ def load_lora(name, lora_on_disk):
    keys_failed_to_match = {}
    is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping

    for key_diffusers, weight in sd.items():
        key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
        key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
    for key_lora, weight in sd.items():
        key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)

        key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
        sd_module = shared.sd_model.lora_layer_mapping.get(key, None)

        if sd_module is None:
@@ -184,8 +202,16 @@ def load_lora(name, lora_on_disk):
            if m:
                sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)

        # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
        if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
            key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
            sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
        elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
            key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
            sd_module = shared.sd_model.lora_layer_mapping.get(key, None)

        if sd_module is None:
            keys_failed_to_match[key_diffusers] = key
            keys_failed_to_match[key_lora] = key
            continue

        lora_module = lora.modules.get(key, None)
@@ -208,9 +234,9 @@ def load_lora(name, lora_on_disk):
        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
        else:
            print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
            print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
            continue
            raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
            raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")

        with torch.no_grad():
            module.weight.copy_(weight)
@@ -222,7 +248,7 @@ def load_lora(name, lora_on_disk):
        elif lora_key == "lora_down.weight":
            lora_module.down = module
        else:
            raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
            raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")

    if keys_failed_to_match:
        print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
+1 −1
Original line number Diff line number Diff line
@@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
    return context_k, context_v


def attention_CrossAttention_forward(self, x, context=None, mask=None):
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
    h = self.heads

    q = self.to_q(x)
+4 −0
Original line number Diff line number Diff line
@@ -237,11 +237,13 @@ def prepare_environment():
    openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")

    stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
    stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
    k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
    codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
    blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')

    stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
    stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
    k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
    codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
    blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -299,6 +301,7 @@ def prepare_environment():
    os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)

    git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
    git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
    git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
    git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
    git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
@@ -323,6 +326,7 @@ def prepare_environment():
        exit(0)



def configure_for_tests():
    if "--api" not in sys.argv:
        sys.argv.append("--api")
+39 −14
Original line number Diff line number Diff line
@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
        send_me_to_gpu(first_stage_model, None)
        return first_stage_model_decode(z)

    # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
    if hasattr(sd_model.cond_stage_model, 'model'):
        sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model

    # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
    # send the model to GPU. Then put modules back. the modules will be in CPU.
    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
    to_remain_in_cpu = [
        (sd_model, 'first_stage_model'),
        (sd_model, 'depth_model'),
        (sd_model, 'embedder'),
        (sd_model, 'model'),
        (sd_model, 'embedder'),
    ]

    is_sdxl = hasattr(sd_model, 'conditioner')
    is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')

    if is_sdxl:
        to_remain_in_cpu.append((sd_model, 'conditioner'))
    elif is_sd2:
        to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
    else:
        to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))

    # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
    stored = []
    for obj, field in to_remain_in_cpu:
        module = getattr(obj, field, None)
        stored.append(module)
        setattr(obj, field, None)

    # send the model to GPU.
    sd_model.to(devices.device)
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored

    # put modules back. the modules will be in CPU.
    for (obj, field), module in zip(to_remain_in_cpu, stored):
        setattr(obj, field, module)

    # register hooks for those the first three models
    if is_sdxl:
        sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
    elif is_sd2:
        sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
    else:
        sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)

    sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.encode = first_stage_model_encode_wrap
    sd_model.first_stage_model.decode = first_stage_model_decode_wrap
@@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram):
        sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
    if sd_model.embedder:
        sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
    parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    if hasattr(sd_model.cond_stage_model, 'model'):
        sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
        del sd_model.cond_stage_model.transformer
    if hasattr(sd_model, 'cond_stage_model'):
        parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    if use_medvram:
        sd_model.model.register_forward_pre_hook(send_me_to_gpu)
+25 −0
Original line number Diff line number Diff line
@@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
import modules.safe  # noqa: F401


def mute_sdxl_imports():
    """create fake modules that SDXL wants to import but doesn't actually use for our purposes"""

    class Dummy:
        pass

    module = Dummy()
    module.LPIPS = None
    sys.modules['taming.modules.losses.lpips'] = module

    module = Dummy()
    module.StableDataModuleFromConfig = None
    sys.modules['sgm.data'] = module


# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)

@@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths:

assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"

mute_sdxl_imports()

path_dirs = [
    (sd_path, 'ldm', 'Stable Diffusion', []),
    (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
    (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
    (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
    (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
@@ -35,6 +53,13 @@ for d, must_exist, what, options in path_dirs:
        d = os.path.abspath(d)
        if "atstart" in options:
            sys.path.insert(0, d)
        elif "sgm" in options:
            # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
            # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.

            sys.path.insert(0, d)
            import sgm  # noqa: F401
            sys.path.pop(0)
        else:
            sys.path.append(d)
        paths[what] = d
Loading