Improve and test error messages for signature mismatches (#18547)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18547
- Argument indices in the error messages are 1-indexed not 0-indexed.
- Add test cases that a mismatching signature actually shows the correct error messages
Reviewed By: dzhulgakov
Differential Revision: D14656695
fbshipit-source-id: 55e45634baa3117e18b8687ea6b2a2f83715bdf6
diff --git a/aten/src/ATen/core/op_registration/infer_schema.cpp b/aten/src/ATen/core/op_registration/infer_schema.cpp
index bcfeb87..8122cbd 100644
--- a/aten/src/ATen/core/op_registration/infer_schema.cpp
+++ b/aten/src/ATen/core/op_registration/infer_schema.cpp
@@ -21,7 +21,7 @@
if (inferred.returns().size() != specified.returns().size()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
- "The number of returns is different.Specified ", specified.returns().size(),
+ "The number of returns is different. Specified ", specified.returns().size(),
" but inferred ", inferred.returns().size());
}
@@ -29,7 +29,7 @@
if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
- "Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(),
+ "Type mismatch in argument ", (i+1) , ": specified ", specified.arguments()[i].type()->str(),
" but inferred ", inferred.arguments()[i].type()->str());
}
}
@@ -38,7 +38,7 @@
if (*inferred.returns()[i].type() != *specified.returns()[i].type()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
- "Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(),
+ "Type mismatch in return ", (i+1), ": specified ", specified.returns()[i].type()->str(),
" but inferred ", inferred.returns()[i].type()->str());
}
}
diff --git a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp
index b122aea..206c9bd 100644
--- a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp
+++ b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp
@@ -547,15 +547,15 @@
), &kernel_func<int64_t, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor>::func);
+ }, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -568,37 +568,37 @@
), &kernel_func<void, Tensor, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func),
- c10::Error
+ ), &kernel_func<void, Tensor, Tensor>::func);
+ }, "The number of arguments is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func),
- c10::Error
+ ), &kernel_func<void, Tensor, Tensor>::func);
+ }, "The number of arguments is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func),
- c10::Error
+ ), &kernel_func<void, Tensor, Tensor>::func);
+ }, "The number of arguments is different. Specified 3 but inferred 2"
);
}
@@ -613,26 +613,26 @@
), &kernel_func<int64_t, Tensor, int64_t>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor, int64_t>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor, int64_t>::func);
+ }, "Type mismatch in argument 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor, int64_t>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor, int64_t>::func);
+ }, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
@@ -647,18 +647,18 @@
), &kernel_func<int64_t, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), &kernel_func<int64_t, Tensor>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor>::func);
+ }, "The number of returns is different. Specified 0 but inferred 1"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
@@ -666,8 +666,8 @@
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor>::func);
+ }, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -680,26 +680,26 @@
), &kernel_func<void, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), &kernel_func<void, Tensor>::func),
- c10::Error
+ ), &kernel_func<void, Tensor>::func);
+ }, "The number of returns is different. Specified 1 but inferred 0"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), &kernel_func<void, Tensor>::func),
- c10::Error
+ ), &kernel_func<void, Tensor>::func);
+ }, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
@@ -712,37 +712,37 @@
), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func),
- c10::Error
+ ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ }, "The number of returns is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func),
- c10::Error
+ ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ }, "The number of returns is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func),
- c10::Error
+ ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ }, "The number of returns is different. Specified 3 but inferred 2"
);
}
@@ -757,26 +757,26 @@
), &kernel_func<int64_t, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), &kernel_func<int64_t, Tensor>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor>::func);
+ }, "Type mismatch in return 1: specified Tensor but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), &kernel_func<int64_t, Tensor>::func),
- c10::Error
+ ), &kernel_func<int64_t, Tensor>::func);
+ }, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
@@ -789,15 +789,15 @@
), &kernel_func<Tensor, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), &kernel_func<Tensor, Tensor>::func),
- c10::Error
+ ), &kernel_func<Tensor, Tensor>::func);
+ }, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
@@ -810,26 +810,26 @@
), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func),
- c10::Error
+ ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+ }, "Type mismatch in return 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func),
- c10::Error
+ ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+ }, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp
index 4d86d28..2c8a590 100644
--- a/aten/src/ATen/core/op_registration/kernel_function_test.cpp
+++ b/aten/src/ATen/core/op_registration/kernel_function_test.cpp
@@ -557,15 +557,15 @@
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -578,37 +578,37 @@
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 3 but inferred 2"
);
}
@@ -623,26 +623,26 @@
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in argument 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
@@ -657,18 +657,18 @@
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 0 but inferred 1"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
@@ -676,8 +676,8 @@
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -690,26 +690,26 @@
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 1 but inferred 0"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
@@ -722,37 +722,37 @@
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 3 but inferred 2"
);
}
@@ -767,26 +767,26 @@
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified Tensor but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
@@ -799,15 +799,15 @@
), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
@@ -820,26 +820,26 @@
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp
index 49e39bb..90f18ed 100644
--- a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp
+++ b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp
@@ -710,15 +710,15 @@
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -731,37 +731,37 @@
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 3 but inferred 2"
);
}
@@ -776,26 +776,26 @@
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in argument 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
@@ -810,18 +810,18 @@
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 0 but inferred 1"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
@@ -829,8 +829,8 @@
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -843,26 +843,26 @@
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 1 but inferred 0"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
@@ -875,37 +875,37 @@
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 3 but inferred 2"
);
}
@@ -920,26 +920,26 @@
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified Tensor but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
@@ -952,15 +952,15 @@
), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
@@ -973,26 +973,26 @@
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
index 8f45d34..884cd2e 100644
--- a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
+++ b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
@@ -497,15 +497,15 @@
), [] (Tensor) -> int64_t {return 0;});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor) -> int64_t {return 0;});
+ }, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -518,37 +518,37 @@
), [] (Tensor, Tensor) -> void {});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {}),
- c10::Error
+ ), [] (Tensor, Tensor) -> void {});
+ }, "The number of arguments is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {}),
- c10::Error
+ ), [] (Tensor, Tensor) -> void {});
+ }, "The number of arguments is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {}),
- c10::Error
+ ), [] (Tensor, Tensor) -> void {});
+ }, "The number of arguments is different. Specified 3 but inferred 2"
);
}
@@ -563,26 +563,26 @@
), [] (Tensor, int64_t) -> int64_t {return 0;});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor, int64_t) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor, int64_t) -> int64_t {return 0;});
+ }, "Type mismatch in argument 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor, int64_t) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor, int64_t) -> int64_t {return 0;});
+ }, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
@@ -597,18 +597,18 @@
), [] (Tensor) -> int64_t {return 0;});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), [] (Tensor) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor) -> int64_t {return 0;});
+ }, "The number of returns is different. Specified 0 but inferred 1"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
@@ -616,8 +616,8 @@
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor) -> int64_t {return 0;});
+ }, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -630,26 +630,26 @@
), [] (Tensor) -> void {});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), [] (Tensor) -> void {}),
- c10::Error
+ ), [] (Tensor) -> void {});
+ }, "The number of returns is different. Specified 1 but inferred 0"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), [] (Tensor) -> void {}),
- c10::Error
+ ), [] (Tensor) -> void {});
+ }, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
@@ -662,37 +662,37 @@
), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}),
- c10::Error
+ ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ }, "The number of returns is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}),
- c10::Error
+ ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ }, "The number of returns is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}),
- c10::Error
+ ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ }, "The number of returns is different. Specified 3 but inferred 2"
);
}
@@ -707,26 +707,26 @@
), [] (Tensor) -> int64_t {return 0;});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), [] (Tensor) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor) -> int64_t {return 0;});
+ }, "Type mismatch in return 1: specified Tensor but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), [] (Tensor) -> int64_t {return 0;}),
- c10::Error
+ ), [] (Tensor) -> int64_t {return 0;});
+ }, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
@@ -739,15 +739,15 @@
), [] (Tensor) -> Tensor {return {};});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), [] (Tensor) -> Tensor {return {};}),
- c10::Error
+ ), [] (Tensor) -> Tensor {return {};});
+ }, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
@@ -760,26 +760,26 @@
), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}),
- c10::Error
+ ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+ }, "Type mismatch in return 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}),
- c10::Error
+ ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+ }, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
index d9defec..c156e48 100644
--- a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
+++ b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
@@ -514,15 +514,15 @@
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -535,37 +535,37 @@
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ }, "The number of arguments is different. Specified 3 but inferred 2"
);
}
@@ -580,26 +580,26 @@
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in argument 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
@@ -614,18 +614,18 @@
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 0 but inferred 1"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
@@ -633,8 +633,8 @@
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
@@ -647,26 +647,26 @@
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 1 but inferred 0"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
@@ -679,37 +679,37 @@
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 0 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 1 but inferred 2"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ }, "The number of returns is different. Specified 3 but inferred 2"
);
}
@@ -724,26 +724,26 @@
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified Tensor but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
@@ -756,15 +756,15 @@
), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
@@ -777,26 +777,26 @@
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 2: specified float but inferred int"
);
- EXPECT_THROW(
+ expectThrows<c10::Error>([] {
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
- c10::Error
+ ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+ }, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index 224bab4..f81d5ee 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -48,10 +48,9 @@
TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) {
// make sure it crashes when kernel is absent
- EXPECT_THROW(
- c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())),
- c10::Error
- );
+ expectThrows<c10::Error>([&] {
+ c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1()));
+ }, "but didn't specify a kernel");
// but make sure it doesn't crash when kernel is present
c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
@@ -62,10 +61,9 @@
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
- EXPECT_THROW(
- callOp(*op, dummyTensor(TensorType2())),
- c10::Error
- );
+ expectThrows<c10::Error>([&] {
+ callOp(*op, dummyTensor(TensorType2()));
+ }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) {
@@ -77,10 +75,9 @@
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
- EXPECT_THROW(
- callOp(*op, dummyTensor(TensorType2())),
- c10::Error
- );
+ expectThrows<c10::Error>([&] {
+ callOp(*op, dummyTensor(TensorType2()));
+ }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) {
diff --git a/aten/src/ATen/core/op_registration/test_helpers.h b/aten/src/ATen/core/op_registration/test_helpers.h
index 595912c..23b5b2b 100644
--- a/aten/src/ATen/core/op_registration/test_helpers.h
+++ b/aten/src/ATen/core/op_registration/test_helpers.h
@@ -1,6 +1,7 @@
#pragma once
#include <gtest/gtest.h>
+#include <gmock/gmock.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/dispatch/Dispatcher.h>
@@ -44,3 +45,15 @@
auto op = c10::Dispatcher::singleton().findSchema(op_name, "");
EXPECT_FALSE(op.has_value());
}
+
+template<class Exception, class Functor>
+inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
+ try {
+ std::forward<Functor>(functor)();
+ } catch (const Exception& e) {
+ EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
+ return;
+ }
+ ADD_FAILURE() << "Expected to throw exception containing \""
+ << expectMessageContains << "\" but didn't throw";
+}