Commit 8eef9d8e authored by AUTOMATIC's avatar AUTOMATIC
Browse files

a way to add an exception to unpickler without explicitly calling load_with_extra

parent c5bdba20
Loading
Loading
Loading
Loading
+38 −1
Original line number Original line Diff line number Diff line
@@ -103,7 +103,7 @@ def check_pt(filename, extra_handler):




def load(filename, *args, **kwargs):
def load(filename, *args, **kwargs):
    return load_with_extra(filename, *args, **kwargs)
    return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)




def load_with_extra(filename, extra_handler=None, *args, **kwargs):
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
@@ -151,5 +151,42 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
    return unsafe_torch_load(filename, *args, **kwargs)
    return unsafe_torch_load(filename, *args, **kwargs)




class Extra:
    """
    A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
    (because it's not your code making the torch.load call). The intended use is like this:

```
import torch
from modules import safe

def handler(module, name):
    if module == 'torch' and name in ['float64', 'float16']:
        return getattr(torch, name)

    return None

with safe.Extra(handler):
    x = torch.load('model.pt')
```
    """

    def __init__(self, handler):
        self.handler = handler

    def __enter__(self):
        global global_extra_handler

        assert global_extra_handler is None, 'already inside an Extra() block'
        global_extra_handler = self.handler

    def __exit__(self, exc_type, exc_val, exc_tb):
        global global_extra_handler

        global_extra_handler = None


unsafe_torch_load = torch.load
unsafe_torch_load = torch.load
torch.load = load
torch.load = load
global_extra_handler = None