Allow structseq to be input of operators where tuple is expected (#17208)
Summary:
Currently the following code gives an error on python 2 because `ret` is a structseq which is not a tuple
```python
ret = a.max(dim=0)
ret1 = torch.max(a, dim=0, out=ret)
```
This PR modify tuple check in python arg parser to allow structseq to be input of operators where tuple is expected, which would make the above code work.
Depend on: https://github.com/pytorch/pytorch/pull/17136
Partially fixes: https://github.com/pytorch/pytorch/issues/16813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17208
Differential Revision: D14280198
Pulled By: VitalyFedyunin
fbshipit-source-id: beffebfd3951c4f5c7c8fe99a5847616a89491f3
diff --git a/test/test_torch.py b/test/test_torch.py
index b42dd9e..ee2aa26 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -7157,7 +7157,7 @@
ret = getattr(a, f)(dim=0)
self.assertEqual(ret.values, ret[0])
self.assertEqual(ret.indices, ret[1])
- ret1 = getattr(torch, f)(a, dim=0, out=tuple(ret))
+ ret1 = getattr(torch, f)(a, dim=0, out=ret)
self.assertEqual(ret1.values, ret1[0])
self.assertEqual(ret1.indices, ret1[1])
self.assertEqual(ret1.values, ret[0])
@@ -7167,7 +7167,7 @@
ret = a.kthvalue(1, dim=0)
self.assertEqual(ret.values, ret[0])
self.assertEqual(ret.indices, ret[1])
- ret1 = torch.kthvalue(a, 1, dim=0, out=tuple(ret))
+ ret1 = torch.kthvalue(a, 1, dim=0, out=ret)
self.assertEqual(ret1.values, ret1[0])
self.assertEqual(ret1.indices, ret1[1])
self.assertEqual(ret1.values, ret[0])
@@ -7178,7 +7178,7 @@
self.assertEqual(ret.U, ret[0])
self.assertEqual(ret.S, ret[1])
self.assertEqual(ret.V, ret[2])
- ret1 = torch.svd(a, out=tuple(ret))
+ ret1 = torch.svd(a, out=ret)
self.assertEqual(ret1.U, ret1[0])
self.assertEqual(ret1.S, ret1[1])
self.assertEqual(ret1.V, ret1[2])
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index fe96e09..63e4625 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -131,7 +131,7 @@
}
return false;
}
- case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
+ case ParameterType::TENSOR_LIST: return six::isTuple(obj) || PyList_Check(obj);
case ParameterType::INT_LIST: {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
return true;
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 121c63d..a7b79be 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -55,6 +55,7 @@
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
+#include <torch/csrc/utils/six.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/ATen.h>
@@ -249,12 +250,12 @@
inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
if (!args[i]) return std::vector<at::Tensor>();
- PyObject* arg = args[i];
- auto tuple = PyTuple_Check(arg);
- auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
+ auto tuple = six::isTuple(args[i]);
+ THPObjectPtr arg = six::maybeAsTuple(args[i]);
+ auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
std::vector<at::Tensor> res(size);
for (int idx = 0; idx < size; idx++) {
- PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
+ PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
if (!THPVariable_Check(obj)) {
throw TypeError("expected Tensor as element %d in argument %d, but got %s",
idx, i, Py_TYPE(obj)->tp_name);
@@ -267,15 +268,15 @@
template<int N>
inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
auto res = std::array<at::Tensor, N>();
- PyObject* arg = args[i];
- if (!arg) return res;
- auto tuple = PyTuple_Check(arg);
- auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
+ if (!args[i]) return res;
+ auto tuple = six::isTuple(args[i]);
+ THPObjectPtr arg = six::maybeAsTuple(args[i]);
+ auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
if (size != N) {
throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
}
for (int idx = 0; idx < size; idx++) {
- PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
+ PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
if (!THPVariable_Check(obj)) {
throw TypeError("expected Tensor as element %d in argument %d, but got %s",
idx, i, Py_TYPE(obj)->tp_name);
diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h
index 957cc20..932f0bf 100644
--- a/torch/csrc/utils/six.h
+++ b/torch/csrc/utils/six.h
@@ -1,7 +1,8 @@
#pragma once
#include <pybind11/pybind11.h>
-#include "torch/csrc/utils/structseq.h"
+#include <torch/csrc/utils/structseq.h>
+#include <torch/csrc/utils/object_ptr.h>
namespace six {
@@ -10,12 +11,20 @@
// the name of the type to determine if it is a namedtupled returned
// by a pytorch operator.
+inline bool isStructSeq(pybind11::handle input) {
+ return pybind11::cast<std::string>(input.get_type().attr("__module__")) == "torch.return_types";
+}
+
+inline bool isStructSeq(PyObject* obj) {
+ return isStructSeq(pybind11::handle(obj));
+}
+
inline bool isTuple(pybind11::handle input) {
if (PyTuple_Check(input.ptr())) {
return true;
}
#if PY_MAJOR_VERSION == 2
- return pybind11::cast<std::string>(input.get_type().attr("__module__")) == "torch.return_types";
+ return isStructSeq(input);
#else
return false;
#endif
@@ -25,15 +34,25 @@
return isTuple(pybind11::handle(obj));
}
-inline PyObject *toTuple(PyStructSequence *obj) {
- // create a new tuple object on python 2, or increase
- // the ref count of the current object on python 3.
+// maybeAsTuple: if the input is a structseq, then convert it to a tuple
+//
+// On Python 3, structseq is a subtype of tuple, so these APIs could be used directly.
+// But on Python 2, structseq is not a subtype of tuple, so we need to manually create a
+// new tuple object from structseq.
+inline THPObjectPtr maybeAsTuple(PyStructSequence *obj) {
#if PY_MAJOR_VERSION == 2
- return torch::utils::structseq_slice(obj, 0, Py_SIZE(obj));
+ return THPObjectPtr(torch::utils::structseq_slice(obj, 0, Py_SIZE(obj)));
#else
Py_INCREF(obj);
- return (PyObject *)obj;
+ return THPObjectPtr((PyObject *)obj);
#endif
}
+inline THPObjectPtr maybeAsTuple(PyObject *obj) {
+ if (isStructSeq(obj))
+ return maybeAsTuple((PyStructSequence *)obj);
+ Py_INCREF(obj);
+ return THPObjectPtr(obj);
+}
+
} // namespace six
diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp
index 0bf4adb..65fede2 100644
--- a/torch/csrc/utils/structseq.cpp
+++ b/torch/csrc/utils/structseq.cpp
@@ -52,7 +52,7 @@
PyObject *returned_structseq_repr(PyStructSequence *obj) {
PyTypeObject *typ = Py_TYPE(obj);
- PyObject *tup = six::toTuple(obj);
+ THPObjectPtr tup = six::maybeAsTuple(obj);
if (tup == nullptr) {
return nullptr;
}
@@ -69,26 +69,22 @@
if (cname == nullptr) {
PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr"
" for type %.500s", i, typ->tp_name);
- Py_DECREF(tup);
return nullptr;
}
- val = PyTuple_GetItem(tup, i);
+ val = PyTuple_GetItem(tup.get(), i);
if (val == nullptr) {
- Py_DECREF(tup);
return nullptr;
}
repr = PyObject_Repr(val);
if (repr == nullptr) {
- Py_DECREF(tup);
return nullptr;
}
crepr = PyUnicode_AsUTF8(repr);
Py_DECREF(repr);
if (crepr == nullptr) {
- Py_DECREF(tup);
return nullptr;
}
@@ -99,7 +95,6 @@
}
ss << ")";
- Py_DECREF(tup);
return PyUnicode_FromString(ss.str().c_str());
}