Commit 5eaddb67 authored by w-e-w's avatar w-e-w
Browse files

use partial function

parent f783f43b
Loading
Loading
Loading
Loading
+10 −12
Original line number Original line Diff line number Diff line
import hashlib
import hashlib
import os.path
import os.path
from functools import partial


from modules import shared
from modules import shared
import modules.cache
import modules.cache
@@ -38,19 +39,12 @@ def sha256_from_cache(filename, title, use_addnet_hash=False):
def sha256(filename, title, use_addnet_hash=False):
def sha256(filename, title, use_addnet_hash=False):
    if shared.opts.experimental_sqlite_cache:
    if shared.opts.experimental_sqlite_cache:
        if use_addnet_hash:
        if use_addnet_hash:
            def calculate_addnet_hash_sqlite3():
            subsection = "hashes-addnet"
                with open(filename, "rb") as file:
            calculate_hash = partial(calculate_addnet_hash, filename)
                    return addnet_hash_safetensors(file)
            return modules.cache.cached_data_for_file("hashes-addnet", title, filename, calculate_addnet_hash_sqlite3)
        else:
        else:
            def calculate_sha256_sqlite3():
            subsection = "hashes"
                hash_sha256 = hashlib.sha256()
            calculate_hash = partial(calculate_sha256, filename)
                blksize = 1024 * 1024
        return modules.cache.cached_data_for_file(subsection, title, filename, calculate_hash)
                with open(filename, "rb") as f:
                    for chunk in iter(lambda: f.read(blksize), b""):
                        hash_sha256.update(chunk)
                return hash_sha256.hexdigest()
            return modules.cache.cached_data_for_file("hashes", title, filename, calculate_sha256_sqlite3)


    hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
    hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")


@@ -95,3 +89,7 @@ def addnet_hash_safetensors(b):


    return hash_sha256.hexdigest()
    return hash_sha256.hexdigest()



def calculate_addnet_hash(filename):
    with open(filename, "rb") as f:
        return addnet_hash_safetensors(f)