blob: cb6d84d35cd7dc3e652d453869044780077d6865 [file] [log] [blame]
from pathlib import Path
from typing import (
cast,
BinaryIO,
Sequence,
Union,
)
import torch
from torch.serialization import location_tag, normalize_storage_type
from torch.types import Storage
from ._zip_file_torchscript import TorchScriptPackageZipFileWriter
from .importer import sys_importer, Importer
from .package_exporter_no_torch import PackageExporter as DefaultPackageExporter
class PackageExporter(DefaultPackageExporter):
"""
A package exporter for specialized functionality for torch. Specifically it uses the optimizations
of torch's storage in order to not duplicate tensors.
"""
def __init__(
self,
f: Union[str, Path, BinaryIO],
importer: Union[Importer, Sequence[Importer]] = sys_importer,
):
super(PackageExporter, self).__init__(
f, importer, zip_file_writer_type=TorchScriptPackageZipFileWriter
)
def persistent_id(self, obj):
assert isinstance(self.zip_file, TorchScriptPackageZipFileWriter)
# needed for 'storage' typename which is a way in which torch models are saved
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj.size()
else:
storage = obj
storage_type = normalize_storage_type(type(storage))
dtype = torch.uint8
storage_numel = storage.nbytes()
storage = cast(Storage, storage)
location = location_tag(storage)
# serialize storage if not already written
storage_present = self.zip_file.storage_context.has_storage(storage)
storage_id = self.zip_file.storage_context.get_or_add_storage(storage)
if not storage_present:
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
self.zip_file.write_record(
f".data/{storage_id}.storage", storage.data_ptr(), num_bytes
)
return ("storage", storage_type, storage_id, location, storage_numel)
return None