Build mechanism for custom operators (#10226)

Summary:
This is the last step in the custom operator implementation: providing a way to build from C++ and Python. For this I:

1. Created a `FindTorch.cmake` taken largely from ebetica with a CMake function to easily create simple custom op libraries
2. Created a ` torch/op.h` header for easy inclusion of necessary headers,
3. Created a test directory `pytorch/test/custom_operator` which includes the basic setup for a custom op.
    1. It defines an op in `op.{h,cpp}`
    2. Registers it with the JIT using `RegisterOperators`
    3. Builds it into a shared library via a `CMakeLists.txt`
    4. Binds it into Python using a `setup.py`. This step makes use of our C++ extension setup that we already have. No work, yey!

The pure C++ and the Python builds are separate and not coupled in any way.

zdevito soumith dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10226

Differential Revision: D9296839

Pulled By: goldsborough

fbshipit-source-id: 32f74cafb6e3d86cada8dfca8136d0dfb1f197a0
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 577f158..23a5080 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -230,7 +230,7 @@
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
   endif()
   if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0")))
-    OR (CMAKE_COMPILER_IS_GNUCXX 
+    OR (CMAKE_COMPILER_IS_GNUCXX
     AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE)))
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
   endif()
diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in
new file mode 100644
index 0000000..a14b2e1
--- /dev/null
+++ b/cmake/TorchConfig.cmake.in
@@ -0,0 +1,55 @@
+# FindTorch
+# -------
+#
+# Finds the Torch library
+#
+# This will define the following variables:
+#
+#   TORCH_FOUND        -- True if the system has the Torch library
+#   TORCH_INCLUDE_DIRS -- The include directories for torch
+#   TORCH_LIBRARIES    -- Libraries to link to
+#
+# and the following imported targets:
+#
+#   Torch
+#
+# and the following functions:
+#
+#   torch_add_custom_op_library(<name> <source_files>)
+
+SET(TORCH_ROOT "${CMAKE_CURRENT_LIST_DIR}/../")
+
+set(TORCH_INCLUDE_DIRS
+  "${TORCH_ROOT}"
+  "${TORCH_ROOT}/aten/src"
+  "${CMAKE_CURRENT_LIST_DIR}/aten/src"
+  "${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src"
+  "${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src/TH"
+)
+
+find_library(TORCH_LIBRARY torch PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
+find_library(CAFFE2_LIBRARY caffe2 PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
+
+if (@USE_CUDA@)
+  find_package(CUDA REQUIRED)
+  find_library(CAFFE2_CUDA_LIBRARY caffe2_gpu PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
+  set(TORCH_CUDA_LIBRARIES -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 cuda nvrtc cudart nvToolsExt)
+  list(APPEND TORCH_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
+endif()
+
+set(TORCH_LIBRARIES
+  ${TORCH_LIBRARY}
+  ${CAFFE2_LIBRARY}
+  ${CAFFE2_CUDA_LIBRARY}
+  ${TORCH_CUDA_LIBRARIES})
+
+# Creates a shared library <name> with the correct include directories
+# and linker flags set to include Torch header files and link with Torch
+# libraries. Also sets the C++ standard version to C++11. All options
+# can be override by specifying further options on the `<name>` CMake target.
+function(torch_add_custom_op_library name source_files)
+  add_library(${name} SHARED ${source_files})
+  target_include_directories(${name} PUBLIC "${TORCH_INCLUDE_DIRS}")
+  target_link_libraries(${name} "${TORCH_LIBRARIES}")
+  target_compile_options(${name} PUBLIC -std=c++11)
+endfunction(torch_add_custom_op_library)
diff --git a/cmake/TorchConfigVersion.cmake.in b/cmake/TorchConfigVersion.cmake.in
new file mode 100644
index 0000000..d9966d6
--- /dev/null
+++ b/cmake/TorchConfigVersion.cmake.in
@@ -0,0 +1,11 @@
+set(PACKAGE_VERSION "@TORCH_VERSION@")
+
+# Check whether the requested PACKAGE_FIND_VERSION is compatible
+if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
+  set(PACKAGE_VERSION_COMPATIBLE FALSE)
+else()
+  set(PACKAGE_VERSION_COMPATIBLE TRUE)
+  if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
+    set(PACKAGE_VERSION_EXACT TRUE)
+  endif()
+endif()
diff --git a/setup.py b/setup.py
index 1a9b0ee..3abd279 100644
--- a/setup.py
+++ b/setup.py
@@ -415,6 +415,7 @@
         self.copy_tree('third_party/pybind11/include/pybind11/',
                        'torch/lib/include/pybind11')
         self.copy_file('torch/csrc/torch.h', 'torch/lib/include/torch/torch.h')
+        self.copy_file('torch/op.h', 'torch/lib/include/torch/op.h')
 
 
 build_dep_cmds = {}
diff --git a/test/custom_operator/CMakeLists.txt b/test/custom_operator/CMakeLists.txt
new file mode 100644
index 0000000..15338cc
--- /dev/null
+++ b/test/custom_operator/CMakeLists.txt
@@ -0,0 +1,10 @@
+# Basic CMake setup
+cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
+project(custom_op)
+
+find_package(Torch REQUIRED)
+
+torch_add_custom_op_library(custom_op op.cpp)
+
+add_executable(custom_op_test test.cpp)
+target_link_libraries(custom_op_test custom_op)
diff --git a/test/custom_operator/op.cpp b/test/custom_operator/op.cpp
new file mode 100644
index 0000000..ec24967
--- /dev/null
+++ b/test/custom_operator/op.cpp
@@ -0,0 +1,18 @@
+#include <torch/op.h>
+
+#include <cstddef>
+#include <vector>
+
+std::vector<at::Tensor> custom_op(
+    at::Tensor tensor,
+    double scalar,
+    int64_t repeat) {
+  std::vector<at::Tensor> output;
+  output.reserve(repeat);
+  for (int64_t i = 0; i < repeat; ++i) {
+    output.push_back(tensor * scalar);
+  }
+  return output;
+}
+
+static torch::RegisterOperators registry("custom::op", &custom_op);
diff --git a/test/custom_operator/op.h b/test/custom_operator/op.h
new file mode 100644
index 0000000..d45123d
--- /dev/null
+++ b/test/custom_operator/op.h
@@ -0,0 +1,9 @@
+#include <torch/op.h>
+
+#include <cstddef>
+#include <vector>
+
+std::vector<at::Tensor> custom_op(
+    at::Tensor tensor,
+    double scalar,
+    int64_t repeat);
diff --git a/test/custom_operator/test.cpp b/test/custom_operator/test.cpp
new file mode 100644
index 0000000..57ad66d
--- /dev/null
+++ b/test/custom_operator/test.cpp
@@ -0,0 +1,25 @@
+#include "op.h"
+
+#include <cassert>
+#include <vector>
+
+int main() {
+  auto& ops = torch::jit::getAllOperatorsFor(
+      torch::jit::Symbol::fromQualString("custom::op"));
+  assert(ops.size() == 1);
+
+  auto& op = ops.front();
+  assert(op->schema().name == "custom::op");
+
+  torch::jit::Stack stack;
+  torch::jit::push(stack, torch::ones(5), 2.0, 3);
+  op->getOperation()(stack);
+  std::vector<at::Tensor> output;
+  torch::jit::pop(stack, output);
+
+  assert(output.size() == 3);
+  for (const auto& tensor : output) {
+    assert(tensor.allclose(torch::ones(5) * 2));
+  }
+  std::cout << "success" << std::endl;
+}
diff --git a/test/custom_operator/test.py b/test/custom_operator/test.py
new file mode 100644
index 0000000..2a04231
--- /dev/null
+++ b/test/custom_operator/test.py
@@ -0,0 +1,12 @@
+import os
+import torch
+
+library_path = os.path.abspath('build/libcustom_op.so')
+torch.ops.load_library(library_path)
+assert library_path in torch.ops.loaded_libraries
+
+output = torch.ops.custom.op(torch.ones(5), 2.0, 3)
+assert type(output) == list
+assert len(output) == 3
+assert all(tensor.allclose(torch.ones(5) * 2) for tensor in output)
+print('success')
diff --git a/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect b/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect
new file mode 100644
index 0000000..04af766
--- /dev/null
+++ b/test/expect/TestCustomOperators.test_script_graph_contains_custom_op.expect
@@ -0,0 +1,4 @@
+graph(%x : Dynamic) {
+  %1 : Dynamic = ^aten::relu()(%x)
+  return (%1);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index cf7b0f7..4f1c3f3 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -6266,7 +6266,7 @@
     pass
 
 
-class TestCustomOperators(TestCase):
+class TestCustomOperators(JitTestCase):
 
     def test_dynamic_op_registry(self):
         from torch._ops import _OpNamespace
@@ -6337,19 +6337,30 @@
             "Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
         ):
             torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
-    #
-    # def test_passing_and_returning_lists(self):
-    #     a, b = torch.ones(5), torch.zeros(5)
-    #     output = torch.ops.aten.stack([a, b])
-    #     self.assertEqual(output, torch.ones(10))
-    #
-    # def test_throws_for_tuples(self):
-    #     with self.assertRaisesRegex(
-    #         RuntimeError,
-    #         "Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
-    #     ):
-    #         torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
 
+    def test_passing_and_returning_lists(self):
+        # Replace with actual test once we support lists.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Lists and tuples are not supported yet"
+        ):
+            a, b = torch.ones(5), torch.zeros(5)
+            output = torch.ops.aten.stack([a, b])
+            self.assertEqual(output, torch.ones(10))
+
+    def test_passing_and_returning_tuples(self):
+        # Replace with actual test once we support tuples.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Lists and tuples are not supported yet"
+        ):
+            torch.ops.aten.max_pool2d(torch.ones(5, 5), [2, 2])
+
+    def test_script_graph_contains_custom_op(self):
+        @torch.jit.script
+        def func(x):
+            return torch.ops.aten.relu(x)
+        self.assertExpected(canonical(func.graph))
 
 # UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
 # and we have to disable the failing tests here instead.
diff --git a/tools/jit/templates/register_aten_ops.cpp b/tools/jit/templates/register_aten_ops.cpp
index 3dc9734..4ac499b 100644
--- a/tools/jit/templates/register_aten_ops.cpp
+++ b/tools/jit/templates/register_aten_ops.cpp
@@ -1,4 +1,5 @@
 #include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
 
 #include "torch/csrc/autograd/profiler.h"
 #include "torch/csrc/jit/interned_strings.h"
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index deb9c45..1bd6453 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -11,7 +11,14 @@
 
 option(BUILD_TORCH_TEST "Build torch test binaries" ON)
 
+# TODO: Unify with version from setup.py
+set(TORCH_VERSION_MAJOR 0)
+set(TORCH_VERSION_MINOR 4)
+set(TORCH_VERSION_PATCH 1)
+set(TORCH_VERSION "${TORCH_VERSION_MAJOR}.${TORCH_VERSION_MINOR}.${TORCH_VERSION_PATCH}")
+
 set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
+set(TORCH_ROOT "${TORCH_SRC_DIR}/..")
 
 add_subdirectory(../third_party/nanopb protobuf-nanopb)
 
@@ -55,9 +62,9 @@
 endif()
 
 # Generate files
-set(TOOLS_PATH "${TORCH_SRC_DIR}/../tools")
+set(TOOLS_PATH "${TORCH_ROOT}/tools")
 
-configure_file("${TORCH_SRC_DIR}/../aten/src/ATen/common_with_cwrap.py"
+configure_file("${TORCH_ROOT}/aten/src/ATen/common_with_cwrap.py"
                "${TOOLS_PATH}/shared/cwrap_common.py"
                COPYONLY)
 
@@ -113,7 +120,7 @@
   "${TOOLS_PATH}/jit/gen_jit_dispatch.py"
   "${TOOLS_PATH}/jit/templates/register_aten_ops.cpp"
   "${TOOLS_PATH}/jit/templates/aten_interned_strings.h"
-  WORKING_DIRECTORY "${TORCH_SRC_DIR}/..")
+  WORKING_DIRECTORY "${TORCH_ROOT}")
 
 set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
@@ -211,7 +218,7 @@
   else()
     set (MSVC_RUNTIME_LIBRARY_FLAG "/MD")
   endif()
-  
+
   target_compile_options(torch PRIVATE
     ${MSVC_RUNTIME_LIBRARY_OPTION}
     /Z7
@@ -339,9 +346,9 @@
 
 set(TH_CPU_INCLUDE
   # dense
-  ${TORCH_SRC_DIR}/../aten/src/TH
+  ${TORCH_ROOT}/aten/src/TH
   ${CMAKE_CURRENT_BINARY_DIR}/../aten/src/TH
-  ${TORCH_SRC_DIR}/../aten/src
+  ${TORCH_ROOT}/aten/src
   ${CMAKE_CURRENT_BINARY_DIR}/../aten/src
   ${CMAKE_BINARY_DIR}/aten/src)
 target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE})
@@ -349,13 +356,13 @@
 if(USE_CUDA OR USE_ROCM)
   set(TH_CUDA_INCLUDE
     # dense
-    ${TORCH_SRC_DIR}/../aten/src/THC
+    ${TORCH_ROOT}/aten/src/THC
     ${CMAKE_CURRENT_BINARY_DIR}/../aten/src/THC)
   target_include_directories(torch PRIVATE ${TH_CUDA_INCLUDE})
 endif()
 
 set(ATen_CPU_INCLUDE
-  ${TORCH_SRC_DIR}/../aten/src
+  ${TORCH_ROOT}/aten/src
   ${CMAKE_CURRENT_BINARY_DIR}/../aten/src
   ${CMAKE_CURRENT_BINARY_DIR}/../aten/src/ATen
   ${CMAKE_BINARY_DIR}/aten/src)
@@ -366,8 +373,8 @@
 
 # SYSTEM headers are included with -isystem and thus do not trigger warnings.
 target_include_directories(torch SYSTEM PUBLIC
-  "${TORCH_SRC_DIR}/../third_party/cereal/include" # For cereal/
-  "${TORCH_SRC_DIR}/../third_party/nanopb")
+  "${TORCH_ROOT}/third_party/cereal/include" # For cereal/
+  "${TORCH_ROOT}/third_party/nanopb")
 
 set_target_properties(torch PROPERTIES VERSION 1 SOVERSION 1)
 
@@ -390,7 +397,7 @@
   target_link_libraries(test_jit torch ${TORCH_CUDA_LIBRARIES})
   target_compile_definitions(test_jit PUBLIC USE_CATCH _FORCE_INLINES)
   target_include_directories(test_jit PUBLIC
-    "${TORCH_SRC_DIR}/../third_party/catch/single_include"
+    "${TORCH_ROOT}/third_party/catch/single_include"
     ${ATen_CPU_INCLUDE})
 
   if (USE_CUDA)
@@ -399,7 +406,7 @@
 endif()
 
 if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
-  set(TORCH_API_TEST_DIR "${TORCH_SRC_DIR}/../test/cpp/api")
+  set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
 
   add_executable(test_api
     ${TORCH_API_TEST_DIR}/any.cpp
@@ -424,7 +431,7 @@
 
   target_include_directories(test_api
     PUBLIC
-    "${TORCH_SRC_DIR}/../third_party/catch/single_include"
+    "${TORCH_ROOT}/third_party/catch/single_include"
     ${ATen_CPU_INCLUDE})
 
   target_link_libraries(test_api torch ${TORCH_CUDA_LIBRARIES})
@@ -445,3 +452,13 @@
     endif()
   endif()
 endif()
+
+# CMake config for external projects.
+configure_file(
+    ${PROJECT_SOURCE_DIR}/cmake/TorchConfigVersion.cmake.in
+    ${PROJECT_BINARY_DIR}/TorchConfigVersion.cmake
+    @ONLY)
+configure_file(
+    ${TORCH_ROOT}/cmake/TorchConfig.cmake.in
+    ${PROJECT_BINARY_DIR}/TorchConfig.cmake
+    @ONLY)
diff --git a/torch/_ops.py b/torch/_ops.py
index e5900fb..a16e34a 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -1,5 +1,27 @@
 import torch._C
 
+import contextlib
+import ctypes
+import sys
+
+
+# Query `hasattr` only once.
+_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
+
+
+@contextlib.contextmanager
+def dl_open_guard():
+    """
+    Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
+    shared library to load custom operators.
+    """
+    if _SET_GLOBAL_FLAGS:
+        old_flags = sys.getdlopenflags()
+        sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
+    yield
+    if _SET_GLOBAL_FLAGS:
+        sys.setdlopenflags(old_flags)
+
 
 class _OpNamespace(object):
     """
@@ -33,12 +55,40 @@
 
 
 class _Ops(object):
+    def __init__(self):
+        self.loaded_libraries = set()
+
     def __getattr__(self, name):
         # Here we are creating `torch.ops.my_namespace`
         namespace = _OpNamespace(name)
         setattr(self, name, namespace)
         return namespace
 
+    def load_library(self, path):
+        """
+        Loads a shared library from the given path into the current process.
+
+        The library being loaded may run global initialization code to register
+        custom operators with the PyTorch JIT runtime. This allows dynamically
+        loading custom operators. For this, you should compile your operator
+        and the static registration code into a shared library object, and then
+        call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
+        shared object.
+
+        After the library is loaded, it is added to the
+        ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
+        for the paths of all libraries loaded using this function.
+
+        Arguments:
+            path (str): A path to a shared library to load.
+        """
+        with dl_open_guard():
+            # Import the shared library into the process, thus running its
+            # static (global) initialization code in order to register custom
+            # operators with the JIT.
+            ctypes.CDLL(path)
+        self.loaded_libraries.add(path)
+
 
 # The ops "namespace"
 ops = _Ops()
diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp
index eab076a..cc1e390 100644
--- a/torch/csrc/jit/constants.cpp
+++ b/torch/csrc/jit/constants.cpp
@@ -1,5 +1,6 @@
 #include "torch/csrc/jit/constants.h"
 #include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
 #include "torch/csrc/autograd/variable.h"
 
 namespace torch { namespace jit {
diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h
index f0cc29d..b501c25 100644
--- a/torch/csrc/jit/custom_operator.h
+++ b/torch/csrc/jit/custom_operator.h
@@ -1,7 +1,6 @@
 #pragma once
 
 #include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/stack.h>
 #include <torch/csrc/jit/tracer.h>
@@ -79,13 +78,10 @@
 /// Does two things for an operator implementation and a tuple of arguments:
 /// 1. Pops all necessary arguments off the stack into the tuple's elements,
 /// 2. Unpacks the tuple and calls the operator implementation.
-/// The result of the implementation call is returned.
-template <
-    typename ReturnType,
-    typename Implementation,
-    typename... Types,
-    size_t... Is>
-ReturnType callOperatorWithTuple(
+/// If tracing is currently enabled, this function will also take care of
+/// tracing the operator call.
+template <typename Implementation, typename... Types, size_t... Is>
+void callOperatorWithTuple(
     const FunctionSchema& schema,
     Implementation&& implementation,
     Stack& stack,
@@ -104,10 +100,10 @@
     jit::tracer::postRecordTrace(node, result);
   }
 
-  return result;
+  push(stack, IValue(std::move(result)));
 }
 
-void checkArgumentVector(
+inline void checkArgumentVector(
     const char* what,
     const std::vector<Argument>& inferred,
     const std::vector<Argument>& provided,
@@ -204,21 +200,54 @@
       c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
   using ArgumentTuple =
       typename c10::guts::typelist::to_tuple<ArgumentTypes>::type;
-  using ReturnType = decay_t<typename Traits::return_type>;
 
   auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
 
   return Operator(schema, [implementation, schema](Stack& stack) {
     ArgumentTuple tuple;
-    auto result = torch::jit::detail::callOperatorWithTuple<ReturnType>(
+    torch::jit::detail::callOperatorWithTuple(
         schema,
         std::move(implementation),
         stack,
         tuple,
         typename MakeIndices<std::tuple_size<ArgumentTuple>::value>::indices{});
-    pack(stack, std::move(result));
     return 0;
   });
 }
+
+/// Registration class for new operators. Effectively calls
+/// `torch::jit::registerOperator` for every supplied operator, but allows doing
+/// so in the global scope when a `RegisterOperators` object is assigned to a
+/// static variable. Also handles registration of user-defined, "custom"
+/// operators.
+struct TORCH_API RegisterOperators {
+  RegisterOperators() = default;
+
+  /// Registers a vector of already created `Operator`s.
+  RegisterOperators(std::vector<Operator> operators) {
+    for (Operator& o : operators) {
+      registerOperator(std::move(o));
+    }
+  }
+
+  /// Calls `op(...)` with the given operator name and implementation.
+  template <typename Implementation>
+  RegisterOperators(const std::string& name, Implementation&& implementation) {
+    op(name, std::forward<Implementation>(implementation));
+  }
+
+  /// Creates a new operator from a name and implementation function (function
+  /// pointer or function object/lambda) using `torch::jit::createOperator`, and
+  /// then registers the operator.
+  template <typename Implementation>
+  RegisterOperators& op(
+      const std::string& name,
+      Implementation&& implementation) {
+    registerOperator(
+        createOperator(name, std::forward<Implementation>(implementation)));
+    return *this;
+  }
+};
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 933a3e9..4873fd7 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -33,6 +33,13 @@
 
 #include <pybind11/functional.h>
 
+#include <memory>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <tuple>
+#include <utility>
+
 namespace torch  { namespace jit {
 
 namespace {
@@ -232,10 +239,14 @@
           "Found ", operations.size(), " overloads for operator ",
           qualified_name, "! Overloads are not supported from Python.");
       std::shared_ptr<Operator> op = operations[0];
+      AT_ASSERT(op != nullptr);
+      std::ostringstream docstring;
+      docstring << "Automatically bound operator '" << qualified_name
+                << "' with schema: " << op->schema();
       return py::cpp_function([op](py::args args, py::kwargs kwargs) {
         return invokeOperatorFromPython(
             *op, std::move(args), std::move(kwargs));
-      });
+      }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
     } catch (const at::Error& error) {
       throw std::runtime_error(error.what_without_backtrace());
     }
diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h
index 7d5e97f..afa4f9c 100644
--- a/torch/csrc/jit/operator.h
+++ b/torch/csrc/jit/operator.h
@@ -94,14 +94,6 @@
 // XXX: this function is meant to be used with string literals only!
 Operator& sig(const char *signature_literal);
 
-struct TORCH_API RegisterOperators {
-  RegisterOperators(std::vector<Operator> operators) {
-    for(Operator& o : operators) {
-      registerOperator(std::move(o));
-    }
-  }
-};
-
 struct OperatorSet {
   OperatorSet(std::initializer_list<const char *> sig_literals);
   // XXX: Returns a nullptr if no Operator in the set matches n
diff --git a/torch/csrc/jit/python_interpreter.cpp b/torch/csrc/jit/python_interpreter.cpp
index d8baa18..86bd4df 100644
--- a/torch/csrc/jit/python_interpreter.cpp
+++ b/torch/csrc/jit/python_interpreter.cpp
@@ -7,6 +7,7 @@
 #include "torch/csrc/autograd/variable.h"
 #include "torch/csrc/jit/fusion_compiler.h"
 #include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
 #include "torch/csrc/jit/graph_executor.h"
 #include "torch/csrc/jit/ir.h"
 
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 749b412..400f9da 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -7,6 +7,7 @@
 #include "torch/csrc/jit/graph_executor.h"
 #include "torch/csrc/jit/ir.h"
 #include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/custom_operator.h"
 
 #include "torch/csrc/variable_tensor_functions.h"
 
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 6585c72..940d47a 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -1059,6 +1059,34 @@
     REQUIRE(output[1] == 2.0);
   }
   {
+    RegisterOperators reg(
+        "foo::lists2(Tensor[] tensors) -> Tensor[]",
+        [](std::vector<at::Tensor> tensors) { return tensors; });
+
+    auto& ops =
+        getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
+    REQUIRE(ops.size() == 1);
+
+    auto& op = ops.front();
+    REQUIRE(op->schema().name == "foo::lists2");
+
+    REQUIRE(op->schema().arguments.size() == 1);
+    REQUIRE(op->schema().arguments[0].name == "tensors");
+    REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
+
+    REQUIRE(op->schema().returns.size() == 1);
+    REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
+
+    Stack stack;
+    push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
+    op->getOperation()(stack);
+    std::vector<at::Tensor> output;
+    pop(stack, output);
+
+    REQUIRE(output.size() == 1);
+    REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
+  }
+  {
 #ifdef USE_CATCH
     REQUIRE_THROWS_WITH(
         createOperator(
diff --git a/torch/op.h b/torch/op.h
new file mode 100644
index 0000000..88e9dc6
--- /dev/null
+++ b/torch/op.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include <torch/csrc/autograd/generated/variable_factories.h>
+#include <torch/csrc/jit/custom_operator.h>
+
+#include <ATen/ATen.h>
+
+namespace torch {
+using jit::createOperator;
+using jit::RegisterOperators;
+} // namespace torch