[WIP][JIT] Add JIT backend registration API (#35833)

Summary:
**Summary**
This commit adds `torch::jit::RegisterBackend`, an API that allows
external backends to be registered for the execution of JIT subgraphs
outside the JIT interpreter. In order to register an external backend,
one must extend the provided abstract class `PyTorchBackendInterface` and provide
two additional functions: one that creates an instance of the aforementioned subclass
of `PyTorchBackendInterface`, and another that preprocesses a `ScriptModule` so that
it can run on the backend. Then, a `ScriptModule` that can compile and execute a given
JIT subgraph using the functions provided at registration time is generated
for each registered backend.

**Testing**
This commit adds a unit test that uses a minimal test backend
to make sure that the registration endpoint and generated
`ScriptModule` work.

```
$ python test/test_jit.py TestBackends
Fail to import hypothesis in common_utils, tests are not derandomized
.
----------------------------------------------------------------------
Ran 1 test in 0.183s

OK

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35833

Differential Revision: D21231955

Pulled By: SplitInfinity

fbshipit-source-id: 452db1123d0e5d83f97fe5da8a00fdfdb50dbef9
diff --git a/BUILD.bazel b/BUILD.bazel
index 6d9fd3c..eaf6112 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1912,6 +1912,7 @@
             "torch/csrc/generic/*.cpp",
             "torch/csrc/jit/*.h",
             "torch/csrc/jit/api/*.h",
+            "torch/csrc/jit/backends/backend_interface.h",
             "torch/csrc/jit/codegen/cuda/*.h",
             "torch/csrc/jit/codegen/fuser/*.h",
             "torch/csrc/jit/codegen/fuser/cpu/*.h",
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 3d1c38b..6c91af1 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -7,7 +7,7 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 namespace torch {
-class CustomClassHolder : public c10::intrusive_ptr_target {};
+class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
 namespace jit {
 using ::torch::CustomClassHolder;
 struct Function;
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 7c003e8..09c32a9 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -1716,7 +1716,8 @@
   size_t addAttribute(
       const std::string& name,
       const TypePtr& type,
-      bool is_parameter = false);
+      bool is_parameter = false,
+      bool allow_any = false);
 
   // [Internal Only] Remove attribute from the ClassType,
   // caller is responsible to make sure the modification is safe:
@@ -1731,10 +1732,11 @@
   size_t addOrCheckAttribute(
       const std::string& name,
       TypePtr ty,
-      bool is_parameter = false) {
+      bool is_parameter = false,
+      bool allow_any = false) {
     auto slot_idx = findAttributeSlot(name);
     if (!slot_idx) {
-      return addAttribute(name, ty, is_parameter);
+      return addAttribute(name, ty, is_parameter, allow_any);
     }
 
     TORCH_CHECK(
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index 2f894b5..30270cc 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -1143,10 +1143,14 @@
 size_t ClassType::addAttribute(
     const std::string& name,
     const TypePtr& type,
-    bool is_parameter) {
+    bool is_parameter,
+    bool allow_any) {
   const char* what = is_parameter ? "parameter" : "attribute";
   checkNotExist(name, what);
-  checkNoAny(*this, what, name, type);
+
+  if (!allow_any) {
+    checkNoAny(*this, what, name, type);
+  }
 
   size_t slot = attributeNames_.size();
   attributeNames_.push_back(name);
diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py
new file mode 100644
index 0000000..7826758
--- /dev/null
+++ b/test/jit/test_backends.py
@@ -0,0 +1,72 @@
+from torch.testing._internal.jit_utils import JitTestCase
+import os
+import sys
+
+import torch
+import torch._C
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+
+if __name__ == '__main__':
+    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
+                       "\tpython test/test_jit.py TESTNAME\n\n"
+                       "instead.")
+
+
+def to_test_backend(module, method_compile_spec):
+    return torch._C._jit_to_test_backend(module, {"forward": method_compile_spec})
+
+
+def to_test_backend_multi(module, method_compile_spec):
+    return torch._C._jit_to_test_backend(module, method_compile_spec)
+
+
+class MyModule(torch.nn.Module):
+    def __init__(self):
+        super(MyModule, self).__init__()
+
+    def forward(self, x, h):
+        return self.accum(x, h), self.sub_accum(x, h)
+
+    def accum(self, x, h):
+        return x + h
+
+    def sub_accum(self, x, h):
+        return x - h
+
+
+class TestBackends(JitTestCase):
+    def test_simple(self):
+        module = MyModule()
+        scripted_module = torch.jit.script(MyModule())
+
+        # Test compile.
+        lowered_module = to_test_backend_multi(
+            scripted_module._c, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}})
+
+        # Test execution with backend against Python and JIT.
+        input = torch.randn(5)
+
+        def compare_py_jit_backend(name, input):
+            # Get handles for Python, JIT and backend methods.
+            python_method = module.__getattribute__(name)
+            jit_method = scripted_module.__getattr__(name)
+            backend_method = lowered_module._get_method(name)
+
+            # Run methods.
+            python_output = python_method(input, input)
+            jit_output = jit_method(input, input)
+            backend_output = backend_method(input, input)
+
+            # The answers returned by Python, JIT and to_backend should all match.
+            self.assertEqual(python_output, backend_output)
+            self.assertEqual(jit_output, backend_output)
+
+        # Test all three module methods.
+        compare_py_jit_backend("accum", input)
+        compare_py_jit_backend("sub_accum", input)
+        compare_py_jit_backend("forward", input)
+
+        # TODO: Test save and load.
diff --git a/test/test_jit.py b/test/test_jit.py
index d2091d7..d45dc00 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8,6 +8,7 @@
 from jit.test_recursive_script import TestRecursiveScript  # noqa: F401
 from jit.test_type_sharing import TestTypeSharing  # noqa: F401
 from jit.test_logging import TestLogging  # noqa: F401
+from jit.test_backends import TestBackends  # noqa: F401
 from jit.test_list_dict import TestList, TestDict  # noqa: F401
 from jit.test_async import TestAsync  # noqa: F401
 from jit.test_data_parallel import TestDataParallel  # noqa: F401
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 96b102b..404486e 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -60,6 +60,7 @@
     "torch/csrc/jit/api/function_impl.cpp",
     "torch/csrc/jit/api/module.cpp",
     "torch/csrc/jit/api/object.cpp",
+    "torch/csrc/jit/backends/backend_interface.cpp",
     "torch/csrc/jit/codegen/fuser/codegen.cpp",
     "torch/csrc/jit/codegen/fuser/compiler.cpp",
     "torch/csrc/jit/codegen/fuser/executor.cpp",
@@ -384,6 +385,9 @@
     "torch/csrc/autograd/python_legacy_variable.cpp",
     "torch/csrc/autograd/python_variable.cpp",
     "torch/csrc/autograd/python_variable_indexing.cpp",
+    "torch/csrc/jit/backends/backend_init.cpp",
+    "torch/csrc/jit/backends/backend_resolver.cpp",
+    "torch/csrc/jit/backends/test_backend.cpp",
     "torch/csrc/jit/python/init.cpp",
     "torch/csrc/jit/passes/onnx.cpp",
     "torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h
index e754ff4..d5e83c9 100644
--- a/torch/csrc/jit/api/module.h
+++ b/torch/csrc/jit/api/module.h
@@ -131,8 +131,9 @@
       const std::string& name,
       const TypePtr t,
       IValue v,
-      bool is_param = false) {
-    type()->addOrCheckAttribute(name, t, is_param);
+      bool is_param = false,
+      bool allow_any = false) {
+    type()->addOrCheckAttribute(name, t, is_param, allow_any);
     _ivalue()->setAttr(name, std::move(v));
   }
   void register_module(const std::string& name, const Module& module) {
diff --git a/torch/csrc/jit/backends/backend.h b/torch/csrc/jit/backends/backend.h
new file mode 100644
index 0000000..4b523a3
--- /dev/null
+++ b/torch/csrc/jit/backends/backend.h
@@ -0,0 +1,332 @@
+#pragma once
+
+#include <ATen/core/builtin_function.h>
+#include <ATen/core/stack.h>
+#include <torch/csrc/jit/backends/backend_interface.h>
+#include <torch/csrc/jit/backends/backend_resolver.h>
+#include <torch/csrc/jit/frontend/code_template.h>
+#include <torch/csrc/jit/frontend/resolver.h>
+#include <torch/csrc/jit/frontend/sugared_value.h>
+#include <torch/csrc/jit/python/pybind.h>
+#include <torch/csrc/utils/pybind.h>
+#include <torch/custom_class.h>
+
+namespace torch {
+namespace jit {
+
+c10::FunctionSchema getPreprocessSchema() {
+  c10::Argument self("self", c10::AnyType::get());
+  c10::Argument mod("mod", c10::AnyType::get());
+  c10::Argument method_compile_spec(
+      "method_compile_spec",
+      c10::DictType::create(c10::StringType::get(), c10::AnyType::get()));
+
+  c10::FunctionSchema preprocessor_schema(
+      "preprocess",
+      /*overload_name=*/"",
+      /*arguments=*/{self, mod, method_compile_spec},
+      /*returns=*/{mod});
+  return preprocessor_schema;
+}
+
+template <typename TBackendInterface>
+std::function<void(Stack&)> getPreprocessFunc() {
+  return [](Stack& stack) {
+    auto method_compile_spec = pop(stack).toGenericDict();
+    auto mod = pop(stack);
+    auto self = pop(stack).toCustomClass<TBackendInterface>();
+    auto ret = self->preprocess(mod, method_compile_spec);
+    push(stack, ret);
+  };
+}
+
+c10::FunctionSchema getCompileSchema() {
+  c10::Argument self("self", c10::AnyType::get());
+  c10::Argument mod("processed", c10::AnyType::get());
+  auto any_dict_ty =
+      c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
+  c10::Argument method_compile_spec("method_compile_spec", any_dict_ty);
+  c10::Argument handles("handles", any_dict_ty);
+
+  c10::FunctionSchema compile_schema(
+      "compile",
+      /*overload_name=*/"",
+      /*arguments=*/{self, mod, method_compile_spec},
+      /*returns=*/{handles});
+  return compile_schema;
+}
+
+template <typename TBackendInterface>
+std::function<void(Stack&)> getCompileFunc() {
+  return [](Stack& stack) {
+    auto method_compile_spec = pop(stack).toGenericDict();
+    auto processed = pop(stack);
+    auto self = pop(stack).toCustomClass<TBackendInterface>();
+    auto ret = self->compile(processed, method_compile_spec);
+    push(stack, ret);
+  };
+}
+
+c10::FunctionSchema getExecuteSchema() {
+  auto any_list_ty = c10::ListType::create(c10::AnyType::get());
+  c10::Argument self("self", c10::AnyType::get());
+  c10::Argument handle("handle", c10::AnyType::get());
+  c10::Argument input("input", any_list_ty);
+  c10::Argument output("output", any_list_ty);
+  return c10::FunctionSchema(
+      "execute",
+      /*overload_name=*/"",
+      /*arguments=*/{self, handle, input},
+      /*returns=*/{output});
+}
+
+template <typename TBackendInterface>
+std::function<void(Stack&)> getExecuteFunc() {
+  return [](Stack& stack) {
+    auto args = pop(stack);
+    auto handle = pop(stack);
+    auto self = pop(stack);
+    auto backend = self.toCustomClass<TBackendInterface>();
+    auto res = backend->execute(handle, args.toList());
+    push(stack, res);
+  };
+}
+
+// Static registration API for backends.
+template <class TBackendInterface>
+class backend {
+  static_assert(
+      std::is_base_of<PyTorchBackendInterface, TBackendInterface>::value,
+      "torch::jit::backend_<T> requires T to inherit from PyTorchBackendInterface");
+  constexpr static auto kBackendsNamespace = "__backends__";
+  std::string backend_name_;
+
+ public:
+  explicit backend(const std::string& name) : backend_name_(name) {
+    static auto cls = torch::class_<TBackendInterface>(kBackendsNamespace, name)
+                          .def(torch::init<>())
+                          ._def_unboxed(
+                              "preprocess",
+                              getPreprocessFunc<TBackendInterface>(),
+                              getPreprocessSchema())
+                          ._def_unboxed(
+                              "compile",
+                              getCompileFunc<TBackendInterface>(),
+                              getCompileSchema())
+                          ._def_unboxed(
+                              "execute",
+                              getExecuteFunc<TBackendInterface>(),
+                              getExecuteSchema());
+  }
+
+  // Generates and returns a function that takes a Module and a lowering
+  // specification in the form of a dictionary. The caller is responsible for
+  // binding this into a CPython module.
+  std::function<Module(Module, py::dict)> generateToBackendFn() {
+    const c10::QualifiedName qual_backend_name(
+        {"__torch__", "torch", "classes", kBackendsNamespace, backend_name_});
+    const std::string backend_name = qual_backend_name.name();
+
+    return [=](Module orig_module, py::dict method_compile_spec) {
+      // TODO: Validate method_compile_spec.
+
+      // Clone orig_module to make sure backend transformation is
+      // functional.
+      auto cloned_module = orig_module.clone();
+
+      // Represents of a Type of Dict[str, Any].
+      auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
+
+      // Generate LoweredModule.
+      Module loweredModule("torch.jit." + backend_name + "LoweredModule");
+
+      // Generate attributes.
+      // This is the original cloned and preprocessed module.
+      loweredModule.register_attribute(
+          "__processed_module",
+          AnyType::get(),
+          cloned_module._ivalue(),
+          /*is_param=*/false,
+          /*allow_any=*/true);
+
+      // This is for the method_compile_spec passed in to to_<backend> or
+      // loaded from an exported model.
+      loweredModule.register_attribute(
+          "__method_compile_spec",
+          any_dict_ty,
+          toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
+          /*is_param=*/false,
+          /*allow_any=*/true);
+
+      // This is a pointer to a backend instance that is used to access
+      // compile and execute functions.
+      auto cls = getCustomClass(qual_backend_name.qualifiedName());
+      TORCH_INTERNAL_ASSERT(cls);
+      c10::intrusive_ptr<torch::CustomClassHolder> backend;
+      loweredModule.register_attribute(
+          "__backend", cls, IValue::make_capsule(backend));
+
+      // This is the list of opaque backend handles returned by
+      // backend.compile.
+      loweredModule.register_attribute(
+          "__handles",
+          any_dict_ty,
+          c10::impl::GenericDict(
+              any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
+          /*is_param=*/false,
+          /*allow_any=*/true);
+
+      // Methods.
+
+      // This is a helper function for creating a new instance of the
+      // backend class.
+      static const auto create_backend_ct = CodeTemplate(R"(
+            def __create_backend(self):
+                self.__backend = $name()
+            )");
+      TemplateEnv create_backend_te;
+      create_backend_te.s("name", qual_backend_name.qualifiedName());
+      loweredModule.define(
+          create_backend_ct.format(create_backend_te), loweredModuleResolver());
+
+      // getstate and setstate are for serialization/deserialization of the
+      // LoweredModule.
+      loweredModule.define(
+          R"(
+            def __getstate__(self):
+                return self.__method_compile_spec, self.__processed_module
+            )",
+          loweredModuleResolver());
+
+      loweredModule.define(
+          R"(
+            def __setstate__(self, state):
+                self.__method_compile_spec = state[0]
+                self.__processed_module = state[1]
+                self.__create_backend()
+                self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
+            )",
+          loweredModuleResolver());
+
+      // This is never called during compilation or execution, but is needed
+      // to generate the LoweredModule because we don't have access to an
+      // instance of the backend as a C++ object with which to call
+      // preprocess.
+      loweredModule.define(
+          R"(
+            def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
+                self.__create_backend()
+                self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
+          )",
+          loweredModuleResolver());
+
+      // This loop generates one method on the LoweredModule for every key
+      // in method_compile_spec.
+      for (auto& e : method_compile_spec) {
+        std::string method_name = py::cast<std::string>(e.first);
+        static const auto method_ct = CodeTemplate(R"(
+            def $method(self${,def_inputs}):
+                typed_inputs: List[Any] = [${fwd_inputs,}]
+                $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs)
+                ${refine,}
+                return $ret
+            )");
+
+        TemplateEnv method_te;
+        method_te.s("method", method_name);
+        auto method = orig_module.get_method(method_name);
+        auto& function = method.function();
+        auto schema = function.getSchema();
+
+        // Generate the inputs for the function signature (def_inputs) and
+        // for passing to backend.execute (fwd_inputs).
+        std::vector<std::string> def_inputs, fwd_inputs;
+        for (const auto& arg : schema.arguments()) {
+          auto name = arg.name();
+
+          // Skip self since that is only and always present in the
+          // signature.
+          if (name == "self") {
+            continue;
+          }
+
+          auto default_value = arg.default_value();
+
+          if (arg.kwarg_only()) {
+            // If this is a kwarg, it needs to be emitted as keyword=value
+            // in the definition and keyword=keyword in the call to
+            // backend_execute.
+            TORCH_INTERNAL_ASSERT(default_value.has_value());
+            std::stringstream def_ss, fwd_ss;
+            def_ss << name << "=";
+            fwd_ss << name << "=" << name;
+            default_value->repr(
+                def_ss,
+                [](std::ostream&, const IValue&) -> bool { return false; });
+            def_inputs.emplace_back(def_ss.str());
+            fwd_inputs.emplace_back(fwd_ss.str());
+          } else {
+            // If this is not a kwarg, it should be emitted as is in the
+            // signature and the call to backend_execute.
+            def_inputs.emplace_back(name);
+            fwd_inputs.emplace_back(name);
+          }
+        }
+
+        // Generate a comma-delimited list of identifiers to unpack outputs, as
+        // well as a list of isinstance checks to make sure the backend returned
+        // the types it was supposed to.
+        std::stringstream out_ss, type_check_ss;
+        std::vector<std::string> type_checks;
+        TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
+        auto out_ty = schema.returns().at(0).type();
+
+        out_ss << "_0";
+        type_check_ss << "assert isinstance(_0, ";
+
+        if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
+          auto tuple_elements = out_tuple_ty->elements();
+          type_check_ss << tuple_elements[0]->str() << ")";
+          type_checks.emplace_back(type_check_ss.str());
+          for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
+            type_check_ss.str(std::string());
+            type_check_ss.clear();
+            out_ss << ", _" << i;
+            type_check_ss << "assert isinstance(_" << i << ", "
+                          << tuple_elements[i]->str() << ")";
+            type_checks.emplace_back(type_check_ss.str());
+          }
+        } else {
+          type_check_ss << out_ty->str() << ")";
+          type_checks.emplace_back(type_check_ss.str());
+        }
+
+        method_te.v("def_inputs", def_inputs);
+        method_te.v("fwd_inputs", fwd_inputs);
+        method_te.v("refine", type_checks);
+        method_te.s("ret", out_ss.str());
+
+        loweredModule.define(
+            method_ct.format(method_te), loweredModuleResolver());
+      }
+
+      // Run preprocess so that __processed_module is set correctly before
+      // compilation.
+      loweredModule.run_method(
+          "__preprocess",
+          cloned_module._ivalue(),
+          toIValue(method_compile_spec, any_dict_ty).toGenericDict());
+
+      // Call __setstate__ to ensure that the returned Module is ready to
+      // run.
+      auto state = at::ivalue::Tuple::create(
+          toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
+          loweredModule.attr("__processed_module"));
+      loweredModule.run_method("__setstate__", state);
+      return loweredModule;
+    };
+  }
+};
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp
new file mode 100644
index 0000000..f73a70e
--- /dev/null
+++ b/torch/csrc/jit/backends/backend_init.cpp
@@ -0,0 +1,11 @@
+#include <torch/csrc/jit/backends/backend_init.h>
+#include <torch/csrc/jit/backends/test_backend.h>
+
+namespace torch {
+namespace jit {
+
+void initJitBackendBindings(PyObject* module) {
+  initTestBackendBindings(module);
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/backend_init.h b/torch/csrc/jit/backends/backend_init.h
new file mode 100644
index 0000000..e7be08c
--- /dev/null
+++ b/torch/csrc/jit/backends/backend_init.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include <torch/csrc/jit/python/pybind.h>
+#include <torch/csrc/utils/pybind.h>
+
+namespace torch {
+namespace jit {
+// Initialize Python bindings for JIT to_<backend> functions.
+void initJitBackendBindings(PyObject* module);
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/backend_interface.cpp b/torch/csrc/jit/backends/backend_interface.cpp
new file mode 100644
index 0000000..cf0879f
--- /dev/null
+++ b/torch/csrc/jit/backends/backend_interface.cpp
@@ -0,0 +1,10 @@
+#include <torch/csrc/jit/backends/backend_interface.h>
+
+namespace torch {
+namespace jit {
+
+PyTorchBackendInterface::PyTorchBackendInterface() = default;
+PyTorchBackendInterface::~PyTorchBackendInterface() = default;
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/backend_interface.h b/torch/csrc/jit/backends/backend_interface.h
new file mode 100644
index 0000000..3a54343
--- /dev/null
+++ b/torch/csrc/jit/backends/backend_interface.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <torch/custom_class.h>
+
+namespace torch {
+namespace jit {
+
+// Interface for a JIT backend.
+class TORCH_API PyTorchBackendInterface : public torch::CustomClassHolder {
+ public:
+  PyTorchBackendInterface();
+  virtual ~PyTorchBackendInterface();
+
+  // Preprocess \p mod as per \p method_compile_spec to prepare it for
+  // compilation.
+  virtual c10::IValue preprocess(
+      c10::IValue mod,
+      c10::impl::GenericDict method_compile_spec) = 0;
+
+  // Compile the module contained in \p processed using the details provided in
+  // \p method_compile_spec for each module method that should be compiled for
+  // the backend. \p method_compile_spec should be of type Dict<string, Any>.
+  // \returns a dictionary of type Dict<string, Any> that contains a backend
+  // handle each method that can run on the backend (i.e. each key in \p
+  // method_compile_spec).
+  virtual c10::impl::GenericDict compile(
+      c10::IValue processed,
+      c10::impl::GenericDict method_compile_spec) = 0;
+
+  // Execute the method specified by \p handle using \p inputs. \returns the
+  // outputs as a tuple.
+  virtual c10::impl::GenericList execute(
+      c10::IValue handle,
+      c10::impl::GenericList inputs) = 0;
+};
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/backend_resolver.cpp b/torch/csrc/jit/backends/backend_resolver.cpp
new file mode 100644
index 0000000..a3a8c63
--- /dev/null
+++ b/torch/csrc/jit/backends/backend_resolver.cpp
@@ -0,0 +1,69 @@
+#include <torch/csrc/jit/backends/backend_resolver.h>
+#include <torch/csrc/jit/frontend/sugared_value.h>
+#include <torch/custom_class.h>
+
+namespace torch {
+namespace jit {
+namespace {
+// Essentially ClassNamespaceValue from import_source.cpp without the
+// SourceImporterImpl reference. This helps resolve the
+// __torch__.torch.classes.backends.{backend_name} symbols in the generated code
+// for the LoweredModule.
+struct ClassNamespaceValue : public SugaredValue {
+  explicit ClassNamespaceValue(c10::QualifiedName name)
+      : basename_(std::move(name)) {}
+
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Function& m,
+      const std::string& name) override {
+    auto fullName = c10::QualifiedName(basename_, name);
+
+    // Check to see if it is a custom class.
+    if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
+      return std::make_shared<ClassValue>(custom_class);
+    }
+
+    // If it's not a custom class, assume it's another namespace
+    return std::make_shared<ClassNamespaceValue>(std::move(fullName));
+  }
+
+  std::string kind() const override {
+    return "Class Namespace";
+  }
+
+ private:
+  c10::QualifiedName basename_;
+};
+
+// A resolver just for resolving custom backend class lookups in the
+// LoweredModule classes generated by the rest of the cdoe in this file.
+struct LoweredModuleResolver : public Resolver {
+  std::shared_ptr<SugaredValue> resolveValue(
+      const std::string& name,
+      Function& m,
+      const SourceRange& loc) override {
+    if (name == "torch") {
+      return std::make_shared<BuiltinModule>("aten");
+    } else if (name == "__torch__") {
+      return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
+    }
+
+    return nullptr;
+  }
+
+  TypePtr resolveType(const std::string& name, const SourceRange& loc)
+      override {
+    return nullptr;
+  }
+};
+} // namespace
+
+std::shared_ptr<Resolver> loweredModuleResolver() {
+  std::shared_ptr<Resolver> resolver =
+      std::make_shared<LoweredModuleResolver>();
+  return resolver;
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/backend_resolver.h b/torch/csrc/jit/backends/backend_resolver.h
new file mode 100644
index 0000000..b0d5727
--- /dev/null
+++ b/torch/csrc/jit/backends/backend_resolver.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include <torch/csrc/jit/frontend/resolver.h>
+
+namespace torch {
+namespace jit {
+// Create a Resolver for use in generating LoweredModules for specific backends.
+TORCH_API std::shared_ptr<Resolver> loweredModuleResolver();
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/test_backend.cpp b/torch/csrc/jit/backends/test_backend.cpp
new file mode 100644
index 0000000..158842f
--- /dev/null
+++ b/torch/csrc/jit/backends/test_backend.cpp
@@ -0,0 +1,82 @@
+#include <torch/csrc/jit/backends/test_backend.h>
+#include <torch/csrc/jit/api/module.h>
+#include <torch/csrc/jit/backends/backend.h>
+
+namespace torch {
+namespace jit {
+// This test JIT backend is intended to do the minimal amount of work
+// necessary to test that the JIT backend registration endpoints and
+// code generation are working correctly. It is not intended to
+// produce numerically correct results.
+class TestBackend : public PyTorchBackendInterface {
+ public:
+  // Constructor.
+  explicit TestBackend() {}
+  virtual ~TestBackend() = default;
+
+  c10::IValue preprocess(
+      c10::IValue mod,
+      c10::impl::GenericDict method_compile_spec) override {
+    return mod;
+  }
+
+  c10::impl::GenericDict compile(
+      c10::IValue processed,
+      c10::impl::GenericDict method_compile_spec) override {
+    auto spec =
+        c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
+
+    // Return the same string as a value for every key in method_compile_spec.
+    auto handles = c10::Dict<std::string, std::string>();
+    for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
+      handles.insert(it->key(), it->key());
+    }
+    return c10::impl::toGenericDict(handles);
+  }
+  c10::impl::GenericList execute(
+      c10::IValue handle,
+      c10::impl::GenericList inputs) override {
+    TORCH_INTERNAL_ASSERT(handle.isString());
+    TORCH_INTERNAL_ASSERT(inputs.size() > 0);
+
+    c10::List<at::Tensor> output_list;
+
+    // Implement simple accumulator and negative accumulator (?) ops. Return one
+    // or both of them depending on the handle to make sure multiple outputs are
+    // handled.
+    c10::IValue value = inputs[0];
+    at::Tensor accum = value.toTensor();
+    accum = accum.clone();
+    at::Tensor sub_accum = value.toTensor();
+    sub_accum = sub_accum.clone();
+
+    for (size_t i = 1, e = inputs.size(); i < e; ++i) {
+      value = inputs[i];
+      accum.add_(value.toTensor(), 1.0);
+      sub_accum.sub_(value.toTensor(), 1.0);
+    }
+
+    if (handle.toStringRef() == "accum") {
+      output_list.emplace_back(accum);
+    } else if (handle.toStringRef() == "sub_accum") {
+      output_list.emplace_back(sub_accum);
+    } else if (handle.toStringRef() == "forward") {
+      output_list.emplace_back(accum);
+      output_list.emplace_back(sub_accum);
+    }
+
+    return c10::impl::toList(output_list);
+  }
+};
+
+torch::jit::backend<TestBackend>& testBackend() {
+  static auto cls = torch::jit::backend<TestBackend>("test_backend");
+  return cls;
+}
+
+void initTestBackendBindings(PyObject* module) {
+  auto m = py::handle(module).cast<py::module>();
+  m.def("_jit_to_test_backend", testBackend().generateToBackendFn());
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/backends/test_backend.h b/torch/csrc/jit/backends/test_backend.h
new file mode 100644
index 0000000..96501dd
--- /dev/null
+++ b/torch/csrc/jit/backends/test_backend.h
@@ -0,0 +1,8 @@
+#include <torch/csrc/jit/python/pybind.h>
+#include <torch/csrc/utils/pybind.h>
+
+namespace torch {
+namespace jit {
+void initTestBackendBindings(PyObject* module);
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 13ae1ca..9436e93 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -1,6 +1,7 @@
 #include <torch/csrc/utils/pybind.h>
 
 #include <torch/csrc/jit/api/module.h>
+#include <torch/csrc/jit/backends/backend_init.h>
 #include <torch/csrc/jit/codegen/fuser/interface.h>
 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
 #include <torch/csrc/jit/frontend/ir_emitter.h>
@@ -905,6 +906,7 @@
   tracer::initPythonTracerBindings(module);
   initTreeViewBindings(module);
   initJitScriptBindings(module);
+  initJitBackendBindings(module);
 
   setPrintHandler([](const std::string& str) {
     py::gil_scoped_acquire acquire;
diff --git a/torch/custom_class.h b/torch/custom_class.h
index 56de83a..eb72307 100644
--- a/torch/custom_class.h
+++ b/torch/custom_class.h
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <ATen/core/stack.h>
 #include <ATen/core/builtin_function.h>
 #include <ATen/core/function_schema.h>
 #include <ATen/core/ivalue.h>
@@ -117,6 +118,17 @@
     return *this;
   }
 
+  /// This is an unsafe method registration API added for adding custom JIT backend support via custom
+  /// C++ classes. It is not for general purpose use.
+  class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema) {
+    auto qualMethodName = qualClassName + "." + name;
+    auto method = std::make_unique<jit::BuiltinOpFunction>(
+        qualMethodName, std::move(schema), std::move(func));
+    classTypePtr->addMethod(method.get());
+    registerCustomClassMethod(std::move(method));
+    return *this;
+  }
+
   /// def_pickle() is used to define exactly what state gets serialized
   /// or deserialized for a given instance of a custom C++ class in
   /// Python or TorchScript. This protocol is equivalent to the Pickle