[dynamo][guards-cpp-refactor] GetAttrGuardAccessor (#119833)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119833
Approved by: https://github.com/jansel
ghstack dependencies: #119822, #119827
diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py
index fc9dbe2..8a9094e 100644
--- a/test/dynamo/test_guard_manager.py
+++ b/test/dynamo/test_guard_manager.py
@@ -7,6 +7,7 @@
from torch._C._dynamo import guards
RootGuardManager = guards.RootGuardManager
+GetAttrGuardAccessor = guards.GetAttrGuardAccessor
def id_type(x):
@@ -137,6 +138,44 @@
self.assertFalse(guard_manager.check(4))
self.assertFalse(guard_manager.check("foo"))
+ def test_attr_guard_manager(self):
+ class Foo:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ foo = Foo(1, 2)
+ guard_manager = RootGuardManager()
+ guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
+ guard_manager.getattr_manager("x", 1).add_lambda_guard(
+ functools.partial(equals_match, expected=foo.x),
+ equals_match_verbose_code_parts(foo.x),
+ )
+ guard_manager.getattr_manager("y", 2).add_lambda_guard(
+ functools.partial(equals_match, expected=foo.y),
+ equals_match_verbose_code_parts(foo.y),
+ )
+ self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
+ # 2 child managers, one for x and one for y
+ self.assertEqual(len(guard_manager.get_accessors()), 2)
+ self.assertTrue(
+ isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor)
+ )
+ self.assertTrue(
+ isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor)
+ )
+ # Check leaf guards on child managers
+ self.assertEqual(
+ len(guard_manager.getattr_manager("x", None).get_leaf_guards()), 1
+ )
+ self.assertEqual(
+ len(guard_manager.getattr_manager("y", None).get_leaf_guards()), 1
+ )
+
+ self.assertTrue(guard_manager.check(foo))
+ self.assertFalse(guard_manager.check(Foo(3, 4)))
+ self.assertFalse(guard_manager.check("foo"))
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 5322801..46bb571 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -1353,6 +1353,60 @@
return std::make_unique<GuardManager>(root);
}
+/**
+ * Represents __getattr__ acccessor.
+ */
+class GetAttrGuardAccessor : public GuardAccessor {
+ public:
+ GetAttrGuardAccessor(
+ RootGuardManager* root,
+ py::str name,
+ py::handle example_value)
+ : GuardAccessor(root, name, example_value), _attr_name(name.ptr()) {}
+
+ // NB: Intentional duplication between check_nopybind and
+ // check_verbose_nopybind.
+ bool check_nopybind(PyObject* obj) override { // borrowed ref
+ PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
+ if (x == nullptr) {
+ // Attribute absent, clear the exception and return false.
+ PyErr_Clear();
+ return false;
+ }
+ bool result = _guard_manager->check_nopybind(x);
+ Py_DECREF(x);
+ return result;
+ }
+
+ GuardDebugInfo check_verbose_nopybind(
+ PyObject* obj) override { // borrowed ref
+ PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
+ if (x == nullptr) {
+ // Attribute absent, clear the exception and return false.
+ PyErr_Clear();
+ return GuardDebugInfo(
+ false,
+ std::string("get attr failed for attr name ") +
+ py::str(_attr_name).cast<std::string>(),
+ 0);
+ }
+ GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
+ Py_DECREF(x);
+ return result;
+ }
+
+ std::string repr() const override {
+ // Helpful when priting GuardManager tree structure.
+ return "GetAttrGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
+ ")";
+ }
+
+ private:
+ // no need of py::object here because the attr_name is already passed on to
+ // the base class as accessor_key which is a py::object.
+ PyObject* _attr_name;
+};
+
} // namespace
static void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
@@ -1458,6 +1512,10 @@
py::class_<GuardAccessor, std::unique_ptr<GuardAccessor>>(
py_m, "GuardAccessor")
.def("repr", &GuardAccessor::repr);
+ py::class_<
+ GetAttrGuardAccessor,
+ GuardAccessor,
+ std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor");
// Guard Manager - No constructor in python, python should use
// RootGuardManager.
@@ -1510,7 +1568,13 @@
py::object verbose_code_parts) -> void {
self.add_leaf_guard(
std::make_shared<EQUALS_MATCH>(value, verbose_code_parts));
- });
+ })
+ // return by reference because C++ GuardManager has the ownership of
+ // accessors and guard managers
+ .def(
+ "getattr_manager",
+ &GuardManager::get_child_manager<GetAttrGuardAccessor>,
+ py::return_value_policy::reference);
// Root Guard Manager
py::class_<RootGuardManager, GuardManager, std::unique_ptr<RootGuardManager>>(