Commit 5c8e53d5 authored by papuSpartan's avatar papuSpartan
Browse files

Allow different merge ratios to be used for each pass. Make toggle cmd flag...

Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible
parent c707b7df
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)

# token merging / tomesd
parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False)
parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)
parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
+15 −29
Original line number Diff line number Diff line
@@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
            if k == 'sd_vae':
                sd_vae.reload_vae_weights()

        if opts.token_merging and not opts.token_merging_hr_only:
            print("applying token merging to all passes")
            tomesd.apply_patch(
                p.sd_model,
                ratio=opts.token_merging_ratio,
                max_downsample=opts.token_merging_maximum_down_sampling,
                sx=opts.token_merging_stride_x,
                sy=opts.token_merging_stride_y,
                use_rand=opts.token_merging_random,
                merge_attn=opts.token_merging_merge_attention,
                merge_crossattn=opts.token_merging_merge_cross_attention,
                merge_mlp=opts.token_merging_merge_mlp
            )
        if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
            print("\nApplying token merging\n")
            sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)

        res = process_images_inner(p)

    finally:
        # undo model optimizations made by tomesd
        if opts.token_merging:
            print('removing token merging model optimizations')
        if opts.token_merging or cmd_opts.token_merging:
            print('\nRemoving token merging model optimizations\n')
            tomesd.remove_patch(p.sd_model)

        # restore opts to original state
@@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        devices.torch_gc()

        # apply token merging optimizations from tomesd for high-res pass
        # check if hr_only so we don't redundantly apply patch
        if opts.token_merging and opts.token_merging_hr_only:
            print("applying token merging for high-res pass")
            tomesd.apply_patch(
                self.sd_model,
                ratio=opts.token_merging_ratio,
                max_downsample=opts.token_merging_maximum_down_sampling,
                sx=opts.token_merging_stride_x,
                sy=opts.token_merging_stride_y,
                use_rand=opts.token_merging_random,
                merge_attn=opts.token_merging_merge_attention,
                merge_crossattn=opts.token_merging_merge_cross_attention,
                merge_mlp=opts.token_merging_merge_mlp
            )
        # check if hr_only so we are not redundantly patching
        if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
            # case where user wants to use separate merge ratios
            if not opts.token_merging_hr_only:
                # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
                print('Temporarily reverting token merging optimizations in preparation for next pass')
                tomesd.remove_patch(self.sd_model)

            print("\nApplying token merging for high-res pass\n")
            sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)

        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

+28 −1
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd

model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -546,3 +547,29 @@ def unload_model_weights(sd_model=None, info=None):
    print(f"Unloaded weights {timer.summary()}.")

    return sd_model


def apply_token_merging(sd_model, hr: bool):
    """
    Applies speed and memory optimizations from tomesd.

    Args:
        hr (bool): True if called in the context of a high-res pass
    """

    ratio = shared.opts.token_merging_ratio
    if hr:
        ratio = shared.opts.token_merging_ratio_hr
        print("effective hr pass merge ratio is "+str(ratio))

    tomesd.apply_patch(
        sd_model,
        ratio=ratio,
        max_downsample=shared.opts.token_merging_maximum_down_sampling,
        sx=shared.opts.token_merging_stride_x,
        sy=shared.opts.token_merging_stride_y,
        use_rand=shared.opts.token_merging_random,
        merge_attn=shared.opts.token_merging_merge_attention,
        merge_crossattn=shared.opts.token_merging_merge_cross_attention,
        merge_mlp=shared.opts.token_merging_merge_mlp
    )
+5 −1
Original line number Diff line number Diff line
@@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {

options_templates.update(options_section(('token_merging', 'Token Merging'), {
    "token_merging": OptionInfo(
        False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
        0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
        gr.Checkbox
    ),
    "token_merging_ratio": OptionInfo(
@@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
        True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
        gr.Checkbox
    ),
    "token_merging_ratio_hr": OptionInfo(
        0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
        gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
    ),
    # More advanced/niche settings:
    "token_merging_random": OptionInfo(
        True, "Use random perturbations - Disabling might help with certain samplers",