[package] implement `get_resource_reader` API (#51674)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51674

See
https://docs.python.org/3/library/importlib.html#importlib.abc.ResourceReader

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D26237034

Pulled By: suo

fbshipit-source-id: 4c19f6172d16b710737528d3de48372873b9368d
diff --git a/test/test_package.py b/test/test_package.py
index 685ea45..9a191fa 100644
--- a/test/test_package.py
+++ b/test/test_package.py
@@ -1,6 +1,7 @@
 from torch.package.importer import ObjMismatchError
 from unittest import skipIf
 import inspect
+from textwrap import dedent
 from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
 from tempfile import NamedTemporaryFile
 from torch.package import (
@@ -722,6 +723,81 @@
         self.assertTrue(packaged_dependency is not package_a.subpackage)
 
 
+class TestPackageResources(TestCase):
+    def test_resource_reader(self):
+        """Test compliance with the get_resource_reader importlib API."""
+        buffer = BytesIO()
+        with PackageExporter(buffer, verbose=False) as pe:
+            # Layout looks like:
+            #    package
+            #    ├── one/
+            #    │   ├── a.txt
+            #    │   ├── b.txt
+            #    │   ├── c.txt
+            #    │   └── three/
+            #    │       ├── d.txt
+            #    │       └── e.txt
+            #    └── two/
+            #       ├── f.txt
+            #       └── g.txt
+            pe.save_text('one', 'a.txt', 'hello, a!')
+            pe.save_text('one', 'b.txt', 'hello, b!')
+            pe.save_text('one', 'c.txt', 'hello, c!')
+
+            pe.save_text('one.three', 'd.txt', 'hello, d!')
+            pe.save_text('one.three', 'e.txt', 'hello, e!')
+
+            pe.save_text('two', 'f.txt', 'hello, f!')
+            pe.save_text('two', 'g.txt', 'hello, g!')
+
+        buffer.seek(0)
+        importer = PackageImporter(buffer)
+
+        reader_one = importer.get_resource_reader('one')
+        with self.assertRaises(FileNotFoundError):
+            reader_one.resource_path('a.txt')
+
+        self.assertTrue(reader_one.is_resource('a.txt'))
+        self.assertEqual(reader_one.open_resource('a.txt').getbuffer(), b'hello, a!')
+        self.assertFalse(reader_one.is_resource('three'))
+        reader_one_contents = list(reader_one.contents())
+        self.assertSequenceEqual(reader_one_contents, ['a.txt', 'b.txt', 'c.txt', 'three'])
+
+        reader_two = importer.get_resource_reader('two')
+        self.assertTrue(reader_two.is_resource('f.txt'))
+        self.assertEqual(reader_two.open_resource('f.txt').getbuffer(), b'hello, f!')
+        reader_two_contents = list(reader_two.contents())
+        self.assertSequenceEqual(reader_two_contents, ['f.txt', 'g.txt'])
+
+        reader_one_three = importer.get_resource_reader('one.three')
+        self.assertTrue(reader_one_three.is_resource('d.txt'))
+        self.assertEqual(reader_one_three.open_resource('d.txt').getbuffer(), b'hello, d!')
+        reader_one_three_contenst = list(reader_one_three.contents())
+        self.assertSequenceEqual(reader_one_three_contenst, ['d.txt', 'e.txt'])
+
+        self.assertIsNone(importer.get_resource_reader('nonexistent_package'))
+
+    def test_package_resource_access(self):
+        """Packaged modules should be able to use the importlib.resources API to access
+        resources saved in the package.
+        """
+        mod_src = """\
+        import importlib.resources
+        import my_cool_resources
+
+        def secret_message():
+            return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
+        """
+        buffer = BytesIO()
+        with PackageExporter(buffer, verbose=False) as pe:
+            pe.save_source_string("foo.bar", dedent(mod_src))
+            pe.save_text('my_cool_resources', 'sekrit.txt', 'my sekrit plays')
+
+        buffer.seek(0)
+        importer = PackageImporter(buffer)
+        self.assertEqual(importer.import_module('foo.bar').secret_message(), 'my sekrit plays')
+
+
 class ManglingTest(TestCase):
     def test_unique_manglers(self):
         """
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index acaae68..069b72d 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -992,6 +992,11 @@
             return py::bytes(reinterpret_cast<const char*>(data.get()), size);
           })
       .def(
+          "has_record",
+          [](PyTorchStreamReader& self, const std::string& key) {
+            return self.hasRecord(key);
+          })
+      .def(
           "get_storage_from_record",
           [](PyTorchStreamReader& self,
              const std::string& key,
diff --git a/torch/package/importer.py b/torch/package/importer.py
index 424775a..238388b 100644
--- a/torch/package/importer.py
+++ b/torch/package/importer.py
@@ -184,14 +184,3 @@
             raise last_err
         else:
             raise ModuleNotFoundError(module_name)
-
-    def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
-        last_err = None
-        for importer in self._importers:
-            try:
-                return importer.get_name(obj, name)
-            except ObjNotFoundError as err:
-                last_err = err
-
-        assert last_err is not None
-        raise last_err
diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py
index 9c8269c..b30419b 100644
--- a/torch/package/package_importer.py
+++ b/torch/package/package_importer.py
@@ -270,6 +270,17 @@
         module = self.import_module(demangle(module_name))
         return self.zip_reader.get_record(demangle(module.__file__)).decode('utf-8')
 
+    # note: named `get_resource_reader` so that importlib.resources can find it.
+    # This is otherwise considered an internal method.
+    def get_resource_reader(self, fullname):
+        try:
+            package = self._get_package(fullname)
+        except ImportError:
+            return None
+        if package.__loader__ is not self:
+            return None
+        return _PackageResourceReader(self, fullname)
+
     def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
         if not parent:
             return
@@ -413,12 +424,15 @@
             else:
                 return module
 
-    def _zipfile_path(self, package, resource):
+    def _zipfile_path(self, package, resource=None):
         package = self._get_package(package)
-        resource = _normalize_path(resource)
         assert package.__loader__ is self
         name = demangle(package.__name__)
-        return f"{name.replace('.', '/')}/{resource}"
+        if resource is not None:
+            resource = _normalize_path(resource)
+            return f"{name.replace('.', '/')}/{resource}"
+        else:
+            return f"{name.replace('.', '/')}"
 
     def _get_or_create_package(self, atoms: List[str]) -> 'Union[_PackageNode, _ExternNode]':
         cur = self.root
@@ -488,3 +502,50 @@
             return _package_imported_modules[object.__module__].__file__
     return _orig_getfile(object)
 inspect.getfile = patched_getfile
+
+
+class _PackageResourceReader:
+    """Private class used to support PackageImporter.get_resource_reader().
+
+    Confirms to the importlib.abc.ResourceReader interface. Allowed to access
+    the innards of PackageImporter.
+    """
+    def __init__(self, importer, fullname):
+        self.importer = importer
+        self.fullname = fullname
+
+    def open_resource(self, resource):
+        from io import BytesIO
+        return BytesIO(self.importer.load_binary(self.fullname, resource))
+
+    def resource_path(self, resource):
+        # The contract for resource_path is that it either returns a concrete
+        # file system path or raises FileNotFoundError.
+        raise FileNotFoundError
+
+    def is_resource(self, name):
+        path = self.importer._zipfile_path(self.fullname, name)
+        return self.importer.zip_reader.has_record(path)
+
+    def contents(self):
+        from pathlib import Path
+        filename = self.fullname.replace('.', '/')
+
+        fullname_path = Path(self.importer._zipfile_path(self.fullname))
+        files = self.importer.zip_reader.get_all_records()
+        subdirs_seen = set()
+        for filename in files:
+            try:
+                relative = Path(filename).relative_to(fullname_path)
+            except ValueError:
+                continue
+            # If the path of the file (which is relative to the top of the zip
+            # namespace), relative to the package given when the resource
+            # reader was created, has a parent, then it's a name in a
+            # subdirectory and thus we skip it.
+            parent_name = relative.parent.name
+            if len(parent_name) == 0:
+                yield relative.name
+            elif parent_name not in subdirs_seen:
+                subdirs_seen.add(parent_name)
+                yield parent_name