Add registerOperator overloads that infer the schema (#10048)
Summary:
This PR adds a way to infer the JIT/script schema of a function from its signature, and then create an operator from the schema and implementation. The implementation function is wrapped into another function, which pops values from the stack into an argument tuple, then invokes the function and pushes the return value back onto the stack, sometimes unpacking the return value if it is a tuple.
Currently the method is called `createOperator`. We may want to think of a nicer way of registering ops in tandem with `RegisterOperators`. It might be very cumbersome to add a template constructor to `Operator`, so maybe we can come up with a chaining method on `RegisterOperators` like `RegisterOperators(schema, func).op(schema.func).op(schema, func)` -- it has to work at startup time (for a static variable) though. We can solve this in another PR.
zdevito apaszke smessmer dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10048
Differential Revision: D9125975
Pulled By: goldsborough
fbshipit-source-id: de9e59888757573284a43787ae5d94384bfe8f9a
diff --git a/caffe2/utils/Metaprogramming.h b/caffe2/utils/Metaprogramming.h
index f6f9318..a8c9450 100644
--- a/caffe2/utils/Metaprogramming.h
+++ b/caffe2/utils/Metaprogramming.h
@@ -7,7 +7,24 @@
#include "caffe2/utils/Array.h"
namespace c10 { namespace guts {
+namespace detail {
+/**
+ * strip_class: helper to remove the class type from pointers to `operator()`.
+ */
+template <typename T>
+struct strip_class {};
+template <typename Class, typename Result, typename... Args>
+struct strip_class<Result (Class::*)(Args...)> {
+ using type = Result(Args...);
+};
+template <typename Class, typename Result, typename... Args>
+struct strip_class<Result (Class::*)(Args...) const> {
+ using type = Result(Args...);
+};
+template <typename T>
+using strip_class_t = typename strip_class<T>::type;
+} // namespace detail
/**
* Access information about result type or arguments from a function type.
@@ -23,9 +40,27 @@
using func_type = Result (Args...);
using return_type = Result;
using parameter_types = typelist::typelist<Args...>;
+ static constexpr auto number_of_parameters = sizeof...(Args);
};
+/**
+ * infer_function_traits: creates a `function_traits` type for a simple
+ * function (pointer) or functor (lambda/struct). Currently does not support
+ * class methods.
+ */
+template <typename Functor>
+struct infer_function_traits {
+ using type = function_traits<detail::strip_class_t<decltype(&Functor::operator())>>;
+};
+
+template <typename Result, typename... Args>
+struct infer_function_traits<Result (*)(Args...)> {
+ using type = function_traits<Result(Args...)>;
+};
+
+template <typename T>
+using infer_function_traits_t = typename infer_function_traits<T>::type;
/**
* Use extract_arg_by_filtered_index to return the i-th argument whose
diff --git a/caffe2/utils/TypeList.h b/caffe2/utils/TypeList.h
index 7c20fa6..79764f9 100644
--- a/caffe2/utils/TypeList.h
+++ b/caffe2/utils/TypeList.h
@@ -177,6 +177,33 @@
};
template<class TypeList> using head_t = typename head<TypeList>::type;
+/**
+ * Returns the N-th element of a type list.
+ * Example:
+ * int == element_t<1, typelist<float, int, char>>
+ */
+
+/// Base template.
+template<size_t Index, class TypeList> struct element final {
+ static_assert(detail::false_t<TypeList>::value, "In typelist::element<T>, the T argument must be typelist<...>.");
+};
+
+/// Successful case, we have reached the zero index and can "return" the head type.
+template<class Head, class... Tail> struct element<0, typelist<Head, Tail...>> { using type = Head; };
+
+/// Error case, we have an index but ran out of types! It will only be selected
+/// if `Ts...` is actually empty!
+template <size_t Index, class... Ts>
+struct element<Index, typelist<Ts...>> {
+ static_assert(Index < sizeof...(Ts), "Index is out of bounds in typelist::element");
+};
+
+/// Shave off types until we hit the <0, Head, Tail...> or <Index> case.
+template<size_t Index, class Head, class... Tail> struct element<Index, typelist<Head, Tail...>> : element<Index-1, typelist<Tail...>> { };
+
+/// Convenience alias.
+template<size_t Index, class TypeList>
+using element_t = typename element<Index, TypeList>::type;
/**
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index 07ab317..eab076a 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -22,7 +22,7 @@
n->f_(attr::value, val.toDouble());
n->output()->setType(FloatType::get());
} else if(val.isIntList()) {
- n->is_(attr::value, val.toIntList()->elements().vec());
+ n->is_(attr::value, val.toIntList()->elements());
n->output()->setType(ListType::ofInts());
} else if(val.isTensorList()) {
n->ts_(attr::value, fmap(val.toTensorList()->elements(), [](const at::Tensor & t) {
diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h
new file mode 100644
index 0000000..63e9901
--- /dev/null
+++ b/torch/csrc/jit/custom_operator.h
@@ -0,0 +1,190 @@
+#pragma once
+
+#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/utils/variadic.h>
+
+#include <caffe2/utils/Metaprogramming.h>
+#include <caffe2/utils/TypeList.h>
+
+namespace torch { namespace jit {
+namespace detail {
+template <typename... Ts, size_t... Is>
+std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
+ // Arguments are named "_<index>"
+ return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
+}
+
+template <typename... Ts, size_t... Is>
+std::vector<Argument> createReturns(Indices<Is...> indices) {
+ return createArgumentVectorFromTypes<Ts..., Is...>();
+}
+
+/// Unpack a tuple return type into a vector of return types, one per tuple
+/// element.
+template <typename... Ts>
+std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
+ // Create an index pack so we can call `get<Indices>` on the tuple next.
+ return createReturns<Ts...>(typename MakeIndices<sizeof...(Ts)>::indices{});
+}
+
+/// Create a single-element `vector` for simple (non-tuple) return types.
+template <typename ReturnType>
+std::vector<Argument> createReturns(ReturnType*) {
+ return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
+}
+
+/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
+/// into the argument list.
+template <typename FunctionTraits, size_t... Is>
+std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
+ using ArgumentTypes = typename FunctionTraits::parameter_types;
+ return createArgumentVectorFromTypes<
+ c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
+}
+
+/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
+/// function.
+template <typename FunctionTraits>
+FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
+ using ReturnType = typename FunctionTraits::return_type;
+ auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
+ typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
+ auto returns = createReturns(static_cast<ReturnType*>(nullptr));
+ return {name, arguments, returns};
+}
+
+/// Does two things for an operator implementation and a tuple of arguments:
+/// 1. Pops all necessary arguments off the stack into the tuple's elements,
+/// 2. Unpacks the tuple and calls the operator implementation.
+/// The result of the implementation call is returned.
+template <
+ typename ReturnType,
+ typename Implementation,
+ typename... Types,
+ size_t... Is>
+ReturnType callOperatorWithTuple(
+ Implementation&& implementation,
+ Stack& stack,
+ std::tuple<Types...>& tuple,
+ Indices<Is...>) {
+ pop(stack, std::get<Is>(tuple)...);
+ return std::forward<Implementation>(implementation)(std::get<Is>(tuple)...);
+}
+
+void checkArgumentVector(
+ const char* what,
+ const std::vector<Argument>& inferred,
+ const std::vector<Argument>& provided,
+ const FunctionSchema& inferredSchema,
+ const FunctionSchema& providedSchema) {
+ AT_CHECK(
+ inferred.size() == provided.size(),
+ "Inferred ", inferred.size(), " ", what,
+ "(s) for operator implementation, but the provided schema specified ",
+ provided.size(), " ", what, "(s). Inferred schema: ",
+ inferredSchema, " | Provided schema: ", providedSchema);
+ for (size_t i = 0; i < provided.size(); ++i) {
+ AT_CHECK(
+ provided[i].type->isSubtypeOf(inferred[i].type),
+ "Inferred type for ", what, " #", i, " was ",
+ *inferred[i].type, ", but the provided schema specified type ",
+ *provided[i].type, " for the ", what,
+ " in that position. Inferred schema: ",
+ inferredSchema, " | Provided schema: ", providedSchema);
+ }
+}
+
+/// If `schemaOrName` contains a `(`, it is assumed it specifies a schema, else
+/// it is assumed it only specifies the name. In the case where it is a full
+/// schema (assumed), we nevertheless infer the schema and verify that the user
+/// made no mistakes. Either way, this function returns the final schema.
+template <typename Traits>
+FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
+ // If there is no '(' in the schema, we assume this is only the name (e.g.
+ // "foo::bar").
+ const auto bracketIndex = schemaOrName.find('(');
+ if (bracketIndex == std::string::npos) {
+ // Infer the full schema and we're good.
+ return torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
+ /*name=*/schemaOrName);
+ }
+
+ // If the user provided her own schema, we need to infer it nevertheless and
+ // check that it's correct. We return the user provided schema in the end
+ // because it has proper argument names.
+
+ auto providedSchema = parseSchema(schemaOrName);
+
+ const auto inferredSchema =
+ torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
+ providedSchema.name);
+ checkArgumentVector(
+ "argument",
+ inferredSchema.arguments,
+ providedSchema.arguments,
+ inferredSchema,
+ providedSchema);
+ checkArgumentVector(
+ "return value",
+ inferredSchema.returns,
+ providedSchema.returns,
+ inferredSchema,
+ providedSchema);
+ return providedSchema;
+}
+} // namespace detail
+
+/// Registers a custom operator with a name or schema, and an implementation
+/// function.
+///
+/// If the first argument specifies only the function name like `foo::bar`, the
+/// schema, including the type of each argument and the return type, is inferred
+/// from the function signature. Otherwise, the string should specify the whole
+/// schema, like `foo::bar(Tensor a, double b) -> Tensor`. In that case, the
+/// schema will still be inferred from the function and checked against this
+/// provided schema.
+///
+/// If the schema is left to be inferred, the argument names will take on
+/// sequential placeholder names like `_0`, `_1`, '_2' and so on. If you want
+/// argument names to be preserved, you should provide the schema yourself.
+///
+/// The implementation function can be a function pointer or a functor
+/// (including a lambda object). The function (or `operator()`) can take any
+/// number of arguments with a type from the subset accepted by the PyTorch
+/// JIT/Script backend, and return a single type or a tuple of types.
+///
+/// Example invocation:
+/// ```
+/// createOperator(
+/// "foo::bar(float a, Tensor b)",
+/// [](float a, at::Tensor b) { return a + b; });
+/// ```
+template <typename Implementation>
+Operator createOperator(
+ const std::string& schemaOrName,
+ Implementation&& implementation) {
+ using Traits = c10::guts::infer_function_traits_t<Implementation>;
+ using ArgumentTypes =
+ c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
+ using ArgumentTuple =
+ typename c10::guts::typelist::to_tuple<ArgumentTypes>::type;
+ using ReturnType = decay_t<typename Traits::return_type>;
+
+ auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
+
+ return Operator(schema, [implementation](Stack& stack) {
+ ArgumentTuple tuple;
+ auto result = torch::jit::detail::callOperatorWithTuple<ReturnType>(
+ std::move(implementation),
+ stack,
+ tuple,
+ typename MakeIndices<std::tuple_size<ArgumentTuple>::value>::indices{});
+ pack(stack, std::move(result));
+ return 0;
+ });
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h
index 390f89f..01e8b0b 100644
--- a/torch/csrc/jit/function_schema.h
+++ b/torch/csrc/jit/function_schema.h
@@ -8,7 +8,7 @@
// schema as used in the compiler for resolving function calls and reporting
// errors. These objects should be constructed from C10 schema once those
-// are availiable
+// are available.
struct Argument {
Argument(
std::string name = "",
diff --git a/torch/csrc/jit/ivalue.h b/torch/csrc/jit/ivalue.h
index 0ad0323..1959c72 100644
--- a/torch/csrc/jit/ivalue.h
+++ b/torch/csrc/jit/ivalue.h
@@ -258,7 +258,9 @@
return out;
}
- std::vector<int64_t> copyToIntList() const;
+ const std::vector<int64_t>& toIntListRef() const;
+ const std::vector<double>& toFloatListRef() const;
+ const std::vector<at::Tensor>& toTensorListRef() const;
// ConstantString
IValue(Shared<ConstantString> v);
@@ -426,7 +428,9 @@
DEFINE_TO(Shared<ConstantString>, toString)
DEFINE_TO(at::Scalar, toScalar)
DEFINE_TO(bool, toInt)
-DEFINE_TO(std::vector<int64_t>, copyToIntList)
+DEFINE_TO(std::vector<int64_t>, toIntListRef)
+DEFINE_TO(std::vector<double>, toFloatListRef)
+DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)
#undef DEFINE_TO
@@ -443,10 +447,10 @@
return Shared<ConstantList<Elem>>(
new ConstantList<Elem>(std::move(elements_)), false);
}
- at::ArrayRef<Elem> elements() const {
+ const std::vector<Elem>& elements() const {
return elements_;
}
- operator at::ArrayRef<Elem>() const {
+ operator const std::vector<Elem>&() const {
return elements();
}
};
@@ -485,8 +489,16 @@
inline IValue::IValue(std::vector<at::Tensor> v)
: IValue(TensorList::create(std::move(v))) {}
-inline std::vector<int64_t> IValue::copyToIntList() const {
- return toIntList()->elements().vec();
+inline const std::vector<int64_t>& IValue::toIntListRef() const {
+ return toIntList()->elements();
+}
+
+inline const std::vector<double>& IValue::toFloatListRef() const {
+ return toDoubleList()->elements();
+}
+
+inline const std::vector<at::Tensor>& IValue::toTensorListRef() const {
+ return toTensorList()->elements();
}
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 0429863..d1c1456 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -363,7 +363,7 @@
at::optional<std::vector<int64_t>> getIntListAttribute(at::optional<int32_t> N, Value* input) {
auto list = constant_as<Shared<jit::IntList>>(input);
if(list)
- return list.value()->elements().vec();
+ return list.value()->elements();
// broadcast IntList[3] with value 4 -> {4, 4, 4}
if(!N)
diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h
index c25636e..21a8311 100644
--- a/torch/csrc/jit/script/module.h
+++ b/torch/csrc/jit/script/module.h
@@ -275,13 +275,13 @@
return modules.get(name).module;
}
- const detail::OrderedDict<std::string, NamedModule>& get_modules() const {
+ const torch::detail::OrderedDict<std::string, NamedModule>& get_modules() const {
return modules;
}
- const detail::OrderedDict<std::string, NamedParameter>& get_parameters() const {
+ const torch::detail::OrderedDict<std::string, NamedParameter>& get_parameters() const {
return parameters;
}
- const detail::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods() const {
+ const torch::detail::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods() const {
return methods;
}
@@ -304,9 +304,9 @@
// it is only legal to _add_ new modules and parameters.
// removing them will allow member_inputs to point to invalid parameters
// no such restriction exists for methods
- detail::OrderedDict<std::string, NamedModule> modules;
- detail::OrderedDict<std::string, NamedParameter> parameters;
- detail::OrderedDict<std::string, std::unique_ptr<Method>> methods;
+ torch::detail::OrderedDict<std::string, NamedModule> modules;
+ torch::detail::OrderedDict<std::string, NamedParameter> parameters;
+ torch::detail::OrderedDict<std::string, std::unique_ptr<Method>> methods;
bool optimize;
};
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index d5d204f..dd523c8 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -3,6 +3,8 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
+using Catch::StartsWith;
+
#else
#define REQUIRE JIT_ASSERT
@@ -26,6 +28,8 @@
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
+#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/variable_tensor_functions.h"
#include "torch/csrc/autograd/variable.h"
@@ -926,7 +930,7 @@
JIT_ASSERT(foo2.isDouble());
JIT_ASSERT(foo2.toDouble() == 4.0);
JIT_ASSERT(foo->use_count() == 2);
- JIT_ASSERT(baz.toIntList()->elements().equals({3,4,5}));
+ JIT_ASSERT(ArrayRef<int64_t>(baz.toIntList()->elements()).equals({3,4,5}));
auto move_it = std::move(baz).toIntList();
JIT_ASSERT(foo->use_count() == 2);
@@ -936,10 +940,11 @@
IValue dlist(DoubleList::create({3.5}));
JIT_ASSERT(
dlist.isDoubleList() &&
- std::move(dlist).toDoubleList()->elements().equals({3.5}));
+ ArrayRef<double>(std::move(dlist).toDoubleList()->elements())
+ .equals({3.5}));
JIT_ASSERT(dlist.isNone());
dlist = IValue(DoubleList::create({3.4}));
- JIT_ASSERT(dlist.toDoubleList()->elements().equals({3.4}));
+ JIT_ASSERT(ArrayRef<double>(dlist.toDoubleList()->elements()).equals({3.4}));
IValue the_list(Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
JIT_ASSERT(foo->use_count() == 3);
JIT_ASSERT(the_list.isTuple());
@@ -960,6 +965,132 @@
proto.set_producer_name("foo");
}
+void testCustomOperators() {
+ {
+ RegisterOperators reg({createOperator(
+ "foo::bar", [](double a, at::Tensor b) { return a + b; })});
+ auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
+ REQUIRE(ops.size() == 1);
+
+ auto& op = ops.front();
+ REQUIRE(op->schema().name == "foo::bar");
+
+ REQUIRE(op->schema().arguments.size() == 2);
+ REQUIRE(op->schema().arguments[0].name == "_0");
+ REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
+ REQUIRE(op->schema().arguments[1].name == "_1");
+ REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
+
+ REQUIRE(op->schema().returns.size() == 1);
+ REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
+
+ Stack stack;
+ push(stack, 2.0f, at::ones(5));
+ op->getOperation()(stack);
+ at::Tensor output;
+ pop(stack, output);
+
+ REQUIRE(output.allclose(at::full(5, 3.0f)));
+ }
+ {
+ RegisterOperators reg({createOperator(
+ "foo::bar_with_schema(float a, Tensor b) -> Tensor",
+ [](double a, at::Tensor b) { return a + b; })});
+
+ auto& ops =
+ getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
+ REQUIRE(ops.size() == 1);
+
+ auto& op = ops.front();
+ REQUIRE(op->schema().name == "foo::bar_with_schema");
+
+ REQUIRE(op->schema().arguments.size() == 2);
+ REQUIRE(op->schema().arguments[0].name == "a");
+ REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
+ REQUIRE(op->schema().arguments[1].name == "b");
+ REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
+
+ REQUIRE(op->schema().returns.size() == 1);
+ REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
+
+ Stack stack;
+ push(stack, 2.0f, at::ones(5));
+ op->getOperation()(stack);
+ at::Tensor output;
+ pop(stack, output);
+
+ REQUIRE(output.allclose(at::full(5, 3.0f)));
+ }
+ {
+ // Check that lists work well.
+ RegisterOperators reg({createOperator(
+ "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]",
+ [](const std::vector<int64_t>& ints,
+ const std::vector<double>& floats,
+ std::vector<at::Tensor> tensors) { return floats; })});
+
+ auto& ops =
+ getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
+ REQUIRE(ops.size() == 1);
+
+ auto& op = ops.front();
+ REQUIRE(op->schema().name == "foo::lists");
+
+ REQUIRE(op->schema().arguments.size() == 3);
+ REQUIRE(op->schema().arguments[0].name == "ints");
+ REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofInts()));
+ REQUIRE(op->schema().arguments[1].name == "floats");
+ REQUIRE(op->schema().arguments[1].type->isSubtypeOf(ListType::ofFloats()));
+ REQUIRE(op->schema().arguments[2].name == "tensors");
+ REQUIRE(op->schema().arguments[2].type->isSubtypeOf(ListType::ofTensors()));
+
+ REQUIRE(op->schema().returns.size() == 1);
+ REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofFloats()));
+
+ Stack stack;
+ push(stack, std::vector<int64_t>{1, 2});
+ push(stack, std::vector<double>{1.0, 2.0});
+ push(stack, std::vector<at::Tensor>{at::ones(5)});
+ op->getOperation()(stack);
+ std::vector<double> output;
+ pop(stack, output);
+
+ REQUIRE(output.size() == 2);
+ REQUIRE(output[0] == 1.0);
+ REQUIRE(output[1] == 2.0);
+ }
+ {
+#ifdef USE_CATCH
+ REQUIRE_THROWS_WITH(
+ createOperator(
+ "foo::bar_with_bad_schema(Tensor a) -> Tensor",
+ [](double a, at::Tensor b) { return a + b; }),
+ StartsWith("Inferred 2 argument(s) for operator implementation, "
+ "but the provided schema specified 1 argument(s)."));
+ REQUIRE_THROWS_WITH(
+ createOperator(
+ "foo::bar_with_bad_schema(Tensor a) -> Tensor",
+ [](double a) { return a; }),
+ StartsWith("Inferred type for argument #0 was float, "
+ "but the provided schema specified type Dynamic "
+ "for the argument in that position"));
+ REQUIRE_THROWS_WITH(
+ createOperator(
+ "foo::bar_with_bad_schema(float a) -> (float, float)",
+ [](double a) { return a; }),
+ StartsWith("Inferred 1 return value(s) for operator implementation, "
+ "but the provided schema specified 2 return value(s)."));
+ REQUIRE_THROWS_WITH(
+ createOperator(
+ "foo::bar_with_bad_schema(float a) -> Tensor",
+ [](double a) { return a; }),
+ StartsWith("Inferred type for return value #0 was float, "
+ "but the provided schema specified type Dynamic "
+ "for the return value in that position"));
+#endif // USE_CATCH
+ }
+}
+
TORCH_API std::string runJITCPPTests() {
std::stringstream out;
testIValue();
@@ -980,6 +1111,7 @@
argumentSpecTest();
shapeAnalysisTest();
testProto();
+ testCustomOperators();
return out.str();
}
@@ -1006,6 +1138,8 @@
attributesTest();
SECTION( "interned strings" )
internedStringsTests();
+ SECTION( "custom operators" )
+ testCustomOperators();
}
TEST_CASE( "jit test CUDA", "[cuda]" ) {
diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp
index 7f246ce..5248f42 100644
--- a/torch/csrc/jit/type.cpp
+++ b/torch/csrc/jit/type.cpp
@@ -80,5 +80,9 @@
static auto value = ListType::create(IntType::get());
return value;
}
+ListTypePtr ListType::ofFloats() {
+ static auto value = ListType::create(FloatType::get());
+ return value;
+}
}} // namespace torch::jit
diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h
index 71b8b95..36c8081 100644
--- a/torch/csrc/jit/type.h
+++ b/torch/csrc/jit/type.h
@@ -240,6 +240,7 @@
// common cast List[Tensor]
static ListTypePtr ofTensors();
static ListTypePtr ofInts();
+ static ListTypePtr ofFloats();
private:
ListType(TypePtr elem)
: Type(TypeKind::ListType), elem(elem) {}
@@ -457,4 +458,25 @@
AT_ERROR("unknown number type", typ->str());
}
+template <typename T>
+TypePtr getTypePtr() {
+#define TYPE_STR(Type) #Type, " ",
+ AT_ERROR(
+ "Type ",
+ at::demangle_type<T>(),
+ " could not be converted to any of the known types { ",
+ TH_FORALL_TYPES(TYPE_STR) "}");
+#undef TYPE_STR
+ return nullptr;
+}
+
+template<> inline TypePtr getTypePtr<at::Tensor>() { return DynamicType::get(); }
+template<> inline TypePtr getTypePtr<double>() { return FloatType::get(); }
+template<> inline TypePtr getTypePtr<int64_t>() { return IntType::get(); }
+template<> inline TypePtr getTypePtr<bool>() { return IntType::get(); }
+template<> inline TypePtr getTypePtr<at::Scalar>() { return NumberType::get(); }
+template<> inline TypePtr getTypePtr<std::vector<at::Tensor>>() { return ListType::ofTensors(); }
+template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::ofFloats(); }
+template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
+
}} // namespace torch::jit
diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h
index c5d8984..0468a75 100644
--- a/torch/csrc/utils/variadic.h
+++ b/torch/csrc/utils/variadic.h
@@ -172,7 +172,6 @@
template <typename Function, typename... Ts>
void apply(Function function, Ts&&... ts) {
- //
// https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector
// Creates a dummy array, so that each function call is evaluated in order.
// `(function(), 0)` is because `function` should (!) return `void`, so