Build mechanism for custom operators (#10226)
Summary:
This is the last step in the custom operator implementation: providing a way to build from C++ and Python. For this I:
1. Created a `FindTorch.cmake` taken largely from ebetica with a CMake function to easily create simple custom op libraries
2. Created a ` torch/op.h` header for easy inclusion of necessary headers,
3. Created a test directory `pytorch/test/custom_operator` which includes the basic setup for a custom op.
1. It defines an op in `op.{h,cpp}`
2. Registers it with the JIT using `RegisterOperators`
3. Builds it into a shared library via a `CMakeLists.txt`
4. Binds it into Python using a `setup.py`. This step makes use of our C++ extension setup that we already have. No work, yey!
The pure C++ and the Python builds are separate and not coupled in any way.
zdevito soumith dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10226
Differential Revision: D9296839
Pulled By: goldsborough
fbshipit-source-id: 32f74cafb6e3d86cada8dfca8136d0dfb1f197a0
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 577f158..23a5080 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -230,7 +230,7 @@
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
endif()
if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0")))
- OR (CMAKE_COMPILER_IS_GNUCXX
+ OR (CMAKE_COMPILER_IS_GNUCXX
AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE)))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
endif()
diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in
new file mode 100644
index 0000000..a14b2e1
--- /dev/null
+++ b/cmake/TorchConfig.cmake.in
@@ -0,0 +1,55 @@
+# FindTorch
+# -------
+#
+# Finds the Torch library
+#
+# This will define the following variables:
+#
+# TORCH_FOUND -- True if the system has the Torch library
+# TORCH_INCLUDE_DIRS -- The include directories for torch
+# TORCH_LIBRARIES -- Libraries to link to
+#
+# and the following imported targets:
+#
+# Torch
+#
+# and the following functions:
+#
+# torch_add_custom_op_library(<name> <source_files>)
+
+SET(TORCH_ROOT "${CMAKE_CURRENT_LIST_DIR}/../")
+
+set(TORCH_INCLUDE_DIRS
+ "${TORCH_ROOT}"
+ "${TORCH_ROOT}/aten/src"
+ "${CMAKE_CURRENT_LIST_DIR}/aten/src"
+ "${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src"
+ "${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src/TH"
+)
+
+find_library(TORCH_LIBRARY torch PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
+find_library(CAFFE2_LIBRARY caffe2 PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
+
+if (@USE_CUDA@)
+ find_package(CUDA REQUIRED)
+ find_library(CAFFE2_CUDA_LIBRARY caffe2_gpu PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
+ set(TORCH_CUDA_LIBRARIES -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 cuda nvrtc cudart nvToolsExt)
+ list(APPEND TORCH_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
+endif()
+
+set(TORCH_LIBRARIES
+ ${TORCH_LIBRARY}
+ ${CAFFE2_LIBRARY}
+ ${CAFFE2_CUDA_LIBRARY}
+ ${TORCH_CUDA_LIBRARIES})
+
+# Creates a shared library <name> with the correct include directories
+# and linker flags set to include Torch header files and link with Torch
+# libraries. Also sets the C++ standard version to C++11. All options
+# can be override by specifying further options on the `<name>` CMake target.
+function(torch_add_custom_op_library name source_files)
+ add_library(${name} SHARED ${source_files})
+ target_include_directories(${name} PUBLIC "${TORCH_INCLUDE_DIRS}")
+ target_link_libraries(${name} "${TORCH_LIBRARIES}")
+ target_compile_options(${name} PUBLIC -std=c++11)
+endfunction(torch_add_custom_op_library)
diff --git a/cmake/TorchConfigVersion.cmake.in b/cmake/TorchConfigVersion.cmake.in
new file mode 100644
index 0000000..d9966d6
--- /dev/null
+++ b/cmake/TorchConfigVersion.cmake.in
@@ -0,0 +1,11 @@
+set(PACKAGE_VERSION "@TORCH_VERSION@")
+
+# Check whether the requested PACKAGE_FIND_VERSION is compatible
+if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
+else()
+ set(PACKAGE_VERSION_COMPATIBLE TRUE)
+ if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
+ set(PACKAGE_VERSION_EXACT TRUE)
+ endif()
+endif()
diff --git a/setup.py b/setup.py
index 1a9b0ee..3abd279 100644
--- a/setup.py
+++ b/setup.py
@@ -415,6 +415,7 @@
self.copy_tree('third_party/pybind11/include/pybind11/',
'torch/lib/include/pybind11')
self.copy_file('torch/csrc/torch.h', 'torch/lib/include/torch/torch.h')
+ self.copy_file('torch/op.h', 'torch/lib/include/torch/op.h')
build_dep_cmds = {}
diff --git a/test/custom_operator/CMakeLists.txt b/test/custom_operator/CMakeLists.txt
new file mode 100644
index 0000000..15338cc
--- /dev/null
+++ b/test/custom_operator/CMakeLists.txt
@@ -0,0 +1,10 @@
+# Basic CMake setup
+cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
+project(custom_op)
+
+find_package(Torch REQUIRED)
+
+torch_add_custom_op_library(custom_op op.cpp)
+
+add_executable(custom_op_test test.cpp)
+target_link_libraries(custom_op_test custom_op)
diff --git a/test/custom_operator/op.cpp b/test/custom_operator/op.cpp
new file mode 100644
index 0000000..ec24967
--- /dev/null
+++ b/test/custom_operator/op.cpp
@@ -0,0 +1,18 @@
+#include <torch/op.h>
+
+#include <cstddef>
+#include <vector>
+
+std::vector<at::Tensor> custom_op(
+ at::Tensor tensor,
+ double scalar,
+ int64_t repeat) {
+ std::vector<at::Tensor> output;
+ output.reserve(repeat);
+ for (int64_t i = 0; i < repeat; ++i) {
+ output.push_back(tensor * scalar);
+ }
+ return output;
+}
+
+static torch::RegisterOperators registry("custom::op", &custom_op);
diff --git a/test/custom_operator/op.h b/test/custom_operator/op.h
new file mode 100644
index 0000000..d45123d
--- /dev/null
+++ b/test/custom_operator/op.h
@@ -0,0 +1,9 @@
+#include <torch/op.h>
+
+#include <cstddef>
+#include <vector>
+
+std::vector<at::Tensor> custom_op(
+ at::Tensor tensor,
+ double scalar,
+ int64_t repeat);
diff --git a/test/custom_operator/test.cpp b/test/custom_operator/test.cpp
new file mode 100644
index 0000000..57ad66d
--- /dev/null
+++ b/test/custom_operator/test.cpp
@@ -0,0 +1,25 @@
+#include "op.h"
+
+#include <cassert>
+#include <vector>
+
+int main() {
+ auto& ops = torch::jit::getAllOperatorsFor(
+ torch::jit::Symbol::fromQualString("custom::op"));
+ assert(ops.size() == 1);
+
+ auto& op = ops.front();
+ assert(op->schema().name == "custom::op");
+
+ torch::jit::Stack stack;
+ torch::jit::push(stack, torch::ones(5), 2.0, 3);
+ op->getOperation()(stack);
+ std::vector<at::Tensor> output;
+ torch::jit::pop(stack, output);
+
+ assert(output.size() == 3);
+ for (const auto& tensor : output) {
+ assert(tensor.allclose(torch::ones(5) * 2));
+ }
+ std::cout << "success" << std::endl;
+}
diff --git a/test/custom_operator/test.py b/test/custom_operator/test.py
new file mode 100644
index 0000000..2a04231
--- /dev/null
+++ b/test/custom_operator/test.py
@@ -0,0 +1,12 @@
+import os
+import torch
+
+library_path = os.path.abspath('build/libcustom_op.so')
+torch.ops.load_library(library_path)
+assert library_path in torch.ops.loaded_libraries
+
+output = torch.ops.custom.op(torch.ones(5), 2.0, 3)
+assert type(output) == list
+assert len(output) == 3
+assert all(tensor.allclose(torch.ones(5) * 2) for tensor in output)
+print('success')
diff --git a/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect b/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect
new file mode 100644
index 0000000..04af766
--- /dev/null
+++ b/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect
@@ -0,0 +1,4 @@
+graph(%x : Dynamic) {
+ %1 : Dynamic = ^aten::relu()(%x)
+ return (%1);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index cf7b0f7..4f1c3f3 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -6266,7 +6266,7 @@
pass
-class TestCustomOperators(TestCase):
+class TestCustomOperators(JitTestCase):
def test_dynamic_op_registry(self):
from torch._ops import _OpNamespace
@@ -6337,19 +6337,30 @@
"Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
):
torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
- #
- # def test_passing_and_returning_lists(self):
- # a, b = torch.ones(5), torch.zeros(5)
- # output = torch.ops.aten.stack([a, b])
- # self.assertEqual(output, torch.ones(10))
- #
- # def test_throws_for_tuples(self):
- # with self.assertRaisesRegex(
- # RuntimeError,
- # "Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
- # ):
- # torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
+ def test_passing_and_returning_lists(self):
+ # Replace with actual test once we support lists.
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "Lists and tuples are not supported yet"
+ ):
+ a, b = torch.ones(5), torch.zeros(5)
+ output = torch.ops.aten.stack([a, b])
+ self.assertEqual(output, torch.ones(10))
+
+ def test_passing_and_returning_tuples(self):
+ # Replace with actual test once we support tuples.
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "Lists and tuples are not supported yet"
+ ):
+ torch.ops.aten.max_pool2d(torch.ones(5, 5), [2, 2])
+
+ def test_script_graph_contains_custom_op(self):
+ @torch.jit.script
+ def func(x):
+ return torch.ops.aten.relu(x)
+ self.assertExpected(canonical(func.graph))
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
# and we have to disable the failing tests here instead.
diff --git a/tools/jit/templates/register_aten_ops.cpp b/tools/jit/templates/register_aten_ops.cpp
index 3dc9734..4ac499b 100644
--- a/tools/jit/templates/register_aten_ops.cpp
+++ b/tools/jit/templates/register_aten_ops.cpp
@@ -1,4 +1,5 @@
#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/jit/interned_strings.h"
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index deb9c45..1bd6453 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -11,7 +11,14 @@
option(BUILD_TORCH_TEST "Build torch test binaries" ON)
+# TODO: Unify with version from setup.py
+set(TORCH_VERSION_MAJOR 0)
+set(TORCH_VERSION_MINOR 4)
+set(TORCH_VERSION_PATCH 1)
+set(TORCH_VERSION "${TORCH_VERSION_MAJOR}.${TORCH_VERSION_MINOR}.${TORCH_VERSION_PATCH}")
+
set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
+set(TORCH_ROOT "${TORCH_SRC_DIR}/..")
add_subdirectory(../third_party/nanopb protobuf-nanopb)
@@ -55,9 +62,9 @@
endif()
# Generate files
-set(TOOLS_PATH "${TORCH_SRC_DIR}/../tools")
+set(TOOLS_PATH "${TORCH_ROOT}/tools")
-configure_file("${TORCH_SRC_DIR}/../aten/src/ATen/common_with_cwrap.py"
+configure_file("${TORCH_ROOT}/aten/src/ATen/common_with_cwrap.py"
"${TOOLS_PATH}/shared/cwrap_common.py"
COPYONLY)
@@ -113,7 +120,7 @@
"${TOOLS_PATH}/jit/gen_jit_dispatch.py"
"${TOOLS_PATH}/jit/templates/register_aten_ops.cpp"
"${TOOLS_PATH}/jit/templates/aten_interned_strings.h"
- WORKING_DIRECTORY "${TORCH_SRC_DIR}/..")
+ WORKING_DIRECTORY "${TORCH_ROOT}")
set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
@@ -211,7 +218,7 @@
else()
set (MSVC_RUNTIME_LIBRARY_FLAG "/MD")
endif()
-
+
target_compile_options(torch PRIVATE
${MSVC_RUNTIME_LIBRARY_OPTION}
/Z7
@@ -339,9 +346,9 @@
set(TH_CPU_INCLUDE
# dense
- ${TORCH_SRC_DIR}/../aten/src/TH
+ ${TORCH_ROOT}/aten/src/TH
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/TH
- ${TORCH_SRC_DIR}/../aten/src
+ ${TORCH_ROOT}/aten/src
${CMAKE_CURRENT_BINARY_DIR}/../aten/src
${CMAKE_BINARY_DIR}/aten/src)
target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE})
@@ -349,13 +356,13 @@
if(USE_CUDA OR USE_ROCM)
set(TH_CUDA_INCLUDE
# dense
- ${TORCH_SRC_DIR}/../aten/src/THC
+ ${TORCH_ROOT}/aten/src/THC
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/THC)
target_include_directories(torch PRIVATE ${TH_CUDA_INCLUDE})
endif()
set(ATen_CPU_INCLUDE
- ${TORCH_SRC_DIR}/../aten/src
+ ${TORCH_ROOT}/aten/src
${CMAKE_CURRENT_BINARY_DIR}/../aten/src
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/ATen
${CMAKE_BINARY_DIR}/aten/src)
@@ -366,8 +373,8 @@
# SYSTEM headers are included with -isystem and thus do not trigger warnings.
target_include_directories(torch SYSTEM PUBLIC
- "${TORCH_SRC_DIR}/../third_party/cereal/include" # For cereal/
- "${TORCH_SRC_DIR}/../third_party/nanopb")
+ "${TORCH_ROOT}/third_party/cereal/include" # For cereal/
+ "${TORCH_ROOT}/third_party/nanopb")
set_target_properties(torch PROPERTIES VERSION 1 SOVERSION 1)
@@ -390,7 +397,7 @@
target_link_libraries(test_jit torch ${TORCH_CUDA_LIBRARIES})
target_compile_definitions(test_jit PUBLIC USE_CATCH _FORCE_INLINES)
target_include_directories(test_jit PUBLIC
- "${TORCH_SRC_DIR}/../third_party/catch/single_include"
+ "${TORCH_ROOT}/third_party/catch/single_include"
${ATen_CPU_INCLUDE})
if (USE_CUDA)
@@ -399,7 +406,7 @@
endif()
if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
- set(TORCH_API_TEST_DIR "${TORCH_SRC_DIR}/../test/cpp/api")
+ set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
add_executable(test_api
${TORCH_API_TEST_DIR}/any.cpp
@@ -424,7 +431,7 @@
target_include_directories(test_api
PUBLIC
- "${TORCH_SRC_DIR}/../third_party/catch/single_include"
+ "${TORCH_ROOT}/third_party/catch/single_include"
${ATen_CPU_INCLUDE})
target_link_libraries(test_api torch ${TORCH_CUDA_LIBRARIES})
@@ -445,3 +452,13 @@
endif()
endif()
endif()
+
+# CMake config for external projects.
+configure_file(
+ ${PROJECT_SOURCE_DIR}/cmake/TorchConfigVersion.cmake.in
+ ${PROJECT_BINARY_DIR}/TorchConfigVersion.cmake
+ @ONLY)
+configure_file(
+ ${TORCH_ROOT}/cmake/TorchConfig.cmake.in
+ ${PROJECT_BINARY_DIR}/TorchConfig.cmake
+ @ONLY)
diff --git a/torch/_ops.py b/torch/_ops.py
index e5900fb..a16e34a 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -1,5 +1,27 @@
import torch._C
+import contextlib
+import ctypes
+import sys
+
+
+# Query `hasattr` only once.
+_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
+
+
+@contextlib.contextmanager
+def dl_open_guard():
+ """
+ Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
+ shared library to load custom operators.
+ """
+ if _SET_GLOBAL_FLAGS:
+ old_flags = sys.getdlopenflags()
+ sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
+ yield
+ if _SET_GLOBAL_FLAGS:
+ sys.setdlopenflags(old_flags)
+
class _OpNamespace(object):
"""
@@ -33,12 +55,40 @@
class _Ops(object):
+ def __init__(self):
+ self.loaded_libraries = set()
+
def __getattr__(self, name):
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
setattr(self, name, namespace)
return namespace
+ def load_library(self, path):
+ """
+ Loads a shared library from the given path into the current process.
+
+ The library being loaded may run global initialization code to register
+ custom operators with the PyTorch JIT runtime. This allows dynamically
+ loading custom operators. For this, you should compile your operator
+ and the static registration code into a shared library object, and then
+ call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
+ shared object.
+
+ After the library is loaded, it is added to the
+ ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
+ for the paths of all libraries loaded using this function.
+
+ Arguments:
+ path (str): A path to a shared library to load.
+ """
+ with dl_open_guard():
+ # Import the shared library into the process, thus running its
+ # static (global) initialization code in order to register custom
+ # operators with the JIT.
+ ctypes.CDLL(path)
+ self.loaded_libraries.add(path)
+
# The ops "namespace"
ops = _Ops()
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index eab076a..cc1e390 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -1,5 +1,6 @@
#include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/autograd/variable.h"
namespace torch { namespace jit {
diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h
index f0cc29d..b501c25 100644
--- a/torch/csrc/jit/custom_operator.h
+++ b/torch/csrc/jit/custom_operator.h
@@ -1,7 +1,6 @@
#pragma once
#include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/tracer.h>
@@ -79,13 +78,10 @@
/// Does two things for an operator implementation and a tuple of arguments:
/// 1. Pops all necessary arguments off the stack into the tuple's elements,
/// 2. Unpacks the tuple and calls the operator implementation.
-/// The result of the implementation call is returned.
-template <
- typename ReturnType,
- typename Implementation,
- typename... Types,
- size_t... Is>
-ReturnType callOperatorWithTuple(
+/// If tracing is currently enabled, this function will also take care of
+/// tracing the operator call.
+template <typename Implementation, typename... Types, size_t... Is>
+void callOperatorWithTuple(
const FunctionSchema& schema,
Implementation&& implementation,
Stack& stack,
@@ -104,10 +100,10 @@
jit::tracer::postRecordTrace(node, result);
}
- return result;
+ push(stack, IValue(std::move(result)));
}
-void checkArgumentVector(
+inline void checkArgumentVector(
const char* what,
const std::vector<Argument>& inferred,
const std::vector<Argument>& provided,
@@ -204,21 +200,54 @@
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
using ArgumentTuple =
typename c10::guts::typelist::to_tuple<ArgumentTypes>::type;
- using ReturnType = decay_t<typename Traits::return_type>;
auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
return Operator(schema, [implementation, schema](Stack& stack) {
ArgumentTuple tuple;
- auto result = torch::jit::detail::callOperatorWithTuple<ReturnType>(
+ torch::jit::detail::callOperatorWithTuple(
schema,
std::move(implementation),
stack,
tuple,
typename MakeIndices<std::tuple_size<ArgumentTuple>::value>::indices{});
- pack(stack, std::move(result));
return 0;
});
}
+
+/// Registration class for new operators. Effectively calls
+/// `torch::jit::registerOperator` for every supplied operator, but allows doing
+/// so in the global scope when a `RegisterOperators` object is assigned to a
+/// static variable. Also handles registration of user-defined, "custom"
+/// operators.
+struct TORCH_API RegisterOperators {
+ RegisterOperators() = default;
+
+ /// Registers a vector of already created `Operator`s.
+ RegisterOperators(std::vector<Operator> operators) {
+ for (Operator& o : operators) {
+ registerOperator(std::move(o));
+ }
+ }
+
+ /// Calls `op(...)` with the given operator name and implementation.
+ template <typename Implementation>
+ RegisterOperators(const std::string& name, Implementation&& implementation) {
+ op(name, std::forward<Implementation>(implementation));
+ }
+
+ /// Creates a new operator from a name and implementation function (function
+ /// pointer or function object/lambda) using `torch::jit::createOperator`, and
+ /// then registers the operator.
+ template <typename Implementation>
+ RegisterOperators& op(
+ const std::string& name,
+ Implementation&& implementation) {
+ registerOperator(
+ createOperator(name, std::forward<Implementation>(implementation)));
+ return *this;
+ }
+};
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 933a3e9..4873fd7 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -33,6 +33,13 @@
#include <pybind11/functional.h>
+#include <memory>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <tuple>
+#include <utility>
+
namespace torch { namespace jit {
namespace {
@@ -232,10 +239,14 @@
"Found ", operations.size(), " overloads for operator ",
qualified_name, "! Overloads are not supported from Python.");
std::shared_ptr<Operator> op = operations[0];
+ AT_ASSERT(op != nullptr);
+ std::ostringstream docstring;
+ docstring << "Automatically bound operator '" << qualified_name
+ << "' with schema: " << op->schema();
return py::cpp_function([op](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython(
*op, std::move(args), std::move(kwargs));
- });
+ }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
} catch (const at::Error& error) {
throw std::runtime_error(error.what_without_backtrace());
}
diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h
index 7d5e97f..afa4f9c 100644
--- a/torch/csrc/jit/operator.h
+++ b/torch/csrc/jit/operator.h
@@ -94,14 +94,6 @@
// XXX: this function is meant to be used with string literals only!
Operator& sig(const char *signature_literal);
-struct TORCH_API RegisterOperators {
- RegisterOperators(std::vector<Operator> operators) {
- for(Operator& o : operators) {
- registerOperator(std::move(o));
- }
- }
-};
-
struct OperatorSet {
OperatorSet(std::initializer_list<const char *> sig_literals);
// XXX: Returns a nullptr if no Operator in the set matches n
diff --git a/torch/csrc/jit/python_interpreter.cpp b/torch/csrc/jit/python_interpreter.cpp
index d8baa18..86bd4df 100644
--- a/torch/csrc/jit/python_interpreter.cpp
+++ b/torch/csrc/jit/python_interpreter.cpp
@@ -7,6 +7,7 @@
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/fusion_compiler.h"
#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/ir.h"
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 749b412..400f9da 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -7,6 +7,7 @@
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/variable_tensor_functions.h"
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 6585c72..940d47a 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -1059,6 +1059,34 @@
REQUIRE(output[1] == 2.0);
}
{
+ RegisterOperators reg(
+ "foo::lists2(Tensor[] tensors) -> Tensor[]",
+ [](std::vector<at::Tensor> tensors) { return tensors; });
+
+ auto& ops =
+ getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
+ REQUIRE(ops.size() == 1);
+
+ auto& op = ops.front();
+ REQUIRE(op->schema().name == "foo::lists2");
+
+ REQUIRE(op->schema().arguments.size() == 1);
+ REQUIRE(op->schema().arguments[0].name == "tensors");
+ REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
+
+ REQUIRE(op->schema().returns.size() == 1);
+ REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
+
+ Stack stack;
+ push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
+ op->getOperation()(stack);
+ std::vector<at::Tensor> output;
+ pop(stack, output);
+
+ REQUIRE(output.size() == 1);
+ REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
+ }
+ {
#ifdef USE_CATCH
REQUIRE_THROWS_WITH(
createOperator(
diff --git a/torch/op.h b/torch/op.h
new file mode 100644
index 0000000..88e9dc6
--- /dev/null
+++ b/torch/op.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include <torch/csrc/autograd/generated/variable_factories.h>
+#include <torch/csrc/jit/custom_operator.h>
+
+#include <ATen/ATen.h>
+
+namespace torch {
+using jit::createOperator;
+using jit::RegisterOperators;
+} // namespace torch