Commit 6e4de5b4 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add load_with_extra function for modules to load checkpoints with extended whitelist

parent 9cd1a666
Loading
Loading
Loading
Loading
+37 −3
Original line number Diff line number Diff line
@@ -23,11 +23,18 @@ def encode(*args):


class RestrictedUnpickler(pickle.Unpickler):
    extra_handler = None

    def persistent_load(self, saved_id):
        assert saved_id[0] == 'storage'
        return TypedStorage()

    def find_class(self, module, name):
        if self.extra_handler is not None:
            res = self.extra_handler(module, name)
            if res is not None:
                return res

        if module == 'collections' and name == 'OrderedDict':
            return getattr(collections, name)
        if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler):
            return set

        # Forbid everything else.
        raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
        raise Exception(f"global '{module}/{name}' is forbidden")


allowed_zip_names = ["archive/data.pkl", "archive/version"]
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
        raise Exception(f"bad file inside {filename}: {name}")


def check_pt(filename):
def check_pt(filename, extra_handler):
    try:

        # new pytorch format is a zip file
@@ -78,6 +85,7 @@ def check_pt(filename):

            with z.open('archive/data.pkl') as file:
                unpickler = RestrictedUnpickler(file)
                unpickler.extra_handler = extra_handler
                unpickler.load()

    except zipfile.BadZipfile:
@@ -85,16 +93,42 @@ def check_pt(filename):
        # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
        with open(filename, "rb") as file:
            unpickler = RestrictedUnpickler(file)
            unpickler.extra_handler = extra_handler
            for i in range(5):
                unpickler.load()


def load(filename, *args, **kwargs):
    return load_with_extra(filename, *args, **kwargs)


def load_with_extra(filename, extra_handler=None, *args, **kwargs):
    """
    this functon is intended to be used by extensions that want to load models with
    some extra classes in them that the usual unpickler would find suspicious.

    Use the extra_handler argument to specify a function that takes module and field name as text,
    and returns that field's value:

    ```python
    def extra(module, name):
        if module == 'collections' and name == 'OrderedDict':
            return collections.OrderedDict

        return None

    safe.load_with_extra('model.pt', extra_handler=extra)
    ```

    The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
    definitely unsafe.
    """

    from modules import shared

    try:
        if not shared.cmd_opts.disable_safe_unpickle:
            check_pt(filename)
            check_pt(filename, extra_handler)

    except pickle.UnpicklingError:
        print(f"Error verifying pickled file from {filename}:", file=sys.stderr)