blob: 2e81cfb82c7dd87edc85be5a764f4c888da08141 [file] [log] [blame]
from io import BytesIO
from textwrap import dedent
import torch
from torch.package import (
PackageExporter,
PackageImporter,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestPackageScript(PackageTestCase):
"""Tests for compatibility with TorchScript."""
def test_package_interface(self):
"""Packaging an interface class should work correctly."""
import package_a.fake_interface as fake
uses_interface = fake.UsesInterface()
scripted = torch.jit.script(uses_interface)
scripted.proxy_mod = torch.jit.script(fake.NewModule())
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", uses_interface)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.load_pickle("model", "model.pkl")
scripted_loaded = torch.jit.script(loaded)
scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
input = torch.tensor(1)
self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input)))
def test_different_package_interface(self):
"""Test a case where the interface defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
# Import one version of the interface
import package_a.fake_interface as fake
# Simulate a package that contains a different version of the
# interface, with the exact same name.
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_source_string(
fake.__name__,
dedent(
"""\
import torch
from torch import Tensor
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
pass
class ImplementsInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
return inp1 + 1
class UsesInterface(torch.nn.Module):
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
self.proxy_mod = ImplementsInterface()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.one(input)
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
# We should be able to script successfully.
torch.jit.script(diff_fake.UsesInterface())
def test_package_script_class(self):
import package_a.fake_script_class as fake
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_module(fake.__name__)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.import_module(fake.__name__)
input = torch.tensor(1)
self.assertTrue(
torch.allclose(
fake.uses_script_class(input), loaded.uses_script_class(input)
)
)
def test_different_package_script_class(self):
"""Test a case where the script class defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
import package_a.fake_script_class as fake
# Simulate a package that contains a different version of the
# script class ,with the attribute `bar` instead of `foo`
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe2:
pe2.save_source_string(
fake.__name__,
dedent(
"""\
import torch
@torch.jit.script
class MyScriptClass:
def __init__(self, x):
self.bar = x
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
input = torch.rand(2, 3)
loaded_script_class = diff_fake.MyScriptClass(input)
orig_script_class = fake.MyScriptClass(input)
self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo))
if __name__ == "__main__":
run_tests()