| # -*- coding: utf-8 -*- |
| import inspect |
| from io import BytesIO |
| from textwrap import dedent |
| |
| from torch.package import PackageExporter, PackageImporter, is_from_package |
| from torch.testing._internal.common_utils import run_tests |
| |
| try: |
| from .common import PackageTestCase |
| except ImportError: |
| # Support the case where we run this file directly. |
| from common import PackageTestCase |
| |
| |
| class TestMisc(PackageTestCase): |
| """Tests for one-off or random functionality. Try not to add to this!""" |
| |
| def test_file_structure(self): |
| """ |
| Tests package's Directory structure representation of a zip file. Ensures |
| that the returned Directory prints what is expected and filters |
| inputs/outputs correctly. |
| """ |
| buffer = BytesIO() |
| |
| export_plain = dedent( |
| """\ |
| ├── .data |
| │ ├── extern_modules |
| │ └── version |
| ├── main |
| │ └── main |
| ├── obj |
| │ └── obj.pkl |
| ├── package_a |
| │ ├── __init__.py |
| │ └── subpackage.py |
| └── module_a.py |
| """ |
| ) |
| export_include = dedent( |
| """\ |
| ├── obj |
| │ └── obj.pkl |
| └── package_a |
| └── subpackage.py |
| """ |
| ) |
| import_exclude = dedent( |
| """\ |
| ├── .data |
| │ ├── extern_modules |
| │ └── version |
| ├── main |
| │ └── main |
| ├── obj |
| │ └── obj.pkl |
| ├── package_a |
| │ ├── __init__.py |
| │ └── subpackage.py |
| └── module_a.py |
| """ |
| ) |
| |
| with PackageExporter(buffer, verbose=False) as he: |
| import module_a |
| import package_a |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| he.intern("**") |
| he.save_module(module_a.__name__) |
| he.save_module(package_a.__name__) |
| he.save_pickle("obj", "obj.pkl", obj) |
| he.save_text("main", "main", "my string") |
| |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| |
| file_structure = hi.file_structure() |
| # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| export_plain, |
| ) |
| file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"]) |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| export_include, |
| ) |
| |
| file_structure = hi.file_structure(exclude="**/*.storage") |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| import_exclude, |
| ) |
| |
| def test_file_structure_has_file(self): |
| """ |
| Test Directory's has_file() method. |
| """ |
| buffer = BytesIO() |
| with PackageExporter(buffer, verbose=False) as he: |
| import package_a.subpackage |
| |
| he.intern("**") |
| obj = package_a.subpackage.PackageASubpackageObject() |
| he.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| |
| importer = PackageImporter(buffer) |
| file_structure = importer.file_structure() |
| self.assertTrue(file_structure.has_file("package_a/subpackage.py")) |
| self.assertFalse(file_structure.has_file("package_a/subpackage")) |
| |
| def test_is_from_package(self): |
| """is_from_package should work for objects and modules""" |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer, verbose=False) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.import_module("package_a.subpackage") |
| loaded_obj = pi.load_pickle("obj", "obj.pkl") |
| |
| self.assertFalse(is_from_package(package_a.subpackage)) |
| self.assertTrue(is_from_package(mod)) |
| |
| self.assertFalse(is_from_package(obj)) |
| self.assertTrue(is_from_package(loaded_obj)) |
| |
| def test_inspect_class(self): |
| """Should be able to retrieve source for a packaged class.""" |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer, verbose=False) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| packaged_class = pi.import_module( |
| "package_a.subpackage" |
| ).PackageASubpackageObject |
| regular_class = package_a.subpackage.PackageASubpackageObject |
| |
| packaged_src = inspect.getsourcelines(packaged_class) |
| regular_src = inspect.getsourcelines(regular_class) |
| self.assertEqual(packaged_src, regular_src) |
| |
| def test_dunder_package_present(self): |
| """ |
| The attribute '__torch_package__' should be populated on imported modules. |
| """ |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer, verbose=False) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.import_module("package_a.subpackage") |
| self.assertTrue(hasattr(mod, "__torch_package__")) |
| |
| def test_dunder_package_works_from_package(self): |
| """ |
| The attribute '__torch_package__' should be accessible from within |
| the module itself, so that packaged code can detect whether it's |
| being used in a packaged context or not. |
| """ |
| import package_a.use_dunder_package as mod |
| |
| buffer = BytesIO() |
| |
| with PackageExporter(buffer, verbose=False) as pe: |
| pe.intern("**") |
| pe.save_module(mod.__name__) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| imported_mod = pi.import_module(mod.__name__) |
| self.assertTrue(imported_mod.is_from_package()) |
| self.assertFalse(mod.is_from_package()) |
| |
| def test_os_path_edge_case(self): |
| """ |
| Both 'os' and 'os.path' should be able to be imported into a package. |
| """ |
| import package_a.os_imports |
| |
| buffer = BytesIO() |
| mod = package_a.os_imports.Module() |
| |
| with PackageExporter(buffer, verbose=False) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", mod) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.load_pickle("obj", "obj.pkl") |
| mod() |
| |
| |
| if __name__ == "__main__": |
| run_tests() |