[package] implement `get_resource_reader` API (#51674)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51674
See
https://docs.python.org/3/library/importlib.html#importlib.abc.ResourceReader
Test Plan: Imported from OSS
Reviewed By: zdevito
Differential Revision: D26237034
Pulled By: suo
fbshipit-source-id: 4c19f6172d16b710737528d3de48372873b9368d
diff --git a/test/test_package.py b/test/test_package.py
index 685ea45..9a191fa 100644
--- a/test/test_package.py
+++ b/test/test_package.py
@@ -1,6 +1,7 @@
from torch.package.importer import ObjMismatchError
from unittest import skipIf
import inspect
+from textwrap import dedent
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
from tempfile import NamedTemporaryFile
from torch.package import (
@@ -722,6 +723,81 @@
self.assertTrue(packaged_dependency is not package_a.subpackage)
+class TestPackageResources(TestCase):
+ def test_resource_reader(self):
+ """Test compliance with the get_resource_reader importlib API."""
+ buffer = BytesIO()
+ with PackageExporter(buffer, verbose=False) as pe:
+ # Layout looks like:
+ # package
+ # ├── one/
+ # │ ├── a.txt
+ # │ ├── b.txt
+ # │ ├── c.txt
+ # │ └── three/
+ # │ ├── d.txt
+ # │ └── e.txt
+ # └── two/
+ # ├── f.txt
+ # └── g.txt
+ pe.save_text('one', 'a.txt', 'hello, a!')
+ pe.save_text('one', 'b.txt', 'hello, b!')
+ pe.save_text('one', 'c.txt', 'hello, c!')
+
+ pe.save_text('one.three', 'd.txt', 'hello, d!')
+ pe.save_text('one.three', 'e.txt', 'hello, e!')
+
+ pe.save_text('two', 'f.txt', 'hello, f!')
+ pe.save_text('two', 'g.txt', 'hello, g!')
+
+ buffer.seek(0)
+ importer = PackageImporter(buffer)
+
+ reader_one = importer.get_resource_reader('one')
+ with self.assertRaises(FileNotFoundError):
+ reader_one.resource_path('a.txt')
+
+ self.assertTrue(reader_one.is_resource('a.txt'))
+ self.assertEqual(reader_one.open_resource('a.txt').getbuffer(), b'hello, a!')
+ self.assertFalse(reader_one.is_resource('three'))
+ reader_one_contents = list(reader_one.contents())
+ self.assertSequenceEqual(reader_one_contents, ['a.txt', 'b.txt', 'c.txt', 'three'])
+
+ reader_two = importer.get_resource_reader('two')
+ self.assertTrue(reader_two.is_resource('f.txt'))
+ self.assertEqual(reader_two.open_resource('f.txt').getbuffer(), b'hello, f!')
+ reader_two_contents = list(reader_two.contents())
+ self.assertSequenceEqual(reader_two_contents, ['f.txt', 'g.txt'])
+
+ reader_one_three = importer.get_resource_reader('one.three')
+ self.assertTrue(reader_one_three.is_resource('d.txt'))
+ self.assertEqual(reader_one_three.open_resource('d.txt').getbuffer(), b'hello, d!')
+ reader_one_three_contenst = list(reader_one_three.contents())
+ self.assertSequenceEqual(reader_one_three_contenst, ['d.txt', 'e.txt'])
+
+ self.assertIsNone(importer.get_resource_reader('nonexistent_package'))
+
+ def test_package_resource_access(self):
+ """Packaged modules should be able to use the importlib.resources API to access
+ resources saved in the package.
+ """
+ mod_src = """\
+ import importlib.resources
+ import my_cool_resources
+
+ def secret_message():
+ return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
+ """
+ buffer = BytesIO()
+ with PackageExporter(buffer, verbose=False) as pe:
+ pe.save_source_string("foo.bar", dedent(mod_src))
+ pe.save_text('my_cool_resources', 'sekrit.txt', 'my sekrit plays')
+
+ buffer.seek(0)
+ importer = PackageImporter(buffer)
+ self.assertEqual(importer.import_module('foo.bar').secret_message(), 'my sekrit plays')
+
+
class ManglingTest(TestCase):
def test_unique_manglers(self):
"""
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index acaae68..069b72d 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -992,6 +992,11 @@
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
})
.def(
+ "has_record",
+ [](PyTorchStreamReader& self, const std::string& key) {
+ return self.hasRecord(key);
+ })
+ .def(
"get_storage_from_record",
[](PyTorchStreamReader& self,
const std::string& key,
diff --git a/torch/package/importer.py b/torch/package/importer.py
index 424775a..238388b 100644
--- a/torch/package/importer.py
+++ b/torch/package/importer.py
@@ -184,14 +184,3 @@
raise last_err
else:
raise ModuleNotFoundError(module_name)
-
- def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
- last_err = None
- for importer in self._importers:
- try:
- return importer.get_name(obj, name)
- except ObjNotFoundError as err:
- last_err = err
-
- assert last_err is not None
- raise last_err
diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py
index 9c8269c..b30419b 100644
--- a/torch/package/package_importer.py
+++ b/torch/package/package_importer.py
@@ -270,6 +270,17 @@
module = self.import_module(demangle(module_name))
return self.zip_reader.get_record(demangle(module.__file__)).decode('utf-8')
+ # note: named `get_resource_reader` so that importlib.resources can find it.
+ # This is otherwise considered an internal method.
+ def get_resource_reader(self, fullname):
+ try:
+ package = self._get_package(fullname)
+ except ImportError:
+ return None
+ if package.__loader__ is not self:
+ return None
+ return _PackageResourceReader(self, fullname)
+
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
if not parent:
return
@@ -413,12 +424,15 @@
else:
return module
- def _zipfile_path(self, package, resource):
+ def _zipfile_path(self, package, resource=None):
package = self._get_package(package)
- resource = _normalize_path(resource)
assert package.__loader__ is self
name = demangle(package.__name__)
- return f"{name.replace('.', '/')}/{resource}"
+ if resource is not None:
+ resource = _normalize_path(resource)
+ return f"{name.replace('.', '/')}/{resource}"
+ else:
+ return f"{name.replace('.', '/')}"
def _get_or_create_package(self, atoms: List[str]) -> 'Union[_PackageNode, _ExternNode]':
cur = self.root
@@ -488,3 +502,50 @@
return _package_imported_modules[object.__module__].__file__
return _orig_getfile(object)
inspect.getfile = patched_getfile
+
+
+class _PackageResourceReader:
+ """Private class used to support PackageImporter.get_resource_reader().
+
+ Confirms to the importlib.abc.ResourceReader interface. Allowed to access
+ the innards of PackageImporter.
+ """
+ def __init__(self, importer, fullname):
+ self.importer = importer
+ self.fullname = fullname
+
+ def open_resource(self, resource):
+ from io import BytesIO
+ return BytesIO(self.importer.load_binary(self.fullname, resource))
+
+ def resource_path(self, resource):
+ # The contract for resource_path is that it either returns a concrete
+ # file system path or raises FileNotFoundError.
+ raise FileNotFoundError
+
+ def is_resource(self, name):
+ path = self.importer._zipfile_path(self.fullname, name)
+ return self.importer.zip_reader.has_record(path)
+
+ def contents(self):
+ from pathlib import Path
+ filename = self.fullname.replace('.', '/')
+
+ fullname_path = Path(self.importer._zipfile_path(self.fullname))
+ files = self.importer.zip_reader.get_all_records()
+ subdirs_seen = set()
+ for filename in files:
+ try:
+ relative = Path(filename).relative_to(fullname_path)
+ except ValueError:
+ continue
+ # If the path of the file (which is relative to the top of the zip
+ # namespace), relative to the package given when the resource
+ # reader was created, has a parent, then it's a name in a
+ # subdirectory and thus we skip it.
+ parent_name = relative.parent.name
+ if len(parent_name) == 0:
+ yield relative.name
+ elif parent_name not in subdirs_seen:
+ subdirs_seen.add(parent_name)
+ yield parent_name