Follow-up for pytorch/pytorch#37091. (#42806)
Summary:
This is a follow-up PR for https://github.com/pytorch/pytorch/issues/37091, fixing some of the quirks of that PR as that one was landed early to avoid merge conflicts.
This PR addresses the following action items:
- [x] Use error-handling macros instead of a `try`-`catch`.
- [x] Renamed and added comments to clarify the use of `HANDLED_FUNCTIONS_WRAPPERS` in tests. `HANDLED_FUNCTIONS_NAMESPACES` was already removed in the last PR as we had a way to test for methods.
This PR does NOT address the following action item, as it proved to be difficult:
- [ ] Define `__module__` for whole API.
Single-line repro-er for why this is hard:
```python
>>> torch.Tensor.grad.__get__.__module__ = "torch.Tensor.grad"
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'method-wrapper' object has no attribute '__module__'
```
Explanation: Methods defined in C/properties don't always have a `__dict__` attribute or a mutable `__module__` slot for us to modify.
The documentation action items were addressed in the following commit, with the additional future task of adding the rendered RFCs to the documentation: https://github.com/pytorch/rfcs/pull/3/commits/552ba37c0500f70b8738522591b276c82cb7ca2a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42806
Reviewed By: smessmer
Differential Revision: D23031501
Pulled By: ezyang
fbshipit-source-id: b781c97f7840b8838ede50a0017b4327f96bc98a
diff --git a/test/test_overrides.py b/test/test_overrides.py
index 7d058ea..0bf8478 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -285,7 +285,14 @@
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
-HANDLED_FUNCTIONS_WRAPPERS = {}
+
+
+# Note: _triggered wrapper
+# Dict that wraps the implementations from get_testing_overrides into another
+# function with a _triggered slot/flag. The triggered flag is set when the
+# implementation is called.
+WRAPPED_TRIGGERED_IMPLS = {}
+
def triggered_wrapper(f):
@functools.wraps(f)
@@ -324,7 +331,8 @@
# decorate the overrides with implements_tensor_like if it's not a
# torch.Tensor method
wrapped = triggered_wrapper(override)
- HANDLED_FUNCTIONS_WRAPPERS[func] = wrapped
+ # See note: "_triggered wrapper"
+ WRAPPED_TRIGGERED_IMPLS[func] = wrapped
if is_tensor_method_or_property(func):
implements_sub(func)(wrapped)
else:
@@ -549,6 +557,7 @@
t = t[:-1]
if t == 'Tensor':
if arg['name'] == 'self' and is_tensor_method_or_property(func):
+ # See "Note: properties and __get__"
func = func.__get__(instance_gen())
continue
func_args.append(instance_gen())
@@ -590,8 +599,9 @@
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
# This is currently the best check but doesn't work for, for example,
# Tensor.__add__ because it redirects to Tensor.add.
+ # See note "_triggered wrapper"
if ret is None:
- self.assertTrue(HANDLED_FUNCTIONS_WRAPPERS[func]._triggered)
+ self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
return
self.assertEqual(ret, -1)
@@ -601,6 +611,7 @@
for func, override in get_testing_overrides().items():
test_method = test_generator(func, override)
if func.__name__ == "__get__":
+ # Note: properties and __get__
# __get__ is part of the descriptor protocol.
# https://docs.python.org/3/howto/descriptor.html
# This is used for properties of the form
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 731e5fd..38cc596 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -325,12 +325,9 @@
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
- try {
- return handle_torch_function(self, "__bool__");
- }
- catch(const python_error&) {
- return nullptr;
- }
+ HANDLE_TH_ERRORS
+ return handle_torch_function(self, "__bool__");
+ END_HANDLE_TH_ERRORS
}
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
@@ -1016,12 +1013,9 @@
static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
if (check_has_torch_function(self)) {
- try {
- return handle_torch_function(self, "__bool__");
- }
- catch(const python_error&) {
- return nullptr;
- }
+ HANDLE_TH_ERRORS
+ return handle_torch_function(self, "__bool__");
+ END_HANDLE_TH_ERRORS
}
jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW);
return THPVariable_is_nonzero(self, args);
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index e727177..5554922 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -335,12 +335,9 @@
PyObject *THPVariable_get_volatile(THPVariable *self, void *unused)
{
if (check_has_torch_function((PyObject *)self)) {
- try {
- return handle_torch_function_getter(self, "volatile");
- }
- catch (const python_error&) {
- return nullptr;
- }
+ HANDLE_TH_ERRORS
+ return handle_torch_function_getter(self, "volatile");
+ END_HANDLE_TH_ERRORS
}
const char* msg = "volatile was removed (Variable.volatile is always False)";
PyErr_WarnEx(PyExc_UserWarning, msg, 1);
@@ -350,12 +347,9 @@
int THPVariable_set_volatile(THPVariable *self, PyObject *obj, void *unused)
{
if (check_has_torch_function((PyObject *)self)) {
- try {
- return handle_torch_function_setter(self, "volatile", obj);
- }
- catch (const python_error&) {
- return -1;
- }
+ HANDLE_TH_ERRORS
+ return handle_torch_function_setter(self, "volatile", obj);
+ END_HANDLE_TH_ERRORS_RET(-1)
}
return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
}
@@ -469,12 +463,9 @@
PyObject *THPVariable_get_name(THPVariable* self, void *unused)
{
if (check_has_torch_function((PyObject *)self)) {
- try {
- return handle_torch_function_getter(self, "name");
- }
- catch (const python_error&) {
- return nullptr;
- }
+ HANDLE_TH_ERRORS
+ return handle_torch_function_getter(self, "name");
+ END_HANDLE_TH_ERRORS
}
if (self->cdata.name() == "")
Py_RETURN_NONE;