[dynamo][guards-cpp-refactor] PythonLambdaGuardAccessor (#120730)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120730
Approved by: https://github.com/jansel
ghstack dependencies: #120864
diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py
index 3c365b4..d57271d 100644
--- a/test/dynamo/test_guard_manager.py
+++ b/test/dynamo/test_guard_manager.py
@@ -547,6 +547,33 @@
         del x
         self.assertFalse(guard_manager.check(None))
 
+    def test_lambda_manager(self):
+        a = (1, 1, 3, 4, 5, 6)
+
+        guard_manager = RootGuardManager()
+
+        # Check that we can use the same accessor
+        foo_mgr = guard_manager.lambda_manager(lambda x: x[2], None)
+        foo_mgr.add_lambda_guard(
+            lambda x: x == 3,
+            "Expected value 3",
+        )
+        self.assertTrue(guard_manager.check(a))
+
+        # test that exception works
+        guard_manager = RootGuardManager()
+
+        def fn(x):
+            raise AssertionError("Test")
+            return x
+
+        foo_mgr = guard_manager.lambda_manager(fn, None)
+
+        self.assertFalse(guard_manager.check(None))
+        debug_info = guard_manager.check_verbose(None)
+        self.assertFalse(debug_info.result)
+        self.assertTrue("Test" in debug_info.verbose_code_parts[0])
+
     def test_dict_guard_manager(self):
         root = RootGuardManager()
 
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index c9571c7..59e613c 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -700,6 +700,20 @@
     -1,
     _methods};
 
+std::string get_exception_message() {
+  PyObject *ptype, *pvalue, *ptraceback;
+  PyErr_Fetch(&ptype, &pvalue, &ptraceback);
+
+  PyObject* exc_message_pyobj = PyObject_Str(pvalue);
+  const char* exc_message = PyUnicode_AsUTF8(exc_message_pyobj);
+
+  Py_DECREF(exc_message_pyobj);
+  Py_XDECREF(ptype);
+  Py_XDECREF(pvalue);
+  Py_XDECREF(ptraceback);
+  return std::string(exc_message);
+}
+
 /**
  * Stores relevant guard debug information, e.g., failure str for a LeafGuard
  * failure. The data structure is also accessible in Python.
@@ -842,6 +856,22 @@
     return result;
   }
 
+  GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
+    PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
+    if (x == nullptr) {
+      // An exception is caught in the lambda function.
+      std::string exc_message = get_exception_message();
+      PyErr_Clear();
+      return GuardDebugInfo(false, exc_message, 0);
+    }
+    bool result = PyObject_IsTrue(x);
+    Py_DECREF(x);
+    if (result) {
+      return GuardDebugInfo(true, 0);
+    }
+    return GuardDebugInfo(false, verbose_code_parts(), 0);
+  }
+
  private:
   // The user provided lambda function for check_fn.
   py::function _guard_check_fn;
@@ -2422,6 +2452,55 @@
   PyObject* _global_name;
 };
 
+/**
+ * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to
+ * supply accessor as a python function. This is useful for from_numpy source.
+ */
+class PythonLambdaGuardAccessor : public GuardAccessor {
+ public:
+  PythonLambdaGuardAccessor(
+      RootGuardManager* root,
+      py::function accessor_fn,
+      py::handle example_value)
+      : GuardAccessor(root, accessor_fn, example_value),
+        _accessor_fn(accessor_fn) {}
+
+  // NB: Intentional duplication between check_nopybind and
+  // check_verbose_nopybind.
+  bool check_nopybind(PyObject* obj) override { // borrowed ref
+    PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
+    if (x == nullptr) {
+      // The accessor function failed.
+      PyErr_Clear();
+      return false;
+    }
+    bool result = _guard_manager->check_nopybind(x);
+    Py_DECREF(x);
+    return result;
+  }
+
+  GuardDebugInfo check_verbose_nopybind(
+      PyObject* obj) override { // borrowed ref
+    PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
+    if (x == nullptr) {
+      // The accessor function failed.
+      std::string exc_message = get_exception_message();
+      PyErr_Clear();
+      return GuardDebugInfo(false, exc_message, 0);
+    }
+    GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
+    Py_DECREF(x);
+    return result;
+  }
+
+  std::string repr() const override {
+    return "PythonLambdaGuardAccessor";
+  }
+
+ private:
+  py::object _accessor_fn;
+};
+
 void install_tensor_aliasing_guard(
     GuardManager* x,
     GuardManager* y,
@@ -2844,6 +2923,12 @@
           "global_weakref_manager",
           &GuardManager::get_child_manager<GlobalWeakRefGuardAccessor>,
           py::return_value_policy::reference)
+      // return by reference because GuardManager has the ownership of accessors
+      // and guard managers
+      .def(
+          "lambda_manager",
+          &GuardManager::get_child_manager<PythonLambdaGuardAccessor>,
+          py::return_value_policy::reference)
       // return by reference because C++ GuardManager has the ownership of
       // accessors and guard managers
       .def(