[JIT] Improve source attribution for NamedTuple type inference (#95761)
Most errors thrown during torchscript scripting or execution have a SourceRange attached that can be used to identify where the error is coming from. NamedTuple type inference previously didn't have SourceRanges attached; this PR adds them.
Differential Revision: [D43685662](https://our.internmc.facebook.com/intern/diff/D43685662)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95761
Approved by: https://github.com/eellison
diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py
index fd0187a..a461199 100644
--- a/test/jit/test_typing.py
+++ b/test/jit/test_typing.py
@@ -4,10 +4,10 @@
import sys
import torch
-from torch.testing._internal.jit_utils import JitTestCase
+from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing._internal.common_utils import IS_WINDOWS
from collections import namedtuple
-from typing import List, Tuple, Optional, Dict
+from typing import List, Tuple, Optional, Dict, NamedTuple
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -592,3 +592,18 @@
with self.assertRaisesRegex(RuntimeError,
r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'):
out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5"))
+
+ def test_namedtuple_error_source_attribution(self):
+ class _NamedTupleBadMemberType(NamedTuple):
+ f1: torch.Tensor
+ f2: "ABadForwardRefType"
+
+ make_global(_NamedTupleBadMemberType) # see [local resolution in python]
+
+ def fn(x: _NamedTupleBadMemberType) -> torch.Tensor:
+ return x.f1.relu()
+
+ # assert that this has a location associated with the error.
+ # note the " +" is regex (i.e. "at least one space")
+ with self.assertRaisesRegex(ValueError, "at +File"):
+ torch.jit.script(fn)
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 830b740..868f3d9 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -1198,7 +1198,12 @@
return boolean_dispatched.get(fn)
-def _get_named_tuple_properties(obj):
+def _get_named_tuple_properties(
+ obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None
+):
+ if loc is None:
+ loc = fake_range()
+
assert issubclass(obj, tuple) and hasattr(obj, "_fields")
if hasattr(obj, "_field_defaults"):
defaults = [
@@ -1220,9 +1225,7 @@
annotations = []
for field in obj._fields:
if field in obj_annotations:
- the_type = torch.jit.annotations.ann_to_type(
- obj_annotations[field], fake_range()
- )
+ the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], loc)
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.getInferred())
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index 8372190..33dbf57 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -1012,7 +1012,7 @@
py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
py::object props = py::module::import("torch._jit_internal")
- .attr("_get_named_tuple_properties")(obj);
+ .attr("_get_named_tuple_properties")(obj, loc);
std::string unqualName;
std::vector<std::string> field_names;