Add support for requirement extras (#28)
diff --git a/extract_wheels/__init__.py b/extract_wheels/__init__.py
index ac3d72c..390dcd0 100644
--- a/extract_wheels/__init__.py
+++ b/extract_wheels/__init__.py
@@ -11,7 +11,7 @@
import subprocess
import sys
-from extract_wheels.lib import bazel
+from extract_wheels.lib import bazel, requirements
def configure_reproducible_wheels() -> None:
@@ -70,8 +70,10 @@
[sys.executable, "-m", "pip", "wheel", "-r", args.requirements]
)
+ extras = requirements.parse_extras(args.requirements)
+
targets = [
- '"%s%s"' % (args.repo, bazel.extract_wheel(whl, []))
+ '"%s%s"' % (args.repo, bazel.extract_wheel(whl, extras))
for whl in glob.glob("*.whl")
]
diff --git a/extract_wheels/lib/BUILD b/extract_wheels/lib/BUILD
index 1ff0e45..ef6e0ab 100644
--- a/extract_wheels/lib/BUILD
+++ b/extract_wheels/lib/BUILD
@@ -7,6 +7,7 @@
"bazel.py",
"namespace_pkgs.py",
"purelib.py",
+ "requirements.py",
"wheel.py",
],
deps = [
@@ -26,3 +27,15 @@
":lib",
],
)
+
+py_test(
+ name = "requirements_test",
+ size = "small",
+ srcs = [
+ "requirements_test.py",
+ ],
+ tags = ["unit"],
+ deps = [
+ ":lib",
+ ],
+)
diff --git a/extract_wheels/lib/bazel.py b/extract_wheels/lib/bazel.py
index ed40038..2d4c61e 100644
--- a/extract_wheels/lib/bazel.py
+++ b/extract_wheels/lib/bazel.py
@@ -1,7 +1,7 @@
"""Utility functions to manipulate Bazel files"""
import os
import textwrap
-from typing import Iterable, List
+from typing import Iterable, List, Dict, Set
from extract_wheels.lib import namespace_pkgs, wheel, purelib
@@ -113,7 +113,7 @@
namespace_pkgs.add_pkgutil_style_namespace_pkg_init(ns_pkg_dir)
-def extract_wheel(wheel_file: str, extras: List[str]) -> str:
+def extract_wheel(wheel_file: str, extras: Dict[str, Set[str]]) -> str:
"""Extracts wheel into given directory and creates a py_library target.
Args:
@@ -134,16 +134,17 @@
purelib.spread_purelib_into_root(directory)
setup_namespace_pkg_compatibility(directory)
+ extras_requested = extras[whl.name] if whl.name in extras else set()
+
+ sanitised_dependencies = [
+ '"//%s"' % sanitise_name(d) for d in sorted(whl.dependencies(extras_requested))
+ ]
+
with open(os.path.join(directory, "BUILD"), "w") as build_file:
- build_file.write(
- generate_build_file_contents(
- sanitise_name(whl.name),
- [
- '"//%s"' % sanitise_name(d)
- for d in sorted(whl.dependencies(extras_requested=extras))
- ],
- )
+ contents = generate_build_file_contents(
+ sanitise_name(whl.name), sanitised_dependencies,
)
+ build_file.write(contents)
os.remove(whl.path)
diff --git a/extract_wheels/lib/namespace_pkgs_test.py b/extract_wheels/lib/namespace_pkgs_test.py
index 658625f..d4e2a4c 100644
--- a/extract_wheels/lib/namespace_pkgs_test.py
+++ b/extract_wheels/lib/namespace_pkgs_test.py
@@ -1,7 +1,5 @@
-import os
import pathlib
import shutil
-import sys
import tempfile
import unittest
@@ -133,17 +131,5 @@
self.assertEqual(actual, set())
-def main():
- loader = unittest.TestLoader()
- cur_dir = os.path.dirname(os.path.realpath(__file__))
-
- suite = loader.discover(cur_dir)
-
- runner = unittest.TextTestRunner()
- result = runner.run(suite)
- if result.errors or result.failures:
- sys.exit(1)
-
-
if __name__ == "__main__":
- main()
+ unittest.main()
diff --git a/extract_wheels/lib/requirements.py b/extract_wheels/lib/requirements.py
new file mode 100644
index 0000000..e246379
--- /dev/null
+++ b/extract_wheels/lib/requirements.py
@@ -0,0 +1,45 @@
+import re
+from typing import Dict, Set, Tuple, Optional
+
+
+def parse_extras(requirements_path: str) -> Dict[str, Set[str]]:
+ """Parse over the requirements.txt file to find extras requested.
+
+ Args:
+ requirements_path: The filepath for the requirements.txt file to parse.
+
+ Returns:
+ A dictionary mapping the requirement name to a set of extras requested.
+ """
+
+ extras_requested = {}
+ with open(requirements_path, "r") as requirements:
+ # Merge all backslash line continuations so we parse each requirement as a single line.
+ for line in requirements.read().replace("\\\n", "").split("\n"):
+ requirement, extras = _parse_requirement_for_extra(line)
+ if requirement and extras:
+ extras_requested[requirement] = extras
+
+ return extras_requested
+
+
+def _parse_requirement_for_extra(
+ requirement: str,
+) -> Tuple[Optional[str], Optional[Set[str]]]:
+ """Given a requirement string, returns the requirement name and set of extras, if extras specified.
+ Else, returns (None, None)
+ """
+
+ # https://www.python.org/dev/peps/pep-0508/#grammar
+ extras_pattern = re.compile(
+ r"^\s*([0-9A-Za-z][0-9A-Za-z_.\-]*)\s*\[\s*([0-9A-Za-z][0-9A-Za-z_.\-]*(?:\s*,\s*[0-9A-Za-z][0-9A-Za-z_.\-]*)*)\s*\]"
+ )
+
+ matches = extras_pattern.match(requirement)
+ if matches:
+ return (
+ matches.group(1),
+ {extra.strip() for extra in matches.group(2).split(",")},
+ )
+
+ return None, None
diff --git a/extract_wheels/lib/requirements_test.py b/extract_wheels/lib/requirements_test.py
new file mode 100644
index 0000000..2b96a75
--- /dev/null
+++ b/extract_wheels/lib/requirements_test.py
@@ -0,0 +1,32 @@
+import unittest
+
+from extract_wheels.lib import requirements
+
+
+class TestRequirementExtrasParsing(unittest.TestCase):
+ def test_parses_requirement_for_extra(self) -> None:
+ cases = [
+ ("name[foo]", ("name", frozenset(["foo"]))),
+ ("name[ Foo123 ]", ("name", frozenset(["Foo123"]))),
+ (" name1[ foo ] ", ("name1", frozenset(["foo"]))),
+ (
+ "name [fred,bar] @ http://foo.com ; python_version=='2.7'",
+ ("name", frozenset(["fred", "bar"])),
+ ),
+ (
+ "name[quux, strange];python_version<'2.7' and platform_version=='2'",
+ ("name", frozenset(["quux", "strange"])),
+ ),
+ ("name; (os_name=='a' or os_name=='b') and os_name=='c'", (None, None),),
+ ("name@http://foo.com", (None, None),),
+ ]
+
+ for case, expected in cases:
+ with self.subTest():
+ self.assertTupleEqual(
+ requirements._parse_requirement_for_extra(case), expected
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extract_wheels/lib/wheel.py b/extract_wheels/lib/wheel.py
index 6737ee8..5fe74a1 100644
--- a/extract_wheels/lib/wheel.py
+++ b/extract_wheels/lib/wheel.py
@@ -2,7 +2,7 @@
import glob
import os
import zipfile
-from typing import Dict, Optional, List, Set
+from typing import Dict, Optional, Set
import pkg_resources
import pkginfo
@@ -26,7 +26,7 @@
def metadata(self) -> pkginfo.Wheel:
return pkginfo.get_metadata(self.path)
- def dependencies(self, extras_requested: Optional[List[str]] = None) -> Set[str]:
+ def dependencies(self, extras_requested: Optional[Set[str]] = None) -> Set[str]:
dependency_set = set()
for wheel_req in self.metadata.requires_dist: