[dynamo][guards-cpp-refactor] PythonLambdaGuardAccessor (#120730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120730
Approved by: https://github.com/jansel
ghstack dependencies: #120864
diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py
index 3c365b4..d57271d 100644
--- a/test/dynamo/test_guard_manager.py
+++ b/test/dynamo/test_guard_manager.py
@@ -547,6 +547,33 @@
del x
self.assertFalse(guard_manager.check(None))
+ def test_lambda_manager(self):
+ a = (1, 1, 3, 4, 5, 6)
+
+ guard_manager = RootGuardManager()
+
+ # Check that we can use the same accessor
+ foo_mgr = guard_manager.lambda_manager(lambda x: x[2], None)
+ foo_mgr.add_lambda_guard(
+ lambda x: x == 3,
+ "Expected value 3",
+ )
+ self.assertTrue(guard_manager.check(a))
+
+ # test that exception works
+ guard_manager = RootGuardManager()
+
+ def fn(x):
+ raise AssertionError("Test")
+ return x
+
+ foo_mgr = guard_manager.lambda_manager(fn, None)
+
+ self.assertFalse(guard_manager.check(None))
+ debug_info = guard_manager.check_verbose(None)
+ self.assertFalse(debug_info.result)
+ self.assertTrue("Test" in debug_info.verbose_code_parts[0])
+
def test_dict_guard_manager(self):
root = RootGuardManager()
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index c9571c7..59e613c 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -700,6 +700,20 @@
-1,
_methods};
+std::string get_exception_message() {
+ PyObject *ptype, *pvalue, *ptraceback;
+ PyErr_Fetch(&ptype, &pvalue, &ptraceback);
+
+ PyObject* exc_message_pyobj = PyObject_Str(pvalue);
+ const char* exc_message = PyUnicode_AsUTF8(exc_message_pyobj);
+
+ Py_DECREF(exc_message_pyobj);
+ Py_XDECREF(ptype);
+ Py_XDECREF(pvalue);
+ Py_XDECREF(ptraceback);
+ return std::string(exc_message);
+}
+
/**
* Stores relevant guard debug information, e.g., failure str for a LeafGuard
* failure. The data structure is also accessible in Python.
@@ -842,6 +856,22 @@
return result;
}
+ GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
+ PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
+ if (x == nullptr) {
+ // An exception is caught in the lambda function.
+ std::string exc_message = get_exception_message();
+ PyErr_Clear();
+ return GuardDebugInfo(false, exc_message, 0);
+ }
+ bool result = PyObject_IsTrue(x);
+ Py_DECREF(x);
+ if (result) {
+ return GuardDebugInfo(true, 0);
+ }
+ return GuardDebugInfo(false, verbose_code_parts(), 0);
+ }
+
private:
// The user provided lambda function for check_fn.
py::function _guard_check_fn;
@@ -2422,6 +2452,55 @@
PyObject* _global_name;
};
+/**
+ * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to
+ * supply accessor as a python function. This is useful for from_numpy source.
+ */
+class PythonLambdaGuardAccessor : public GuardAccessor {
+ public:
+ PythonLambdaGuardAccessor(
+ RootGuardManager* root,
+ py::function accessor_fn,
+ py::handle example_value)
+ : GuardAccessor(root, accessor_fn, example_value),
+ _accessor_fn(accessor_fn) {}
+
+ // NB: Intentional duplication between check_nopybind and
+ // check_verbose_nopybind.
+ bool check_nopybind(PyObject* obj) override { // borrowed ref
+ PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
+ if (x == nullptr) {
+ // The accessor function failed.
+ PyErr_Clear();
+ return false;
+ }
+ bool result = _guard_manager->check_nopybind(x);
+ Py_DECREF(x);
+ return result;
+ }
+
+ GuardDebugInfo check_verbose_nopybind(
+ PyObject* obj) override { // borrowed ref
+ PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
+ if (x == nullptr) {
+ // The accessor function failed.
+ std::string exc_message = get_exception_message();
+ PyErr_Clear();
+ return GuardDebugInfo(false, exc_message, 0);
+ }
+ GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
+ Py_DECREF(x);
+ return result;
+ }
+
+ std::string repr() const override {
+ return "PythonLambdaGuardAccessor";
+ }
+
+ private:
+ py::object _accessor_fn;
+};
+
void install_tensor_aliasing_guard(
GuardManager* x,
GuardManager* y,
@@ -2844,6 +2923,12 @@
"global_weakref_manager",
&GuardManager::get_child_manager<GlobalWeakRefGuardAccessor>,
py::return_value_policy::reference)
+ // return by reference because GuardManager has the ownership of accessors
+ // and guard managers
+ .def(
+ "lambda_manager",
+ &GuardManager::get_child_manager<PythonLambdaGuardAccessor>,
+ py::return_value_policy::reference)
// return by reference because C++ GuardManager has the ownership of
// accessors and guard managers
.def(