Use TypeError in PythonArgParser (#4966)
Uses TypeError from torch/csrc/Exceptions.h in python_arg_parser.cpp so
that the exception is interpreted as a Python TypeError instead of
RuntimeError.
diff --git a/test/test_nn.py b/test/test_nn.py
index 6ac7d11..7ac837e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2149,7 +2149,7 @@
def test_Conv2d_missing_argument(self):
c = nn.Conv2d(3, 3, 3)
- self.assertRaises(RuntimeError, lambda: c(None))
+ self.assertRaises(TypeError, lambda: c(None))
def test_Conv2d_backward_twice(self):
input = Variable(torch.randn(2, 3, 5, 5))
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 0cfd9ba..b9f93f6 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -223,27 +223,14 @@
}
[[noreturn]]
-void type_error(const char *format, ...) {
- static const size_t ERROR_BUF_SIZE = 1024;
- char error_buf[ERROR_BUF_SIZE];
-
- va_list fmt_args;
- va_start(fmt_args, format);
- vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args);
- va_end(fmt_args);
-
- throw type_exception(error_buf);
-}
-
-[[noreturn]]
static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
auto max_pos_args = signature.max_pos_args;
auto min_args = signature.min_args;
if (min_args != max_pos_args) {
- type_error("%s() takes from %d to %d positional arguments but %d were given",
+ throw TypeError("%s() takes from %d to %d positional arguments but %d were given",
signature.name.c_str(), min_args, max_pos_args, nargs);
}
- type_error("%s() takes %d positional argument%s but %d %s given",
+ throw TypeError("%s() takes %d positional argument%s but %d %s given",
signature.name.c_str(),
max_pos_args, max_pos_args == 1 ? "" : "s",
nargs, nargs == 1 ? "was" : "were");
@@ -265,7 +252,7 @@
}
}
- type_error("%s() missing %d required positional argument%s: %s",
+ throw TypeError("%s() missing %d required positional argument%s: %s",
signature.name.c_str(),
num_missing,
num_missing == 1 ? "s" : "",
@@ -293,23 +280,23 @@
while (PyDict_Next(kwargs, &pos, &key, &value)) {
if (!THPUtils_checkString(key)) {
- type_error("keywords must be strings");
+ throw TypeError("keywords must be strings");
}
auto param_idx = find_param(signature, key);
if (param_idx < 0) {
- type_error("%s() got an unexpected keyword argument '%s'",
+ throw TypeError("%s() got an unexpected keyword argument '%s'",
signature.name.c_str(), THPUtils_unpackString(key).c_str());
}
if (param_idx < num_pos_args) {
- type_error("%s() got multiple values for argument '%s'",
+ throw TypeError("%s() got multiple values for argument '%s'",
signature.name.c_str(), THPUtils_unpackString(key).c_str());
}
}
// this should never be hit
- type_error("invalid keyword arguments");
+ throw TypeError("invalid keyword arguments");
}
bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
@@ -364,12 +351,12 @@
} else if (raise_exception) {
if (is_kwd) {
// foo(): argument 'other' must be str, not int
- type_error("%s(): argument '%s' must be %s, not %s",
+ throw TypeError("%s(): argument '%s' must be %s, not %s",
name.c_str(), param.name.c_str(), param.type_name().c_str(),
Py_TYPE(obj)->tp_name);
} else {
// foo(): argument 'other' (position 2) must be str, not int
- type_error("%s(): argument '%s' (position %d) must be %s, not %s",
+ throw TypeError("%s(): argument '%s' (position %d) must be %s, not %s",
name.c_str(), param.name.c_str(), arg_pos + 1,
param.type_name().c_str(), Py_TYPE(obj)->tp_name);
}
@@ -453,7 +440,7 @@
}
auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options);
- type_error("%s", msg.c_str());
+ throw TypeError("%s", msg.c_str());
}
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 090770c..f2b3c8c 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -28,6 +28,7 @@
#include "torch/csrc/THP.h"
#include "torch/csrc/utils/object_ptr.h"
+#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/DynamicTypes.h"
@@ -43,13 +44,6 @@
struct FunctionSignature;
struct PythonArgs;
-struct type_exception : public std::runtime_error {
- using std::runtime_error::runtime_error;
-};
-
-[[noreturn]]
-void type_error(const char *format, ...);
-
struct PythonArgParser {
explicit PythonArgParser(std::vector<std::string> fmts);
@@ -142,7 +136,7 @@
// a test for Py_None here; instead, you need to mark the argument
// as *allowing none*; you can do this by writing 'Tensor?' instead
// of 'Tensor' in the ATen metadata.
- type_error("expected Variable as argument %d, but got %s", i, THPUtils_typename(args[i]));
+ throw TypeError("expected Variable as argument %d, but got %s", i, THPUtils_typename(args[i]));
}
return reinterpret_cast<THPVariable*>(args[i])->cdata;
}
@@ -168,7 +162,7 @@
for (int idx = 0; idx < size; idx++) {
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
if (!THPVariable_Check(obj)) {
- type_error("expected Variable as element %d in argument %d, but got %s",
+ throw TypeError("expected Variable as element %d in argument %d, but got %s",
idx, i, THPUtils_typename(args[i]));
}
res[idx] = reinterpret_cast<THPVariable*>(obj)->cdata;
@@ -184,12 +178,12 @@
auto tuple = PyTuple_Check(arg);
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
if (size != N) {
- type_error("expected tuple of %d elements but got %d", N, (int)size);
+ throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
}
for (int idx = 0; idx < size; idx++) {
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
if (!THPVariable_Check(obj)) {
- type_error("expected Variable as element %d in argument %d, but got %s",
+ throw TypeError("expected Variable as element %d in argument %d, but got %s",
idx, i, THPUtils_typename(args[i]));
}
res[idx] = reinterpret_cast<THPVariable*>(obj)->cdata;
@@ -216,7 +210,7 @@
try {
res[idx] = THPUtils_unpackLong(obj);
} catch (std::runtime_error &e) {
- type_error("%s(): argument '%s' must be %s, but found element of type %s at pos %d",
+ throw TypeError("%s(): argument '%s' must be %s, but found element of type %s at pos %d",
signature.name.c_str(), signature.params[i].name.c_str(),
signature.params[i].type_name().c_str(), Py_TYPE(obj)->tp_name, idx + 1);
}
@@ -258,7 +252,7 @@
inline at::Generator* PythonArgs::generator(int i) {
if (!args[i]) return nullptr;
if (!THPGenerator_Check(args[i])) {
- type_error("expected Generator as argument %d, but got %s", i, THPUtils_typename(args[i]));
+ throw TypeError("expected Generator as argument %d, but got %s", i, THPUtils_typename(args[i]));
}
return reinterpret_cast<THPGenerator*>(args[i])->cdata;
}