[JIT] Fix toIValue handling of AttributeError when casting ClassType (#49188)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49188
Test Plan: Imported from OSS
Reviewed By: pbelevich
Differential Revision: D25476573
Pulled By: jamesr66a
fbshipit-source-id: cec296fae71cc0cdf36bde60417d7d3b1aa84198
diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py
index b4075db..a80670f 100644
--- a/test/jit/test_class_type.py
+++ b/test/jit/test_class_type.py
@@ -959,6 +959,26 @@
# Make sure class constant is accessible from module
self.assertEqual(m.w, m_loaded.w)
+ def test_py_class_to_ivalue_missing_attribute(self):
+ global Foo # see [local resolution in python]
+
+ class Foo(object):
+ i : int
+ f : float
+
+ def __init__(self, i : int, f : float):
+ self.i = i
+ self.f = f
+
+ @torch.jit.script
+ def test_fn(x : Foo) -> float:
+ return x.i + x.f
+
+ test_fn(Foo(3, 4.0))
+
+ with self.assertRaisesRegex(RuntimeError, 'missing attribute i'):
+ test_fn(torch.rand(3, 4))
+
def test_unused_method(self):
"""
Test unused methods on scripted classes.
diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py
index af7897e..31eec81 100644
--- a/test/jit/test_torchbind.py
+++ b/test/jit/test_torchbind.py
@@ -240,6 +240,10 @@
traced = torch.jit.trace(TryTracing(), ())
self.assertEqual(torch.zeros(4, 4), traced())
+ def test_torchbind_pass_wrong_type(self):
+ with self.assertRaisesRegex(RuntimeError, 'missing attribute capsule'):
+ torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
+
def test_torchbind_tracing_nested(self):
class TryTracingNest(torch.nn.Module):
def __init__(self):
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index dc3b3b1..34ca758 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -713,6 +713,15 @@
const auto& attrType = classType->getAttribute(slot);
const auto& attrName = classType->getAttributeName(slot);
+ if (!py::hasattr(obj, attrName.c_str())) {
+ throw py::cast_error(c10::str(
+ "Tried to cast object to type ",
+ type->repr_str(),
+ " but object",
+ " was missing attribute ",
+ attrName));
+ }
+
const auto& contained = py::getattr(obj, attrName.c_str());
userObj->setSlot(slot, toIValue(contained, attrType));
}