Use TypeError in PythonArgParser (#4966)

Uses TypeError from torch/csrc/Exceptions.h in python_arg_parser.cpp so
that the exception is interpreted as a Python TypeError instead of
RuntimeError.
diff --git a/test/test_nn.py b/test/test_nn.py
index 6ac7d11..7ac837e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2149,7 +2149,7 @@
 
     def test_Conv2d_missing_argument(self):
         c = nn.Conv2d(3, 3, 3)
-        self.assertRaises(RuntimeError, lambda: c(None))
+        self.assertRaises(TypeError, lambda: c(None))
 
     def test_Conv2d_backward_twice(self):
         input = Variable(torch.randn(2, 3, 5, 5))
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 0cfd9ba..b9f93f6 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -223,27 +223,14 @@
 }
 
 [[noreturn]]
-void type_error(const char *format, ...) {
-  static const size_t ERROR_BUF_SIZE = 1024;
-  char error_buf[ERROR_BUF_SIZE];
-
-  va_list fmt_args;
-  va_start(fmt_args, format);
-  vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args);
-  va_end(fmt_args);
-
-  throw type_exception(error_buf);
-}
-
-[[noreturn]]
 static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
   auto max_pos_args = signature.max_pos_args;
   auto min_args = signature.min_args;
   if (min_args != max_pos_args) {
-    type_error("%s() takes from %d to %d positional arguments but %d were given",
+    throw TypeError("%s() takes from %d to %d positional arguments but %d were given",
         signature.name.c_str(), min_args, max_pos_args, nargs);
   }
-  type_error("%s() takes %d positional argument%s but %d %s given",
+  throw TypeError("%s() takes %d positional argument%s but %d %s given",
       signature.name.c_str(),
       max_pos_args, max_pos_args == 1 ? "" : "s",
       nargs, nargs == 1 ? "was" : "were");
@@ -265,7 +252,7 @@
     }
   }
 
-  type_error("%s() missing %d required positional argument%s: %s",
+  throw TypeError("%s() missing %d required positional argument%s: %s",
       signature.name.c_str(),
       num_missing,
       num_missing == 1 ? "s" : "",
@@ -293,23 +280,23 @@
 
   while (PyDict_Next(kwargs, &pos, &key, &value)) {
     if (!THPUtils_checkString(key)) {
-      type_error("keywords must be strings");
+      throw TypeError("keywords must be strings");
     }
 
     auto param_idx = find_param(signature, key);
     if (param_idx < 0) {
-      type_error("%s() got an unexpected keyword argument '%s'",
+      throw TypeError("%s() got an unexpected keyword argument '%s'",
           signature.name.c_str(), THPUtils_unpackString(key).c_str());
     }
 
     if (param_idx < num_pos_args) {
-      type_error("%s() got multiple values for argument '%s'",
+      throw TypeError("%s() got multiple values for argument '%s'",
           signature.name.c_str(), THPUtils_unpackString(key).c_str());
     }
   }
 
   // this should never be hit
-  type_error("invalid keyword arguments");
+  throw TypeError("invalid keyword arguments");
 }
 
 bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
@@ -364,12 +351,12 @@
     } else if (raise_exception) {
       if (is_kwd) {
         // foo(): argument 'other' must be str, not int
-        type_error("%s(): argument '%s' must be %s, not %s",
+        throw TypeError("%s(): argument '%s' must be %s, not %s",
             name.c_str(), param.name.c_str(), param.type_name().c_str(),
             Py_TYPE(obj)->tp_name);
       } else {
         // foo(): argument 'other' (position 2) must be str, not int
-        type_error("%s(): argument '%s' (position %d) must be %s, not %s",
+        throw TypeError("%s(): argument '%s' (position %d) must be %s, not %s",
             name.c_str(), param.name.c_str(), arg_pos + 1,
             param.type_name().c_str(), Py_TYPE(obj)->tp_name);
       }
@@ -453,7 +440,7 @@
   }
 
   auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options);
-  type_error("%s", msg.c_str());
+  throw TypeError("%s", msg.c_str());
 }
 
 
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 090770c..f2b3c8c 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -28,6 +28,7 @@
 
 #include "torch/csrc/THP.h"
 #include "torch/csrc/utils/object_ptr.h"
+#include "torch/csrc/Exceptions.h"
 #include "torch/csrc/autograd/python_variable.h"
 #include "torch/csrc/utils/python_numbers.h"
 #include "torch/csrc/DynamicTypes.h"
@@ -43,13 +44,6 @@
 struct FunctionSignature;
 struct PythonArgs;
 
-struct type_exception : public std::runtime_error {
-  using std::runtime_error::runtime_error;
-};
-
-[[noreturn]]
-void type_error(const char *format, ...);
-
 struct PythonArgParser {
   explicit PythonArgParser(std::vector<std::string> fmts);
 
@@ -142,7 +136,7 @@
     // a test for Py_None here; instead, you need to mark the argument
     // as *allowing none*; you can do this by writing 'Tensor?' instead
     // of 'Tensor' in the ATen metadata.
-    type_error("expected Variable as argument %d, but got %s", i, THPUtils_typename(args[i]));
+    throw TypeError("expected Variable as argument %d, but got %s", i, THPUtils_typename(args[i]));
   }
   return reinterpret_cast<THPVariable*>(args[i])->cdata;
 }
@@ -168,7 +162,7 @@
   for (int idx = 0; idx < size; idx++) {
     PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
     if (!THPVariable_Check(obj)) {
-      type_error("expected Variable as element %d in argument %d, but got %s",
+      throw TypeError("expected Variable as element %d in argument %d, but got %s",
                  idx, i, THPUtils_typename(args[i]));
     }
     res[idx] = reinterpret_cast<THPVariable*>(obj)->cdata;
@@ -184,12 +178,12 @@
   auto tuple = PyTuple_Check(arg);
   auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
   if (size != N) {
-    type_error("expected tuple of %d elements but got %d", N, (int)size);
+    throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
   }
   for (int idx = 0; idx < size; idx++) {
     PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
     if (!THPVariable_Check(obj)) {
-      type_error("expected Variable as element %d in argument %d, but got %s",
+      throw TypeError("expected Variable as element %d in argument %d, but got %s",
                  idx, i, THPUtils_typename(args[i]));
     }
     res[idx] = reinterpret_cast<THPVariable*>(obj)->cdata;
@@ -216,7 +210,7 @@
     try {
       res[idx] = THPUtils_unpackLong(obj);
     } catch (std::runtime_error &e) {
-      type_error("%s(): argument '%s' must be %s, but found element of type %s at pos %d",
+      throw TypeError("%s(): argument '%s' must be %s, but found element of type %s at pos %d",
           signature.name.c_str(), signature.params[i].name.c_str(),
           signature.params[i].type_name().c_str(), Py_TYPE(obj)->tp_name, idx + 1);
     }
@@ -258,7 +252,7 @@
 inline at::Generator* PythonArgs::generator(int i) {
   if (!args[i]) return nullptr;
   if (!THPGenerator_Check(args[i])) {
-    type_error("expected Generator as argument %d, but got %s", i, THPUtils_typename(args[i]));
+    throw TypeError("expected Generator as argument %d, but got %s", i, THPUtils_typename(args[i]));
   }
   return reinterpret_cast<THPGenerator*>(args[i])->cdata;
 }