Commit 9d5948e5 authored by space-nuko's avatar space-nuko
Browse files

Correctly restore hypernetwork from hash

parent 70774282
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
@@ -139,6 +140,30 @@ def run_bind():
            )


def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
    """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.

    Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
    parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.

    If the infotext has no hash, then a hypernet with the same name and the most steps will be selected instead.
    """
    hypernet_name = hypernet_name.lower()
    if hypernet_hash is not None:
        # Try to match the hash in the name
        for hypernet_key in shared.hypernetworks.keys():
            result = re_hypernet_hash.search(hypernet_key)
            if result is not None and result[1] == hypernet_hash:
                return hypernet_key
    else:
        # Fall back to a hypernet with the same name
        for hypernet_key in shared.hypernetworks.keys():
            if hypernet_key.lower().startswith(hypernet_name):
                return hypernet_key

    return None


def parse_generation_parameters(x: str):
    """parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -188,6 +213,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
    if "Clip skip" not in res:
        res["Clip skip"] = "1"

    if "Hypernet" in res:
        hypernet_name = res["Hypernet"]
        hypernet_hash = res.get("Hypernet hash", None)
        res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)

    return res