Commit c1928cdd authored by AUTOMATIC's avatar AUTOMATIC
Browse files

bring back short hashes to sd checkpoint selection

parent d1ea518d
Loading
Loading
Loading
Loading
+11 −4
Original line number Diff line number Diff line
@@ -41,14 +41,16 @@ class CheckpointInfo:
        if name.startswith("\\") or name.startswith("/"):
            name = name[1:]

        self.title = name
        self.name = name
        self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
        self.hash = model_hash(filename)

        self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
        self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
        self.shorthash = self.sha256[0:10] if self.sha256 else None

        self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'

        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])

    def register(self):
        checkpoints_list[self.title] = self
@@ -56,13 +58,15 @@ class CheckpointInfo:
            checkpoint_alisases[id] = self

    def calculate_shorthash(self):
        self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
        self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
        self.shorthash = self.sha256[0:10]

        if self.shorthash not in self.ids:
            self.ids += [self.shorthash, self.sha256]
            self.register()

        self.title = f'{self.name} [{self.shorthash}]'

        return self.shorthash


@@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None


def load_model_weights(model, checkpoint_info: CheckpointInfo):
    title = checkpoint_info.title
    sd_model_hash = checkpoint_info.calculate_shorthash()
    if checkpoint_info.title != title:
        shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

    cache_enabled = shared.opts.sd_checkpoint_cache > 0

+12 −11
Original line number Diff line number Diff line
@@ -439,7 +439,7 @@ def apply_setting(key, value):
        opts.data_labels[key].onchange()

    opts.save(shared.config_filename)
    return value
    return getattr(opts, key)


def update_generation_info(generation_info, html_info, img_index):
@@ -597,6 +597,16 @@ def ordered_ui_categories():
        yield category


def get_value_for_setting(key):
    value = getattr(opts, key)

    info = opts.data_labels[key]
    args = info.component_args() if callable(info.component_args) else info.component_args or {}
    args = {k: v for k, v in args.items() if k not in {'precision'}}

    return gr.update(value=value, **args)


def create_ui():
    import modules.img2img
    import modules.txt2img
@@ -1600,7 +1610,7 @@ def create_ui():

        opts.save(shared.config_filename)

        return gr.update(value=value), opts.dumpjson()
        return get_value_for_setting(key), opts.dumpjson()

    with gr.Blocks(analytics_enabled=False) as settings_interface:
        with gr.Row():
@@ -1771,15 +1781,6 @@ def create_ui():

        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

        def get_value_for_setting(key):
            value = getattr(opts, key)

            info = opts.data_labels[key]
            args = info.component_args() if callable(info.component_args) else info.component_args or {}
            args = {k: v for k, v in args.items() if k not in {'precision'}}

            return gr.update(value=value, **args)

        def get_settings_values():
            return [get_value_for_setting(key) for key in component_keys]