Commit 598da5cd authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Use options instead of cmd_args

parent b60e1088
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -118,5 +118,3 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
+14 −11
Original line number Diff line number Diff line
@@ -116,8 +116,7 @@ patch_module_list = [
    torch.nn.LayerNorm,
]

@contextlib.contextmanager
def manual_autocast():

def manual_cast_forward(self, *args, **kwargs):
    org_dtype = next(self.parameters()).dtype
    self.to(dtype)
@@ -126,6 +125,10 @@ def manual_autocast():
    result = self.org_forward(*args, **kwargs)
    self.to(org_dtype)
    return result


@contextlib.contextmanager
def manual_autocast():
    for module_type in patch_module_list:
        org_forward = module_type.forward
        module_type.forward = manual_cast_forward
+1 −0
Original line number Diff line number Diff line
@@ -177,6 +177,7 @@ def configure_opts_onchange():
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
    shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
    shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
    shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
    startup_timer.record("opts onchange")


+32 −29
Original line number Diff line number Diff line
@@ -339,10 +339,28 @@ class SkipWritingToConfig:
        SkipWritingToConfig.skip = self.previous


def check_fp8(model):
    if model is None:
        return None
    if devices.get_optimal_device_name() == "mps":
        enable_fp8 = False
    elif shared.opts.fp8_storage == "Enable":
        enable_fp8 = True
    elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
        enable_fp8 = True
    else:
        enable_fp8 = False
    return enable_fp8


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
    sd_model_hash = checkpoint_info.calculate_shorthash()
    timer.record("calculate hash")

    if not check_fp8(model) and devices.fp8:
        # prevent model to load state dict in fp8
        model.half()

    if not SkipWritingToConfig.skip:
        shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

@@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
        devices.dtype_unet = torch.float16
        timer.record("apply half()")

    if devices.get_optimal_device_name() == "mps":
        enable_fp8 = False
    elif shared.cmd_opts.opt_unet_fp8_storage:
        enable_fp8 = True
    elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
        enable_fp8 = True
    else:
        enable_fp8 = False

    if enable_fp8:
    if check_fp8(model):
        devices.fp8 = True
        if model.is_sdxl:
            cond_stage = model.conditioner
        else:
            cond_stage = model.cond_stage_model

        for module in cond_stage.modules():
            if isinstance(module, torch.nn.Linear):
                module.to(torch.float8_e4m3fn)

        if devices.device == devices.cpu:
            for module in model.model.diffusion_model.modules():
        first_stage = model.first_stage_model
        model.first_stage_model = None
        for module in model.modules():
            if isinstance(module, torch.nn.Conv2d):
                module.to(torch.float8_e4m3fn)
            elif isinstance(module, torch.nn.Linear):
                module.to(torch.float8_e4m3fn)
        else:
            model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
        model.first_stage_model = first_stage
        timer.record("apply fp8")
    else:
        devices.fp8 = False
@@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
        return None


def reload_model_weights(sd_model=None, info=None):
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
    checkpoint_info = info or select_checkpoint()

    timer = Timer()
@@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None):
        current_checkpoint_info = None
    else:
        current_checkpoint_info = sd_model.sd_checkpoint_info
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
        if check_fp8(sd_model) != devices.fp8:
            # load from state dict again to prevent extra numerical errors
            forced_reload = True
        elif sd_model.sd_model_checkpoint == checkpoint_info.filename:
            return sd_model

    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
    if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
        return sd_model

    if sd_model is not None:
+1 −0
Original line number Diff line number Diff line
@@ -200,6 +200,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
    "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
    "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
    "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
    "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
}))

options_templates.update(options_section(('compatibility', "Compatibility"), {
Loading