fix external codegen kernel error checking (#85029)
Fixes https://github.com/pytorch/pytorch/issues/84987. I followed the repro steps from the issue (changed `empty_symint` to `empty_symint2` and confirmed that and error gets raised.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85029
Approved by: https://github.com/ezyang
diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 377db49..8d54b8c 100644
--- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -3,6 +3,7 @@
import os
import tempfile
import unittest
+from typing import Optional
import expecttest
from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401
@@ -25,12 +26,20 @@
fp.flush()
run(fp.name, "", True)
- def get_errors_from_gen_backend_stubs(self, yaml_str: str) -> str:
+ def get_errors_from_gen_backend_stubs(
+ self, yaml_str: str, *, kernels_str: Optional[str] = None
+ ) -> str:
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_str)
fp.flush()
try:
- run(fp.name, "", True)
+ if kernels_str is None:
+ run(fp.name, "", True)
+ else:
+ with tempfile.NamedTemporaryFile(mode="w") as kernel_file:
+ kernel_file.write(kernels_str)
+ kernel_file.flush()
+ run(fp.name, "", True, impl_path=kernel_file.name)
except AssertionError as e:
# Scrub out the temp file name from any error messages to simplify assertions.
return str(e).replace(fp.name, "")
@@ -269,6 +278,34 @@
"""You must provide either True or False for device_guard. Provided: frue""",
) # noqa: B950
+ def test_incorrect_kernel_name(self) -> None:
+ yaml_str = """\
+backend: XLA
+cpp_namespace: torch_xla
+supported:
+- abs
+autograd:
+- add.Tensor"""
+ # Codegen will expect two kernel names (and try to parse them with regex):
+ # XLANativeFunctions::abs(...)
+ # XLANativeFunctions::add(...)
+ kernels_str = """\
+at::Tensor& XLANativeFunctions::absWRONG(at::Tensor& self) {}
+at::Tensor& XLANativeFunctions::add(at::Tensor& self) {}"""
+ output_error = self.get_errors_from_gen_backend_stubs(
+ yaml_str, kernels_str=kernels_str
+ )
+ self.assertExpectedInline(
+ output_error,
+ """\
+
+XLANativeFunctions is missing a kernel definition for abs. We found 0 kernel(s) with that name,
+but expected 1 kernel(s). The expected function schemas for the missing operator are:
+at::Tensor abs(const at::Tensor & self)
+
+""",
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index 447d9a7..aecc187 100644
--- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -270,30 +270,49 @@
if full_codegen is None:
full_codegen = []
- expected_backend_op_names: List[OperatorName] = (
- list(backend_indices[backend_key].index.keys()) + []
- if autograd_key is None
- else list(backend_indices[autograd_key].index.keys())
+ indices = [backend_indices[backend_key].index] + (
+ [] if autograd_key is None else [backend_indices[autograd_key].index]
+ )
+ # Quick mapping from each OperatorName used by the external backend
+ # to its backend kernel name
+ expected_backend_op_names: Dict[OperatorName, str] = dict(
+ list(
+ concatMap(
+ lambda index: [
+ (op_name, metadata.kernel) for op_name, metadata in index.items()
+ ],
+ indices,
+ )
+ )
)
expected_backend_native_funcs: List[NativeFunction] = [
f
for f in native_functions
- if f.func.name in expected_backend_op_names and f.func.name not in full_codegen
+ if f.func.name in expected_backend_op_names.keys()
+ and f.func.name not in full_codegen
]
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict(
list
)
for native_f in expected_backend_native_funcs:
- expected_backend_kernel_name_counts[dispatcher.name(native_f.func)].append(
- native_f
- )
+ expected_backend_kernel_name_counts[
+ expected_backend_op_names[native_f.func.name]
+ ].append(native_f)
# This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
# It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
# here, then we get a nicer error message. If we miss it, you get a linker error.
- kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\("
+ kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
actual_backend_kernel_name_counts = Counter(
- re.findall(kernel_defn_regex, backend_defns)
+ # A bit unwieldy (this could probably be moved into regex),
+ # but we don't want to include kernel names that come from function calls,
+ # like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
+ # Easy check is to ignore any lines with colons before the class name.
+ [
+ y
+ for (x, y) in re.findall(kernel_defn_regex, backend_defns)
+ if not x.endswith(":")
+ ]
)
missing_kernels_err_msg = ""