Add docstring to torch.serialization.register_package (#104046)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104046
Approved by: https://github.com/albanD
diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst
index ffb266f..df24e3e 100644
--- a/docs/source/notes/serialization.rst
+++ b/docs/source/notes/serialization.rst
@@ -341,3 +341,14 @@
 integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6
 and later cannot be loaded in earlier versions of PyTorch, however, since those
 earlier versions do not understand the new behavior.
+
+.. _utility functions:
+
+Utility functions
+-----------------
+
+The following utility functions are related to serialization:
+
+.. currentmodule:: torch.serialization
+
+.. autofunction:: register_package
diff --git a/torch/serialization.py b/torch/serialization.py
index 6f32dc00..d55fa88 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -32,6 +32,7 @@
 
 FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
 MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]]
+STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
 
 __all__ = [
     'SourceChangeWarning',
@@ -90,7 +91,46 @@
     return read_bytes == local_header_magic_number
 
 
-def register_package(priority, tagger, deserializer):
+def register_package(
+    priority: int,
+    tagger: Callable[[STORAGE], Optional[str]],
+    deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
+):
+    '''
+    Registers callables for tagging and deserializing storage objects with an associated priority.
+    Tagging associates a device with a storage object at save time while deserializing moves a
+    storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
+    are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
+    value that is not `None`.
+
+    To override the deserialization behavior for a device in the global registry, one can register a
+    tagger with a higher priority than the existing tagger.
+
+    This function can also be used to register a tagger and deserializer for new devices.
+
+    Args:
+        priority: Indicates the priority associated with the tagger and deserializer, where a lower
+            value indicates higher priority.
+        tagger: Callable that takes in a storage object and returns its tagged device as a string
+            or None.
+        deserializer: Callable that takes in storage object and a device string and returns a storage
+            object on the appropriate device or None.
+
+    Returns:
+        `None`
+
+    Example:
+        >>> def ipu_tag(obj):
+        >>>     if obj.device.type == 'ipu':
+        >>>         return 'ipu'
+        >>> def ipu_deserialize(obj, location):
+        >>>     if location.startswith('ipu'):
+        >>>         ipu = getattr(torch, "ipu", None)
+        >>>         assert ipu is not None, "IPU device module is not loaded"
+        >>>         assert torch.ipu.is_available(), "ipu is not available"
+        >>>         return obj.ipu(location)
+        >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
+    '''
     queue_elem = (priority, tagger, deserializer)
     _package_registry.append(queue_elem)
     _package_registry.sort()