Allow structseq to be input of operators where tuple is expected (#17208)

Summary:
Currently the following code gives an error on python 2 because `ret` is a structseq which is not a tuple
```python
ret = a.max(dim=0)
ret1 = torch.max(a, dim=0, out=ret)
```

This PR modify tuple check in python arg parser to allow structseq to be input of operators where tuple is expected, which would make the above code work.

Depend on: https://github.com/pytorch/pytorch/pull/17136
Partially fixes: https://github.com/pytorch/pytorch/issues/16813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17208

Differential Revision: D14280198

Pulled By: VitalyFedyunin

fbshipit-source-id: beffebfd3951c4f5c7c8fe99a5847616a89491f3
diff --git a/test/test_torch.py b/test/test_torch.py
index b42dd9e..ee2aa26 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -7157,7 +7157,7 @@
             ret = getattr(a, f)(dim=0)
             self.assertEqual(ret.values, ret[0])
             self.assertEqual(ret.indices, ret[1])
-            ret1 = getattr(torch, f)(a, dim=0, out=tuple(ret))
+            ret1 = getattr(torch, f)(a, dim=0, out=ret)
             self.assertEqual(ret1.values, ret1[0])
             self.assertEqual(ret1.indices, ret1[1])
             self.assertEqual(ret1.values, ret[0])
@@ -7167,7 +7167,7 @@
         ret = a.kthvalue(1, dim=0)
         self.assertEqual(ret.values, ret[0])
         self.assertEqual(ret.indices, ret[1])
-        ret1 = torch.kthvalue(a, 1, dim=0, out=tuple(ret))
+        ret1 = torch.kthvalue(a, 1, dim=0, out=ret)
         self.assertEqual(ret1.values, ret1[0])
         self.assertEqual(ret1.indices, ret1[1])
         self.assertEqual(ret1.values, ret[0])
@@ -7178,7 +7178,7 @@
         self.assertEqual(ret.U, ret[0])
         self.assertEqual(ret.S, ret[1])
         self.assertEqual(ret.V, ret[2])
-        ret1 = torch.svd(a, out=tuple(ret))
+        ret1 = torch.svd(a, out=ret)
         self.assertEqual(ret1.U, ret1[0])
         self.assertEqual(ret1.S, ret1[1])
         self.assertEqual(ret1.V, ret1[2])
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index fe96e09..63e4625 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -131,7 +131,7 @@
       }
       return false;
     }
-    case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
+    case ParameterType::TENSOR_LIST: return six::isTuple(obj) || PyList_Check(obj);
     case ParameterType::INT_LIST: {
       if (PyTuple_Check(obj) || PyList_Check(obj)) {
         return true;
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 121c63d..a7b79be 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -55,6 +55,7 @@
 #include <torch/csrc/utils/object_ptr.h>
 #include <torch/csrc/utils/python_numbers.h>
 #include <torch/csrc/utils/python_strings.h>
+#include <torch/csrc/utils/six.h>
 #include <torch/csrc/autograd/variable.h>
 
 #include <ATen/ATen.h>
@@ -249,12 +250,12 @@
 
 inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
   if (!args[i]) return std::vector<at::Tensor>();
-  PyObject* arg = args[i];
-  auto tuple = PyTuple_Check(arg);
-  auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
+  auto tuple = six::isTuple(args[i]);
+  THPObjectPtr arg = six::maybeAsTuple(args[i]);
+  auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
   std::vector<at::Tensor> res(size);
   for (int idx = 0; idx < size; idx++) {
-    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
+    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
     if (!THPVariable_Check(obj)) {
       throw TypeError("expected Tensor as element %d in argument %d, but got %s",
                  idx, i, Py_TYPE(obj)->tp_name);
@@ -267,15 +268,15 @@
 template<int N>
 inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
   auto res = std::array<at::Tensor, N>();
-  PyObject* arg = args[i];
-  if (!arg) return res;
-  auto tuple = PyTuple_Check(arg);
-  auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
+  if (!args[i]) return res;
+  auto tuple = six::isTuple(args[i]);
+  THPObjectPtr arg = six::maybeAsTuple(args[i]);
+  auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
   if (size != N) {
     throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
   }
   for (int idx = 0; idx < size; idx++) {
-    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
+    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
     if (!THPVariable_Check(obj)) {
       throw TypeError("expected Tensor as element %d in argument %d, but got %s",
                  idx, i, Py_TYPE(obj)->tp_name);
diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h
index 957cc20..932f0bf 100644
--- a/torch/csrc/utils/six.h
+++ b/torch/csrc/utils/six.h
@@ -1,7 +1,8 @@
 #pragma once
 
 #include <pybind11/pybind11.h>
-#include "torch/csrc/utils/structseq.h"
+#include <torch/csrc/utils/structseq.h>
+#include <torch/csrc/utils/object_ptr.h>
 
 namespace six {
 
@@ -10,12 +11,20 @@
 // the name of the type to determine if it is a namedtupled returned
 // by a pytorch operator.
 
+inline bool isStructSeq(pybind11::handle input) {
+  return pybind11::cast<std::string>(input.get_type().attr("__module__")) == "torch.return_types";
+}
+
+inline bool isStructSeq(PyObject* obj) {
+  return isStructSeq(pybind11::handle(obj));
+}
+
 inline bool isTuple(pybind11::handle input) {
   if (PyTuple_Check(input.ptr())) {
     return true;
   }
 #if PY_MAJOR_VERSION == 2
-  return pybind11::cast<std::string>(input.get_type().attr("__module__")) == "torch.return_types";
+  return isStructSeq(input);
 #else
   return false;
 #endif
@@ -25,15 +34,25 @@
   return isTuple(pybind11::handle(obj));
 }
 
-inline PyObject *toTuple(PyStructSequence *obj) {
-  // create a new tuple object on python 2, or increase
-  // the ref count of the current object on python 3.
+// maybeAsTuple: if the input is a structseq, then convert it to a tuple
+//
+// On Python 3, structseq is a subtype of tuple, so these APIs could be used directly.
+// But on Python 2, structseq is not a subtype of tuple, so we need to manually create a
+// new tuple object from structseq.
+inline THPObjectPtr maybeAsTuple(PyStructSequence *obj) {
 #if PY_MAJOR_VERSION == 2
-  return torch::utils::structseq_slice(obj, 0, Py_SIZE(obj));
+  return THPObjectPtr(torch::utils::structseq_slice(obj, 0, Py_SIZE(obj)));
 #else
   Py_INCREF(obj);
-  return (PyObject *)obj;
+  return THPObjectPtr((PyObject *)obj);
 #endif
 }
 
+inline THPObjectPtr maybeAsTuple(PyObject *obj) {
+  if (isStructSeq(obj))
+    return maybeAsTuple((PyStructSequence *)obj);
+  Py_INCREF(obj);
+  return THPObjectPtr(obj);
+}
+
 }  // namespace six
diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp
index 0bf4adb..65fede2 100644
--- a/torch/csrc/utils/structseq.cpp
+++ b/torch/csrc/utils/structseq.cpp
@@ -52,7 +52,7 @@
 
 PyObject *returned_structseq_repr(PyStructSequence *obj) {
     PyTypeObject *typ = Py_TYPE(obj);
-    PyObject *tup = six::toTuple(obj);
+    THPObjectPtr tup = six::maybeAsTuple(obj);
     if (tup == nullptr) {
         return nullptr;
     }
@@ -69,26 +69,22 @@
         if (cname == nullptr) {
             PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr"
                          " for type %.500s", i, typ->tp_name);
-            Py_DECREF(tup);
             return nullptr;
         }
 
-        val = PyTuple_GetItem(tup, i);
+        val = PyTuple_GetItem(tup.get(), i);
         if (val == nullptr) {
-            Py_DECREF(tup);
             return nullptr;
         }
 
         repr = PyObject_Repr(val);
         if (repr == nullptr) {
-            Py_DECREF(tup);
             return nullptr;
         }
 
         crepr = PyUnicode_AsUTF8(repr);
         Py_DECREF(repr);
         if (crepr == nullptr) {
-            Py_DECREF(tup);
             return nullptr;
         }
 
@@ -99,7 +95,6 @@
     }
     ss << ")";
 
-    Py_DECREF(tup);
     return PyUnicode_FromString(ss.str().c_str());
 }