Commit ac83627a authored by papuSpartan's avatar papuSpartan
Browse files

heavily simplify

parent 55e52c87
Loading
Loading
Loading
Loading
+0 −36
Original line number Diff line number Diff line
@@ -282,33 +282,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
        res["Hires resize-1"] = 0
        res["Hires resize-2"] = 0

    # Infer additional override settings for token merging
    token_merging_ratio = res.get("Token merging ratio", None)
    token_merging_ratio_hr = res.get("Token merging ratio hr", None)

    if token_merging_ratio is not None or token_merging_ratio_hr is not None:
        res["Token merging"] = 'True'

        if token_merging_ratio is None:
            res["Token merging hr only"] = 'True'
        else:
            res["Token merging hr only"] = 'False'

        if res.get("Token merging random", None) is None:
            res["Token merging random"] = 'False'
        if res.get("Token merging merge attention", None) is None:
            res["Token merging merge attention"] = 'True'
        if res.get("Token merging merge cross attention", None) is None:
            res["Token merging merge cross attention"] = 'False'
        if res.get("Token merging merge mlp", None) is None:
            res["Token merging merge mlp"] = 'False'
        if res.get("Token merging stride x", None) is None:
            res["Token merging stride x"] = '2'
        if res.get("Token merging stride y", None) is None:
            res["Token merging stride y"] = '2'
        if res.get("Token merging maximum down sampling", None) is None:
            res["Token merging maximum down sampling"] = '1'

    restore_old_hires_fix_params(res)

    # Missing RNG means the default was set, which is GPU RNG
@@ -335,17 +308,8 @@ infotext_to_setting_name_mapping = [
    ('UniPC skip type', 'uni_pc_skip_type'),
    ('UniPC order', 'uni_pc_order'),
    ('UniPC lower order final', 'uni_pc_lower_order_final'),
    ('Token merging', 'token_merging'),
    ('Token merging ratio', 'token_merging_ratio'),
    ('Token merging hr only', 'token_merging_hr_only'),
    ('Token merging ratio hr', 'token_merging_ratio_hr'),
    ('Token merging random', 'token_merging_random'),
    ('Token merging merge attention', 'token_merging_merge_attention'),
    ('Token merging merge cross attention', 'token_merging_merge_cross_attention'),
    ('Token merging merge mlp', 'token_merging_merge_mlp'),
    ('Token merging maximum down sampling', 'token_merging_maximum_down_sampling'),
    ('Token merging stride x', 'token_merging_stride_x'),
    ('Token merging stride y', 'token_merging_stride_y'),
    ('RNG', 'randn_source'),
    ('NGMS', 's_min_uncond')
]
+15 −20
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ import tomesd
# add a logger for the processing module
logger = logging.getLogger(__name__)
# manually set output level here since there is no option to do so yet through launch options
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')


# some of those options should not be changed at all because they would break the model, so I removed them from options.
@@ -496,15 +496,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
        "Clip skip": None if clip_skip <= 1 else clip_skip,
        "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
        "Token merging ratio": None if not opts.token_merging or opts.token_merging_hr_only else opts.token_merging_ratio,
        "Token merging ratio hr": None if not opts.token_merging else opts.token_merging_ratio_hr,
        "Token merging random": None if opts.token_merging_random is False else opts.token_merging_random,
        "Token merging merge attention": None if opts.token_merging_merge_attention is True else opts.token_merging_merge_attention,
        "Token merging merge cross attention": None if opts.token_merging_merge_cross_attention is False else opts.token_merging_merge_cross_attention,
        "Token merging merge mlp": None if opts.token_merging_merge_mlp is False else opts.token_merging_merge_mlp,
        "Token merging stride x": None if opts.token_merging_stride_x == 2 else opts.token_merging_stride_x,
        "Token merging stride y": None if opts.token_merging_stride_y == 2 else opts.token_merging_stride_y,
        "Token merging maximum down sampling": None if opts.token_merging_maximum_down_sampling == 1 else opts.token_merging_maximum_down_sampling,
        "Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
        "Token merging ratio hr": None if not p.enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
        "Init image hash": getattr(p, 'init_img_hash', None),
        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
@@ -538,15 +531,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
            if k == 'sd_vae':
                sd_vae.reload_vae_weights()

        if opts.token_merging and not opts.token_merging_hr_only:
        if opts.token_merging_ratio > 0:
            sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
            logger.debug('Token merging applied')
            logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'")

        res = process_images_inner(p)

    finally:
        # undo model optimizations made by tomesd
        if opts.token_merging:
        if opts.token_merging_ratio > 0:
            tomesd.remove_patch(p.sd_model)
            logger.debug('Token merging model optimizations removed')

@@ -1003,19 +996,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        devices.torch_gc()

        # apply token merging optimizations from tomesd for high-res pass
        # check if hr_only so we are not redundantly patching
        if opts.token_merging and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
            # case where user wants to use separate merge ratios
            if not opts.token_merging_hr_only:
                # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
        if opts.token_merging_ratio_hr > 0:
            # in case the user has used separate merge ratios
            if opts.token_merging_ratio > 0:
                tomesd.remove_patch(self.sd_model)
                logger.debug('Temporarily removed token merging optimizations in preparation for next pass')
                logger.debug('Adjusting token merging ratio for high-res pass')

            sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
            logger.debug('Applied token merging for high-res pass')
            logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'")

        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

        if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0:
            tomesd.remove_patch(self.sd_model)
            logger.debug('Removed token merging optimizations from model')

        self.is_hr_pass = False

        return samples
+4 −7
Original line number Diff line number Diff line
@@ -596,11 +596,8 @@ def apply_token_merging(sd_model, hr: bool):
    tomesd.apply_patch(
        sd_model,
        ratio=ratio,
        max_downsample=shared.opts.token_merging_maximum_down_sampling,
        sx=shared.opts.token_merging_stride_x,
        sy=shared.opts.token_merging_stride_y,
        use_rand=shared.opts.token_merging_random,
        merge_attn=shared.opts.token_merging_merge_attention,
        merge_crossattn=shared.opts.token_merging_merge_cross_attention,
        merge_mlp=shared.opts.token_merging_merge_mlp
        use_rand=False,  # can cause issues with some samplers
        merge_attn=True,
        merge_crossattn=False,
        merge_mlp=False
    )
+4 −38
Original line number Diff line number Diff line
@@ -459,47 +459,13 @@ options_templates.update(options_section((None, "Hidden options"), {
}))

options_templates.update(options_section(('token_merging', 'Token Merging'), {
    "token_merging": OptionInfo(
        False, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
        gr.Checkbox
    ),
    "token_merging_ratio": OptionInfo(
        0.5, "Merging Ratio",
        gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
    ),
    "token_merging_hr_only": OptionInfo(
        True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
        gr.Checkbox
    ),
    "token_merging_ratio_hr": OptionInfo(
        0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
        0, "Merging Ratio (high-res pass)",
        gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
    ),
    # More advanced/niche settings:
    "token_merging_random": OptionInfo(
        False, "Use random perturbations - Can improve outputs for certain samplers. For others, it may cause visual artifacting.",
        gr.Checkbox
    ),
    "token_merging_merge_attention": OptionInfo(
        True, "Merge attention",
        gr.Checkbox
    ),
     "token_merging_merge_cross_attention": OptionInfo(
        False, "Merge cross attention",
        gr.Checkbox
    ),
    "token_merging_merge_mlp": OptionInfo(
        False, "Merge mlp",
        gr.Checkbox
    ),
    "token_merging_maximum_down_sampling": OptionInfo(1, "Maximum down sampling", gr.Radio, lambda: {"choices": [1, 2, 4, 8]}),
    "token_merging_stride_x": OptionInfo(
        2, "Stride - X",
        gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
    ),
    "token_merging_stride_y": OptionInfo(
        2, "Stride - Y",
        gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
    "token_merging_ratio": OptionInfo(
        0, "Merging Ratio",
        gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
    )
}))