|  | # Unpickler restricted to loading only state dicts | 
|  | # Restrict constructing types to a list defined in _get_allowed_globals() | 
|  | # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only | 
|  | # Restrict APPEND/APPENDS to `list` | 
|  | # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary | 
|  | # defined by `_get_allowed_globals()` method, that contains: | 
|  | # - torch types (Storage, dtypes, Tensor, `torch.Size`), | 
|  | # - `torch._utils._rebuild` functions. | 
|  | # - `torch.nn.Parameter` | 
|  | # - `collections.OrderedDict` | 
|  |  | 
|  | # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py | 
|  | # Expected to be useful for loading PyTorch model weights | 
|  | # For example: | 
|  | # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read() | 
|  | # buf = io.BytesIO(data) | 
|  | # weights = torch.load(buf, weights_only = True) | 
|  |  | 
|  | import functools as _functools | 
|  | from collections import OrderedDict | 
|  | from pickle import ( | 
|  | APPEND, | 
|  | APPENDS, | 
|  | BINFLOAT, | 
|  | BINGET, | 
|  | BININT, | 
|  | BININT1, | 
|  | BININT2, | 
|  | BINPERSID, | 
|  | BINPUT, | 
|  | BINUNICODE, | 
|  | BUILD, | 
|  | bytes_types, | 
|  | decode_long, | 
|  | EMPTY_DICT, | 
|  | EMPTY_LIST, | 
|  | EMPTY_SET, | 
|  | EMPTY_TUPLE, | 
|  | GLOBAL, | 
|  | LONG1, | 
|  | LONG_BINGET, | 
|  | LONG_BINPUT, | 
|  | MARK, | 
|  | NEWFALSE, | 
|  | NEWOBJ, | 
|  | NEWTRUE, | 
|  | NONE, | 
|  | PROTO, | 
|  | REDUCE, | 
|  | SETITEM, | 
|  | SETITEMS, | 
|  | SHORT_BINSTRING, | 
|  | STOP, | 
|  | TUPLE, | 
|  | TUPLE1, | 
|  | TUPLE2, | 
|  | TUPLE3, | 
|  | UnpicklingError, | 
|  | ) | 
|  | from struct import unpack | 
|  | from sys import maxsize | 
|  | from typing import Any, Dict, List | 
|  |  | 
|  | import torch | 
|  |  | 
|  |  | 
|  | # Unpickling machinery | 
|  | @_functools.lru_cache(maxsize=1) | 
|  | def _get_allowed_globals(): | 
|  | rc: Dict[str, Any] = { | 
|  | "collections.OrderedDict": OrderedDict, | 
|  | "torch.nn.parameter.Parameter": torch.nn.Parameter, | 
|  | "torch.serialization._get_layout": torch.serialization._get_layout, | 
|  | "torch.Size": torch.Size, | 
|  | "torch.Tensor": torch.Tensor, | 
|  | } | 
|  | # dtype | 
|  | for t in [ | 
|  | torch.complex32, | 
|  | torch.complex64, | 
|  | torch.complex128, | 
|  | torch.float16, | 
|  | torch.float32, | 
|  | torch.float64, | 
|  | torch.int8, | 
|  | torch.int16, | 
|  | torch.int32, | 
|  | torch.int64, | 
|  | ]: | 
|  | rc[str(t)] = t | 
|  | # Tensor classes | 
|  | for tt in torch._tensor_classes: | 
|  | rc[f"{tt.__module__}.{tt.__name__}"] = tt | 
|  | # Storage classes | 
|  | for ts in torch._storage_classes: | 
|  | rc[f"{ts.__module__}.{ts.__name__}"] = ts | 
|  | # Rebuild functions | 
|  | for f in [ | 
|  | torch._utils._rebuild_parameter, | 
|  | torch._utils._rebuild_tensor, | 
|  | torch._utils._rebuild_tensor_v2, | 
|  | torch._utils._rebuild_sparse_tensor, | 
|  | torch._utils._rebuild_meta_tensor_no_storage, | 
|  | ]: | 
|  | rc[f"torch._utils.{f.__name__}"] = f | 
|  |  | 
|  | # Handles Tensor Subclasses, Tensor's with attributes. | 
|  | # NOTE: It calls into above rebuild functions for regular Tensor types. | 
|  | rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 | 
|  | return rc | 
|  |  | 
|  |  | 
|  | class Unpickler: | 
|  | def __init__(self, file, *, encoding: str = "bytes"): | 
|  | self.encoding = encoding | 
|  | self.readline = file.readline | 
|  | self.read = file.read | 
|  | self.memo: Dict[int, Any] = {} | 
|  |  | 
|  | def load(self): | 
|  | """Read a pickled object representation from the open file. | 
|  |  | 
|  | Return the reconstituted object hierarchy specified in the file. | 
|  | """ | 
|  | self.metastack = [] | 
|  | self.stack: List[Any] = [] | 
|  | self.append = self.stack.append | 
|  | read = self.read | 
|  | readline = self.readline | 
|  | while True: | 
|  | key = read(1) | 
|  | if not key: | 
|  | raise EOFError | 
|  | assert isinstance(key, bytes_types) | 
|  | # Risky operators | 
|  | if key[0] == GLOBAL[0]: | 
|  | module = readline()[:-1].decode("utf-8") | 
|  | name = readline()[:-1].decode("utf-8") | 
|  | full_path = f"{module}.{name}" | 
|  | if full_path in _get_allowed_globals(): | 
|  | self.append(_get_allowed_globals()[full_path]) | 
|  | else: | 
|  | raise RuntimeError(f"Unsupported class {full_path}") | 
|  | elif key[0] == NEWOBJ[0]: | 
|  | args = self.stack.pop() | 
|  | cls = self.stack.pop() | 
|  | if cls is not torch.nn.Parameter: | 
|  | raise RuntimeError(f"Trying to instantiate unsupported class {cls}") | 
|  | self.append(torch.nn.Parameter(*args)) | 
|  | elif key[0] == REDUCE[0]: | 
|  | args = self.stack.pop() | 
|  | func = self.stack[-1] | 
|  | if func not in _get_allowed_globals().values(): | 
|  | raise RuntimeError( | 
|  | f"Trying to call reduce for unrecognized function {func}" | 
|  | ) | 
|  | self.stack[-1] = func(*args) | 
|  | elif key[0] == BUILD[0]: | 
|  | state = self.stack.pop() | 
|  | inst = self.stack[-1] | 
|  | if type(inst) is torch.Tensor: | 
|  | # Legacy unpickling | 
|  | inst.set_(*state) | 
|  | elif type(inst) is torch.nn.Parameter: | 
|  | inst.__setstate__(state) | 
|  | elif type(inst) is OrderedDict: | 
|  | inst.__dict__.update(state) | 
|  | else: | 
|  | raise RuntimeError( | 
|  | f"Can only build Tensor, parameter or dict objects, but got {type(inst)}" | 
|  | ) | 
|  | # Stack manipulation | 
|  | elif key[0] == APPEND[0]: | 
|  | item = self.stack.pop() | 
|  | list_obj = self.stack[-1] | 
|  | if type(list_obj) is not list: | 
|  | raise RuntimeError( | 
|  | f"Can only append to lists, but got {type(list_obj)}" | 
|  | ) | 
|  | list_obj.append(item) | 
|  | elif key[0] == APPENDS[0]: | 
|  | items = self.pop_mark() | 
|  | list_obj = self.stack[-1] | 
|  | if type(list_obj) is not list: | 
|  | raise RuntimeError( | 
|  | f"Can only extend lists, but got {type(list_obj)}" | 
|  | ) | 
|  | list_obj.extend(items) | 
|  | elif key[0] == SETITEM[0]: | 
|  | (v, k) = (self.stack.pop(), self.stack.pop()) | 
|  | self.stack[-1][k] = v | 
|  | elif key[0] == SETITEMS[0]: | 
|  | items = self.pop_mark() | 
|  | for i in range(0, len(items), 2): | 
|  | self.stack[-1][items[i]] = items[i + 1] | 
|  | elif key[0] == MARK[0]: | 
|  | self.metastack.append(self.stack) | 
|  | self.stack = [] | 
|  | self.append = self.stack.append | 
|  | elif key[0] == TUPLE[0]: | 
|  | items = self.pop_mark() | 
|  | self.append(tuple(items)) | 
|  | elif key[0] == TUPLE1[0]: | 
|  | self.stack[-1] = (self.stack[-1],) | 
|  | elif key[0] == TUPLE2[0]: | 
|  | self.stack[-2:] = [(self.stack[-2], self.stack[-1])] | 
|  | elif key[0] == TUPLE3[0]: | 
|  | self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] | 
|  | # Basic types construction | 
|  | elif key[0] == NONE[0]: | 
|  | self.append(None) | 
|  | elif key[0] == NEWFALSE[0]: | 
|  | self.append(False) | 
|  | elif key[0] == NEWTRUE[0]: | 
|  | self.append(True) | 
|  | elif key[0] == EMPTY_TUPLE[0]: | 
|  | self.append(()) | 
|  | elif key[0] == EMPTY_LIST[0]: | 
|  | self.append([]) | 
|  | elif key[0] == EMPTY_DICT[0]: | 
|  | self.append({}) | 
|  | elif key[0] == EMPTY_SET[0]: | 
|  | self.append(set()) | 
|  | elif key[0] == BININT[0]: | 
|  | self.append(unpack("<i", read(4))[0]) | 
|  | elif key[0] == BININT1[0]: | 
|  | self.append(self.read(1)[0]) | 
|  | elif key[0] == BININT2[0]: | 
|  | self.append(unpack("<H", read(2))[0]) | 
|  | elif key[0] == BINFLOAT[0]: | 
|  | self.append(unpack(">d", self.read(8))[0]) | 
|  | elif key[0] == BINUNICODE[0]: | 
|  | strlen = unpack("<I", read(4))[0] | 
|  | if strlen > maxsize: | 
|  | raise RuntimeError("String is too long") | 
|  | strval = str(read(strlen), "utf-8", "surrogatepass") | 
|  | self.append(strval) | 
|  | elif key[0] == SHORT_BINSTRING[0]: | 
|  | strlen = read(1)[0] | 
|  | strdata = read(strlen) | 
|  | if self.encoding != "bytes": | 
|  | strdata = strdata.decode(self.encoding, "strict") | 
|  | self.append(strdata) | 
|  | elif key[0] == BINPERSID[0]: | 
|  | pid = self.stack.pop() | 
|  | # Only allow persistent load of storage | 
|  | if type(pid) is not tuple and not type(pid) is not int: | 
|  | raise RuntimeError( | 
|  | f"persistent_load id must be tuple or int, but got {type(pid)}" | 
|  | ) | 
|  | if ( | 
|  | type(pid) is tuple | 
|  | and len(pid) > 0 | 
|  | and torch.serialization._maybe_decode_ascii(pid[0]) != "storage" | 
|  | ): | 
|  | raise RuntimeError( | 
|  | f"Only persistent_load of storage is allowed, but got {pid[0]}" | 
|  | ) | 
|  | self.append(self.persistent_load(pid)) | 
|  | elif key[0] in [BINGET[0], LONG_BINGET[0]]: | 
|  | idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0] | 
|  | self.append(self.memo[idx]) | 
|  | elif key[0] in [BINPUT[0], LONG_BINPUT[0]]: | 
|  | i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0] | 
|  | if i < 0: | 
|  | raise ValueError("negative argument") | 
|  | self.memo[i] = self.stack[-1] | 
|  | elif key[0] == LONG1[0]: | 
|  | n = read(1)[0] | 
|  | data = read(n) | 
|  | self.append(decode_long(data)) | 
|  | # First and last deserializer ops | 
|  | elif key[0] == PROTO[0]: | 
|  | # Read and ignore proto version | 
|  | read(1)[0] | 
|  | elif key[0] == STOP[0]: | 
|  | rc = self.stack.pop() | 
|  | return rc | 
|  | else: | 
|  | raise RuntimeError(f"Unsupported operand {key[0]}") | 
|  |  | 
|  | # Return a list of items pushed in the stack after last MARK instruction. | 
|  | def pop_mark(self): | 
|  | items = self.stack | 
|  | self.stack = self.metastack.pop() | 
|  | self.append = self.stack.append | 
|  | return items | 
|  |  | 
|  | def persistent_load(self, pid): | 
|  | raise UnpicklingError("unsupported persistent id encountered") | 
|  |  | 
|  |  | 
|  | def load(file, *, encoding: str = "ASCII"): | 
|  | return Unpickler(file, encoding=encoding).load() |