Unverified Commit c3777777 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #5327 from smirkingface/master

Fixed safety checker for ckpt files written with pytorch >=1.13
parents 4b3c5bc2 e4614778
Loading
Loading
Loading
Loading
+11 −7
Original line number Diff line number Diff line
@@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
        raise Exception(f"global '{module}/{name}' is forbidden")


allowed_zip_names = ["archive/data.pkl", "archive/version"]
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")

# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")

def check_zip_filenames(filename, names):
    for name in names:
        if name in allowed_zip_names:
            continue
        if allowed_zip_names_re.match(name):
            continue

@@ -83,7 +81,13 @@ def check_pt(filename, extra_handler):
        with zipfile.ZipFile(filename) as z:
            check_zip_filenames(filename, z.namelist())
            
            with z.open('archive/data.pkl') as file:
            # find filename of data.pkl in zip file: '<directory name>/data.pkl'
            data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
            if len(data_pkl_filenames) == 0:
                raise Exception(f"data.pkl not found in {filename}")
            if len(data_pkl_filenames) > 1:
                raise Exception(f"Multiple data.pkl found in {filename}")
            with z.open(data_pkl_filenames[0]) as file:
                unpickler = RestrictedUnpickler(file)
                unpickler.extra_handler = extra_handler
                unpickler.load()