Add docstring to torch.serialization.register_package (#104046)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104046
Approved by: https://github.com/albanD
diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst
index ffb266f..df24e3e 100644
--- a/docs/source/notes/serialization.rst
+++ b/docs/source/notes/serialization.rst
@@ -341,3 +341,14 @@
integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6
and later cannot be loaded in earlier versions of PyTorch, however, since those
earlier versions do not understand the new behavior.
+
+.. _utility functions:
+
+Utility functions
+-----------------
+
+The following utility functions are related to serialization:
+
+.. currentmodule:: torch.serialization
+
+.. autofunction:: register_package
diff --git a/torch/serialization.py b/torch/serialization.py
index 6f32dc00..d55fa88 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -32,6 +32,7 @@
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]]
+STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
__all__ = [
'SourceChangeWarning',
@@ -90,7 +91,46 @@
return read_bytes == local_header_magic_number
-def register_package(priority, tagger, deserializer):
+def register_package(
+ priority: int,
+ tagger: Callable[[STORAGE], Optional[str]],
+ deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
+):
+ '''
+ Registers callables for tagging and deserializing storage objects with an associated priority.
+ Tagging associates a device with a storage object at save time while deserializing moves a
+ storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
+ are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
+ value that is not `None`.
+
+ To override the deserialization behavior for a device in the global registry, one can register a
+ tagger with a higher priority than the existing tagger.
+
+ This function can also be used to register a tagger and deserializer for new devices.
+
+ Args:
+ priority: Indicates the priority associated with the tagger and deserializer, where a lower
+ value indicates higher priority.
+ tagger: Callable that takes in a storage object and returns its tagged device as a string
+ or None.
+ deserializer: Callable that takes in storage object and a device string and returns a storage
+ object on the appropriate device or None.
+
+ Returns:
+ `None`
+
+ Example:
+ >>> def ipu_tag(obj):
+ >>> if obj.device.type == 'ipu':
+ >>> return 'ipu'
+ >>> def ipu_deserialize(obj, location):
+ >>> if location.startswith('ipu'):
+ >>> ipu = getattr(torch, "ipu", None)
+ >>> assert ipu is not None, "IPU device module is not loaded"
+ >>> assert torch.ipu.is_available(), "ipu is not available"
+ >>> return obj.ipu(location)
+ >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
+ '''
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()