blob: 0cd46ba98609afc3dd591c6b52352eec88b42100 [file] [log] [blame]
# Owner(s): ["oncall: package/deploy"]
from io import BytesIO
from torch.package import (
PackageExporter,
PackageImporter,
sys_importer,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestRepackage(PackageTestCase):
"""Tests for repackaging."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_repackage_import_indirectly_via_parent_module(self):
from package_d.imports_directly import ImportsDirectlyFromSubSubPackage
from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage
model_a = ImportsDirectlyFromSubSubPackage()
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_pickle("default", "model.py", model_a)
buffer.seek(0)
pi = self.PackageImporter(buffer)
loaded_model = pi.load_pickle("default", "model.py")
model_b = ImportsIndirectlyFromSubPackage()
buffer = BytesIO()
with self.PackageExporter(
buffer,
importer=(
pi,
sys_importer,
),
) as pe:
pe.intern("**")
pe.save_pickle("default", "model_b.py", model_b)
class TestRepackageNoTorch(TestRepackage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()