Reland of "D27708346: generate xla codegen in-tree" (#56601)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56601

Updating it to ensure that RegistrationDeclarations.yaml is completely
unchanged

This reverts commit 90e532f3ef17a9611e9e7a9f1f6189d4168bf084.

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D27915305

Pulled By: bdhirsh

fbshipit-source-id: 491a025c44221690dad849f9a2166934130c0fec
diff --git a/aten/src/ATen/templates/aten_xla_type.h b/aten/src/ATen/templates/aten_xla_type.h
new file mode 100644
index 0000000..4dc34bc
--- /dev/null
+++ b/aten/src/ATen/templates/aten_xla_type.h
@@ -0,0 +1,22 @@
+#pragma once
+// ${generated_comment}
+
+#include <ATen/Tensor.h>
+
+namespace ${cpp_namespace} {
+
+// Base ATEN Type class where the XLA specific overrides should be defined.
+class AtenXlaType {
+ public:
+  static void InitializeAtenBindings();
+
+  //////////////////////////////////////////////////////////////////////////////
+  // ATEN API ovverrides in alphabetical order.
+  // Note: The C++ signatures must match the ones listed within the following
+  // pytorch folder file:
+  //   torch/csrc/autograd/generated/RegistrationDeclarations.h
+  /////////////////////////////////////////////////////////////////////////////
+${dispatch_xla_declarations}
+};
+
+}  // namespace torch_xla
diff --git a/aten/src/ATen/templates/aten_xla_type_default.cpp b/aten/src/ATen/templates/aten_xla_type_default.cpp
new file mode 100644
index 0000000..56c29166
--- /dev/null
+++ b/aten/src/ATen/templates/aten_xla_type_default.cpp
@@ -0,0 +1,30 @@
+// ${generated_comment}
+#include <torch_xla/csrc/aten_xla_type_default.h>
+
+#include <ATen/Context.h>
+#include <torch/library.h>
+#include <ATen/CPUGeneratorImpl.h>
+
+#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
+#include <tensorflow/compiler/xla/xla_client/metrics.h>
+#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
+#include <torch_xla/csrc/aten_xla_bridge.h>
+#include <torch_xla/csrc/aten_xla_type.h>
+#include <torch_xla/csrc/function_call_tracker.h>
+
+namespace ${cpp_namespace} {
+
+${dispatch_aten_fallback_definitions}
+
+
+
+TORCH_LIBRARY_IMPL(aten, XLA, m) {
+${dispatch_registrations}
+
+}
+TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {
+${dispatch_autograd_registrations}
+
+}
+
+}  // namespace torch_xla
diff --git a/aten/src/ATen/templates/aten_xla_type_default.h b/aten/src/ATen/templates/aten_xla_type_default.h
new file mode 100644
index 0000000..6d1e84b
--- /dev/null
+++ b/aten/src/ATen/templates/aten_xla_type_default.h
@@ -0,0 +1,19 @@
+// ${generated_comment}
+
+#include <ATen/Tensor.h>
+#include <c10/core/Stream.h>
+
+using c10::Stream;
+
+namespace ${cpp_namespace} {
+
+class AtenXlaTypeDefault {
+ public:
+${dispatch_aten_fallback_declarations}
+
+};
+
+// TODO: maybe kill this, doesn't look like XLA actually calls it anywhere
+void RegisterAtenTypeFunctions();
+
+}  // namespace torch_xla
diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py
index 673f01c..54ae467 100644
--- a/tools/codegen/api/cpp.py
+++ b/tools/codegen/api/cpp.py
@@ -127,6 +127,10 @@
                 else:
                     return MutRefCType(BaseCType(tensorT))
             else:
+                # Note [Tensor Copy Returns]
+                # Currently, we use "Argument.is_write" to determine
+                # whether or not Tensor return types should be copies or references.
+                # If that ever changes, take a look at other locations of this note!
                 return BaseCType(tensorT)
         elif t.name == BaseTy.Scalar:
             return BaseCType(scalarT)
@@ -150,7 +154,7 @@
     else:
         return TupleCType([return_type(r) for r in rs])
 
-def return_names(f: NativeFunction) -> Sequence[str]:
+def return_names(f: NativeFunction, *, fallback_name: str = 'result') -> Sequence[str]:
     returns: List[str] = []
     for i, r in enumerate(f.func.returns):
         # If we have an inplace function, the return argument is
@@ -171,11 +175,11 @@
                 name = f'{r.name}_return'
             else:
                 name = r.name
-        # If there is no explicit name, we just name the output result,
+        # If there is no explicit name and no fallback name was passed in, we just name the output result,
         # unless it's a multi-return, in which case it's result0,
         # result1, etc (zero-indexed)
         else:
-            name = 'result' if len(f.func.returns) == 1 else f'result{i}'
+            name = fallback_name if len(f.func.returns) == 1 else f'{fallback_name}{i}'
         returns.append(name)
     return returns
 
diff --git a/tools/codegen/api/dispatcher.py b/tools/codegen/api/dispatcher.py
index a9ca124..be51c4a 100644
--- a/tools/codegen/api/dispatcher.py
+++ b/tools/codegen/api/dispatcher.py
@@ -4,6 +4,7 @@
 
 from tools.codegen.api.types import ArgName, Binding, NamedCType, CType
 from tools.codegen.api import cpp
+from tools.codegen.utils import concatMap
 
 import itertools
 from typing import Sequence, List, Union
@@ -40,27 +41,25 @@
     # At present, there is no difference. But there could be!
     return cpp.returns_type(rs)
 
-def argument(
-    a: Union[Argument, TensorOptionsArguments, SelfArgument]
-) -> List[Binding]:
-    if isinstance(a, Argument):
-        return [Binding(
-            nctype=argument_type(a, binds=a.name),
-            name=a.name,
-            argument=a,
-        )]
-    elif isinstance(a, SelfArgument):
-        return argument(a.argument)
-    elif isinstance(a, TensorOptionsArguments):
-        return argument(a.dtype) + argument(a.layout) + argument(a.device) + argument(a.pin_memory)
-    else:
-        assert_never(a)
+def jit_arguments(func: FunctionSchema) -> List[Argument]:
+    def to_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Argument]:
+        if isinstance(a, Argument):
+            return [a]
+        elif isinstance(a, SelfArgument):
+            return [a.argument]
+        elif isinstance(a, TensorOptionsArguments):
+            return [a.dtype, a.layout, a.device, a.pin_memory]
+        else:
+            assert_never(a)
+    return list(concatMap(to_argument, itertools.chain(
+        func.arguments.positional,
+        func.arguments.kwarg_only,
+        func.arguments.out)))
 
 def arguments(func: FunctionSchema) -> List[Binding]:
     return [
-        r for a in itertools.chain(
-            func.arguments.positional,
-            func.arguments.kwarg_only,
-            func.arguments.out
-        ) for r in argument(a)
-    ]
+        Binding(
+            nctype=argument_type(a, binds=a.name),
+            name=a.name,
+            argument=a,
+        ) for a in jit_arguments(func)]
diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py
index 93214a3..b2ef83a 100644
--- a/tools/codegen/api/types.py
+++ b/tools/codegen/api/types.py
@@ -266,11 +266,16 @@
             argument=self.argument,
         )
 
-    def decl(self) -> str:
+    def decl(self, *, func_ptr_cast: bool = False) -> str:
         mb_default = ""
         if self.default is not None:
             mb_default = f"={self.default}"
-        return f"{self.type} {self.name}{mb_default}"
+
+        # casting only needs to know the type
+        if func_ptr_cast:
+            return f"{self.type}"
+        else:
+            return f"{self.type} {self.name}{mb_default}"
 
     # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
     # TODO: Kill this when we eventually remove it!
@@ -407,6 +412,12 @@
     def name(self) -> str:
         return dispatcher.name(self.func)
 
+    def decl(self, name: Optional[str] = None) -> str:
+        args_str = ', '.join(a.decl() for a in self.arguments())
+        if name is None:
+            name = self.name()
+        return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
     def defn(self, name: Optional[str] = None) -> str:
         args_str = ', '.join(a.defn() for a in self.arguments())
         if name is None:
@@ -419,6 +430,10 @@
     def returns_type(self) -> CType:
         return dispatcher.returns_type(self.func.returns)
 
+    def ptr_type(self) -> str:
+        dispatcher_args_types_str = ', '.join(a.type for a in self.arguments())
+        return f'{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})'
+
     # Return the C++ function type, e.g., something like int(bool)
     def type(self) -> str:
         dispatcher_args_types_str = ', '.join(a.type for a in self.arguments())
@@ -438,6 +453,12 @@
     def name(self) -> str:
         return self.prefix + native.name(self.func)
 
+    def decl(self, name: Optional[str] = None) -> str:
+        args_str = ', '.join(a.decl() for a in self.arguments())
+        if name is None:
+            name = self.name()
+        return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})"
+
     def defn(self, name: Optional[str] = None) -> str:
         args_str = ', '.join(a.defn() for a in self.arguments())
         if name is None:
diff --git a/tools/codegen/context.py b/tools/codegen/context.py
index 39aeaa0..da49662 100644
--- a/tools/codegen/context.py
+++ b/tools/codegen/context.py
@@ -1,5 +1,6 @@
 from tools.codegen.utils import S, T, context
-from tools.codegen.model import NativeFunction, NativeFunctionsGroup
+from tools.codegen.model import (NativeFunction, NativeFunctionsGroup, ExternalBackendFunction,
+                                 ExternalBackendFunctionsGroup)
 import tools.codegen.local as local
 
 import functools
@@ -8,11 +9,25 @@
 
 # Helper functions for defining generators on things in the model
 
-F = TypeVar('F', NativeFunction, NativeFunctionsGroup, Union[NativeFunction, NativeFunctionsGroup])
+F = TypeVar(
+    'F',
+    NativeFunction,
+    NativeFunctionsGroup,
+    ExternalBackendFunction,
+    ExternalBackendFunctionsGroup,
+    Union[NativeFunction, NativeFunctionsGroup],
+    Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
+    Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
+)
 
 @contextlib.contextmanager
-def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> Iterator[None]:
-    if isinstance(g, NativeFunctionsGroup):
+def native_function_manager(g: Union[
+        NativeFunctionsGroup, NativeFunction, ExternalBackendFunction, ExternalBackendFunctionsGroup]) -> Iterator[None]:
+    if isinstance(g, ExternalBackendFunctionsGroup):
+        f = g.primary.native_function
+    elif isinstance(g, ExternalBackendFunction):
+        f = g.native_function
+    elif isinstance(g, NativeFunctionsGroup):
         # By default, we associate all errors with structured native functions
         # with the out variant.  In some cases, it might be better to have
         # a more specific place to hang things; if so, use
@@ -20,7 +35,7 @@
         f = g.out
     else:
         f = g
-    with context(f'in {f.loc}:\n  {f.func}'):
+    with context(f'in native_functions.yaml line {f.loc}:\n  {f.func}'):
         with local.parametrize(use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors):
             yield
 
diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py
index ab4bada..0b68bf3 100644
--- a/tools/codegen/dest/__init__.py
+++ b/tools/codegen/dest/__init__.py
@@ -1,2 +1,3 @@
 from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey
 from .native_functions import compute_native_function_declaration as compute_native_function_declaration
+from .gen_external_aten_fallbacks import GenExternalAtenFallback as GenExternalAtenFallback
diff --git a/tools/codegen/dest/gen_external_aten_fallbacks.py b/tools/codegen/dest/gen_external_aten_fallbacks.py
new file mode 100644
index 0000000..47892f6
--- /dev/null
+++ b/tools/codegen/dest/gen_external_aten_fallbacks.py
@@ -0,0 +1,316 @@
+from typing import List, Optional, Union, Dict
+from typing_extensions import Literal
+from dataclasses import dataclass
+import re
+
+from tools.codegen.context import method_with_native_function
+from tools.codegen.utils import Target, mapMaybe
+from tools.codegen.model import (Argument, ExternalBackendFunction,
+                                 ExternalBackendFunctionsGroup, SchemaKind,
+                                 assert_never, Return, is_generic_dispatch_key,
+                                 ListType, OptionalType, BaseType, BaseTy, Variant)
+from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
+import tools.codegen.api.dispatcher as dispatcher
+import tools.codegen.api.cpp as cpp
+
+# TODO: this contains a list of regex for ops that don't get a CPU fallback.
+# We should just register fallthroughs when we make the CPU fallback a boxed kernel.
+_FN_DENYLIST_REGEX = [
+    # ATEN functions
+    r'[^(]*cudnn',
+    r'slow_conv_transpose2d_backward.grad_output',
+    r'slow_conv_transpose3d_backward.grad_output',
+    r'slow_conv3d_backward.grad_input',
+    r'thnn_conv2d_backward.grad_input',
+    r'thnn_conv_depthwise2d_backward.grad_input',
+    # XLA/TPU functions
+]
+
+# TODO: remove this list.
+# Instead, the codegen will figure out which ops to generate _out wrappers for
+# entirely from the yaml. Maintaining the same behavior as current XLA codegen for now.
+_FN_OUT = [
+    'abs',
+    'add',
+    'acos',
+    'acosh',
+    'asin',
+    'asinh',
+    'atan',
+    'atan2',
+    'atanh',
+    'baddbmm',
+    'bernoulli',
+    'binary_cross_entropy',
+    'binary_cross_entropy_backward',
+    'clamp',
+    'div',
+    'gather',
+    'ger',
+    'hardsigmoid',
+    'kthvalue',
+    'index_select',
+    'inverse',
+    'log',
+    'masked_select',
+    'maximum',
+    'minimum',
+    'pow',
+    'prod',
+    'nonzero',
+    'round',
+    'normal',
+    'std',
+    'take',
+    'topk',
+    'var',
+]
+
+def requires_backend_wrapper(f: ExternalBackendFunction) -> bool:
+    requires_lowering = not any(is_generic_dispatch_key(k) for k in f.native_function.dispatch)
+    has_xla_lowering = f.metadata is not None
+    in_denylist = any([re.match(frx, str(f.native_function.func.name)) for frx in _FN_DENYLIST_REGEX])
+    return not in_denylist and (requires_lowering or has_xla_lowering)
+
+def xla_tensor_creation_api(
+        ret_name: str,
+        ret: Return,
+        device_param_name: str,
+        *,
+        cpu_result_name: str,
+        tuple_idx: Optional[int] = None
+) -> str:
+    if ret.type == BaseType(BaseTy.Tensor) and not ret.is_write:
+        # Only raw Tensor (non-reference) returns need to go through the XLA tensor creation API.
+        # Tensor references can be returned directly, since they've already been converted to XLA tensors.
+        # See Note [Tensor Copy Returns]
+        bridge_api = 'CreateXlaTensor'
+    elif isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor):
+        bridge_api = 'CreateXlaTensors'
+    else:
+        # for non tensor-types, there's no need to wrap the output in an xla bridge api.
+        return ret_name
+
+    return f"bridge::{bridge_api}({cpu_result_name}, bridge::GetXlaDevice({device_param_name}))"
+
+
+
+# Generates aten_xla_type_default.h and aten_xla_type_default.cpp.
+#
+#   - This function registers external backend kernels, and also generates fallbacks to CPU.
+#     This is useful because pretty much all external backends (e.g. XLA)
+#     do not have full aten coverage.
+#     For operators not implemented by the external backend, our codegen
+#     will register these fallbacks instead.
+#   - Why do we generate fallback for ALL aten ops, including ops that
+#     external backends have already implemented?
+#     Many external backend kernels only work with specific input shapes,
+#     and are written to call into a cpu fallback when given inputs
+#     that they cannot handle.
+@dataclass(frozen=True)
+class GenExternalAtenFallback:
+    target: Union[
+        Literal[Target.NAMESPACED_DEFINITION],
+        Literal[Target.NAMESPACED_DECLARATION],
+        Literal[Target.REGISTRATION],
+    ]
+
+    @method_with_native_function
+    def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
+
+        def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
+            dispatcher_sig = DispatcherSignature.from_schema(g.out.native_function.func)
+            name = dispatcher_sig.name()
+
+            dispatcher_order_args = dispatcher.jit_arguments(g.out.native_function.func)
+            tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
+            print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
+
+            func_name = f'AtenXlaTypeDefault::{name}'
+            functional_result_name = f'{name}_tmp'
+            return_names = cpp.return_names(g.out.native_function)
+            if len(return_names) > 1:
+                updates = '\n  '.join(
+                    f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});'
+                    for i, ret_name in enumerate(return_names))
+                returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
+            else:
+                ret_name = return_names[0]
+                updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});'
+                returns = ret_name
+
+            functional_sig = DispatcherSignature.from_schema(g.functional.native_function.func)
+
+            return f"""\
+{dispatcher_sig.defn(name=func_name)} {{
+  XLA_FN_TRACK(3);
+  TF_VLOG(3) << "XLA {name} :"{print_args_str};
+  auto {functional_result_name} = AtenXlaType::{functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
+  {updates}
+  return {returns};
+}}
+
+"""
+
+        def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
+            if not requires_backend_wrapper(f):
+                return None
+
+            def get_device_param(args: List[Argument]) -> str:
+                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
+                # to use as the device argument.
+                # We should update this to be consistent with how we choose device guards.
+                const_tensor_or_self = [
+                    a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor)))
+                    and not a.is_write]
+                if any(const_tensor_or_self):
+                    return const_tensor_or_self[0].name
+                tensor_like = [a for a in args if a.type.is_tensor_like()]
+                if any(tensor_like):
+                    return tensor_like[0].name
+                device_like = [a for a in args if a.type == BaseType(BaseTy.Device)
+                               or a.type == OptionalType(BaseType(BaseTy.Device))]
+                if any(device_like):
+                    return device_like[0].name
+                raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
+
+            # XLA appears to have used the dispatcher convention to write their kernel signatures,
+            # probably because they based their signatures off of our RegistrationDeclarations.h
+            dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
+            name = dispatcher_sig.name()
+            args = dispatcher_sig.arguments()
+
+            if self.target is Target.NAMESPACED_DECLARATION:
+                return f"  static {dispatcher_sig.decl()};"
+
+            elif self.target is Target.REGISTRATION:
+                if f.metadata is not None:
+                    # xla has their own kernel: register it
+                    namespace = 'AtenXlaType'
+                else:
+                    # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
+                    namespace = 'AtenXlaTypeDefault'
+                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
+                return f'  m.impl("{f.native_function.func.name}", {payload});\n'
+
+            if self.target is not Target.NAMESPACED_DEFINITION:
+                assert_never(self.target)
+
+            # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
+            # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
+            if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
+                    and isinstance(g, ExternalBackendFunctionsGroup):
+                return gen_out_wrapper(g)
+
+            # Everything below here is where we generate the CPU fallback.
+            dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
+
+            # Map each argument to it's intermediate variable name in the fallback
+            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
+            tensorlist_args: Dict[Argument, str] = {
+                a: f'l_{a.name}' for a in dispatcher_order_args
+                if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor)}
+
+            opt_tensors = [
+                a for a in dispatcher_order_args
+                if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor)]
+            opt_tensor_args: Dict[Argument, str] = {a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors)}
+
+            tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
+            tensor_args: Dict[Argument, str] = {a: f'xlatens[{i}]' for i, a in enumerate(tensors)}
+            annotated_tensor_indices: List[int] = [
+                i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write]
+
+            print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys()])
+
+
+            tensorlist_intermediates_str = ''
+            if len(tensorlist_args) > 0:
+                tensorlist_intermediates_str = '\n'.join([f'  auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
+                                                          for arg, updated_name in tensorlist_args.items()])
+
+            opt_tensor_intermediates_str = ''
+            if len(opt_tensor_args) > 0:
+                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
+                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
+                opt_tensor_intermediates_str += '\n  auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'
+
+            intermediates = ''
+            if tensorlist_intermediates_str != '':
+                intermediates += tensorlist_intermediates_str + '\n'
+            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
+            intermediates += "\n  auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
+            if opt_tensor_intermediates_str != '':
+                intermediates += opt_tensor_intermediates_str
+
+
+            is_method = Variant.function not in f.native_function.variants
+            func_name = f'AtenXlaTypeDefault::{name}'
+
+            # Gather all of the updated variable names to call into the CPU operator.
+            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
+            updated_bindings: List[str] = [
+                tensorlist_args.get(a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args]
+
+            at_call_name = CppSignatureGroup.from_native_function(
+                f.native_function, method=is_method).most_faithful_signature().name()
+
+            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
+            # to the faithful C++ API, which are carefuly written to be exactly the same.
+            cpu_result_name = 'x_result'
+            if is_method:
+                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
+            else:
+                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
+            avoid_warning = ''
+            if f.native_function.func.returns:
+                at_call = f'auto&& {cpu_result_name} = {at_call}'
+                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'
+
+            collect_mutated_tensors = ''
+            update_tensors = ''
+            if len(annotated_tensor_indices) > 0:
+                indices_str = ", ".join([str(i) for i in annotated_tensor_indices])
+                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
+                update_tensors = '\n  bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'
+
+            returns = ''
+            if f.native_function.func.returns:
+                ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name)
+                if len(ret_names) == 1:
+                    returns = xla_tensor_creation_api(
+                        ret_names[0], f.native_function.func.returns[0],
+                        get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name)
+                else:
+                    return_args = [
+                        xla_tensor_creation_api(
+                            ret_names[i], f.native_function.func.returns[i],
+                            get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})'
+                        ) for i in range(len(f.native_function.func.returns))]
+                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
+            return_str = ''
+            if returns != '':
+                return_str = f'\n  return {returns};'
+
+            return f"""\
+{dispatcher_sig.defn(name=func_name)} {{
+  XLA_FN_TRACK(3);
+  XLA_COUNTER("aten::{name}", 1);
+  TF_VLOG(3) << "XLA {name} :"{print_args_str};
+{intermediates}
+  {at_call}{collect_mutated_tensors}{update_tensors}{avoid_warning}{return_str}
+}}
+
+"""
+        if isinstance(g, ExternalBackendFunctionsGroup):
+            if g.structured:
+                # We can probably only bother generating fallbacks for one of the variants, for structured
+                raise AssertionError("Not Implemented")
+            else:
+                return list(mapMaybe(gen_unstructured_external, g.functions()))
+        elif isinstance(g, ExternalBackendFunction):
+            f = g
+            x = gen_unstructured_external(f)
+            return [x] if x else []
+        else:
+            assert_never(f)
diff --git a/tools/codegen/dest/native_functions.py b/tools/codegen/dest/native_functions.py
index e1a208a..3b6bf6d 100644
--- a/tools/codegen/dest/native_functions.py
+++ b/tools/codegen/dest/native_functions.py
@@ -3,7 +3,9 @@
 from tools.codegen.context import with_native_function
 from tools.codegen.utils import concatMap
 from tools.codegen.model import (NativeFunction, NativeFunctionsGroup,
+                                 ExternalBackendFunction, ExternalBackendFunctionsGroup,
                                  is_structured_dispatch_key)
+from tools.codegen.api.types import DispatcherSignature, NativeSignature
 import tools.codegen.api.meta as meta
 import tools.codegen.api.native as native
 import tools.codegen.api.structured as structured
@@ -11,6 +13,7 @@
 @with_native_function
 def gen_unstructured(f: NativeFunction) -> List[str]:
     ns = list(f.dispatch.values())
+    native_sig = NativeSignature(f.func)
 
     rs = []
     # Sometimes a function name shows up multiple times; only generate
@@ -22,13 +25,22 @@
         if "legacy::" in n:
             continue
         seen.add(n)
-        returns_type = native.returns_type(f.func.returns).cpp_type()
-        args = native.arguments(f.func)
-        rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
+        rs.append(f"TORCH_API {native_sig.decl(name=n)};")
 
     return rs
 
 @with_native_function
+def gen_unstructured_external(f: ExternalBackendFunction) -> List[str]:
+    # XLA appears to have used the dispatcher convention to write their kernel signatures,
+    # probably because they based their signatures off of our RegistrationDeclarations.h
+    dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
+    if f.metadata is not None:
+        # Only generate declarations for operators that xla has defined in the yaml
+        return [f"static {dispatcher_sig.decl()};"]
+    else:
+        return []
+
+@with_native_function
 def gen_structured(g: NativeFunctionsGroup) -> List[str]:
     # only out has dispatch
     meta_name = meta.name(g)
@@ -65,8 +77,17 @@
 # Generates NativeFunctions.h, a list of forward declarations of all
 # actual kernel definitions we keep in aten/src/ATen/native/
 @with_native_function
-def compute_native_function_declaration(g: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
-    if isinstance(g, NativeFunctionsGroup):
+def compute_native_function_declaration(
+        g: Union[NativeFunctionsGroup, NativeFunction, ExternalBackendFunctionsGroup, ExternalBackendFunction]
+) -> List[str]:
+    if isinstance(g, ExternalBackendFunctionsGroup):
+        if g.structured:
+            raise AssertionError("Structured external backend functions are not implemented yet.")
+        else:
+            return list(concatMap(gen_unstructured_external, g.functions()))
+    elif isinstance(g, ExternalBackendFunction):
+        return gen_unstructured_external(g)
+    elif isinstance(g, NativeFunctionsGroup):
         if g.structured:
             return gen_structured(g)
         else:
diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py
index 8669d68..522e330 100644
--- a/tools/codegen/gen.py
+++ b/tools/codegen/gen.py
@@ -763,6 +763,26 @@
 
     return selector
 
+def get_grouped_native_functions(native_yaml_path: str) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
+    native_functions = parse_native_yaml(native_yaml_path)
+
+    pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]
+    pre_grouped_native_functions = defaultdict(dict)
+    for f in native_functions:
+        d = pre_grouped_native_functions[f.func.signature()]
+        assert f.func.kind() not in d
+        d[f.func.kind()] = f
+
+    def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
+        r = NativeFunctionsGroup.from_dict(d)
+        if r is None:
+            return list(d.values())
+        else:
+            return [r]
+
+    # TODO: how come ValuesView isn't a Sequence lol
+    return list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
+
 def main() -> None:
     parser = argparse.ArgumentParser(description='Generate ATen source files')
     parser.add_argument(
@@ -817,24 +837,9 @@
         options.op_selection_yaml_path,
     )
 
-    native_functions = parse_native_yaml(os.path.join(options.source_path, 'native/native_functions.yaml'))
-
-    pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]
-    pre_grouped_native_functions = defaultdict(dict)
-    for f in native_functions:
-        d = pre_grouped_native_functions[f.func.signature()]
-        assert f.func.kind() not in d
-        d[f.func.kind()] = f
-
-    def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
-        r = NativeFunctionsGroup.from_dict(d)
-        if r is None:
-            return list(d.values())
-        else:
-            return [r]
-
-    # TODO: how come ValuesView isn't a Sequence lol
-    grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
+    native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
+    native_functions = parse_native_yaml(native_yaml_path)
+    grouped_native_functions = get_grouped_native_functions(native_yaml_path)
     structured_native_functions = [g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)]
 
     template_dir = os.path.join(options.source_path, "templates")
diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py
new file mode 100644
index 0000000..a67dbb9
--- /dev/null
+++ b/tools/codegen/gen_backend_stubs.py
@@ -0,0 +1,126 @@
+import pathlib
+import argparse
+import os
+import yaml
+from typing import List, Dict, Union, Tuple, Sequence
+from tools.codegen.gen import FileManager, get_grouped_native_functions, LineLoader, parse_native_yaml
+from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
+                                 NativeFunction, NativeFunctionsGroup, OperatorName,
+                                 ExternalBackendMetadata, assert_never)
+from tools.codegen.selective_build.selector import SelectiveBuilder
+from tools.codegen.utils import Target, concatMap
+import tools.codegen.dest as dest
+
+def parse_backend_yaml(
+        backend_yaml_path: str,
+        grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
+) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
+    with open(backend_yaml_path, 'r') as f:
+        yaml_values = yaml.load(f, Loader=LineLoader)
+    assert isinstance(yaml_values, dict)
+
+    cpp_namespace = yaml_values.pop('cpp_namespace')
+    backend = yaml_values.pop('backend')
+
+    supported = yaml_values.pop('supported', [])
+    assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported}'
+    supported_autograd = yaml_values.pop('autograd', [])
+    assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'
+
+    assert len(yaml_values.keys()) > 0, \
+        f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}'
+
+    metadata: Dict[OperatorName, ExternalBackendMetadata] = {}
+    for op in supported:
+        op_name = OperatorName.parse(op)
+        m = ExternalBackendMetadata(op_name, backend, is_autograd=False)
+        metadata[m.operator] = m
+    for op in supported_autograd:
+        op_name = OperatorName.parse(op)
+        m = ExternalBackendMetadata(op_name, backend, is_autograd=True)
+        metadata[m.operator] = m
+
+    native_functions_map: Dict[OperatorName, NativeFunction] = {
+        f.func.name: f
+        for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
+    }
+
+    def native_to_external(
+            g: Union[NativeFunction, NativeFunctionsGroup]
+    ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
+        if isinstance(g, NativeFunction):
+            f = g
+            m = metadata.get(f.func.name, None)
+            return ExternalBackendFunction(f, m)
+        elif isinstance(g, NativeFunctionsGroup):
+            return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
+        else:
+            assert_never(g)
+    for op_name in metadata.keys():
+        if op_name not in native_functions_map:
+            raise AssertionError(f"Found an invalid operator name: {op_name}")
+    return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description='Generate backend stub files')
+    parser.add_argument(
+        '-s',
+        '--source_yaml',
+        help='path to source yaml file containing operator external definitions')
+    parser.add_argument(
+        '-o', '--output_dir', help='output directory')
+    parser.add_argument(
+        '--dry_run', type=bool, default=False, help='output directory')
+    options = parser.parse_args()
+
+    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
+    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
+
+    def make_file_manager(install_dir: str) -> FileManager:
+        return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)
+
+    fm = make_file_manager(options.output_dir)
+
+    native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
+    grouped_native_functions = get_grouped_native_functions(native_yaml_path)
+    cpp_namespace, external_backend_functions = parse_backend_yaml(options.source_yaml, grouped_native_functions)
+
+    native_functions = parse_native_yaml(native_yaml_path)
+
+    selector = SelectiveBuilder.get_nop_selector()
+
+
+    generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
+    fm.write('aten_xla_type.h', lambda: {
+        'generated_comment': generated_comment,
+        'cpp_namespace': cpp_namespace,
+        'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
+    })
+
+    fm.write('aten_xla_type_default.h', lambda: {
+        'generated_comment': generated_comment,
+        'cpp_namespace': cpp_namespace,
+        'dispatch_aten_fallback_declarations': list(concatMap(
+            dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions
+        )),
+    })
+
+    fm.write('aten_xla_type_default.cpp', lambda: {
+        'generated_comment': generated_comment,
+        'cpp_namespace': cpp_namespace,
+        # TODO: after cpu fallbacks are moved to a boxed kernel,
+        # merge registrations / definitions into RegisterDispatchKey
+        'dispatch_aten_fallback_definitions': list(concatMap(
+            dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
+        )),
+        'dispatch_registrations': list(concatMap(
+            dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
+        )),
+        'dispatch_autograd_registrations': list(concatMap(
+            dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
+        )),
+    })
+
+if __name__ == '__main__':
+    main()
diff --git a/tools/codegen/model.py b/tools/codegen/model.py
index 8acf296..1c7db29 100644
--- a/tools/codegen/model.py
+++ b/tools/codegen/model.py
@@ -1343,6 +1343,100 @@
         else:
             return f"{self.name}"
 
+@dataclass(frozen=True)
+class ExternalBackendMetadata:
+
+    operator: OperatorName
+    backend: str
+    is_autograd: bool
+
+    structured: bool = False  # TODO: this will eventually become per-op metadata in the yaml file
+
+@dataclass(frozen=True)
+class ExternalBackendFunction:
+
+    native_function: NativeFunction
+    metadata: Optional[ExternalBackendMetadata]
+
+    @property
+    def structured(self) -> bool:
+        # An external backend op is only considered structured if it's been marked structured both in-tree and out-of-tree
+        return self.native_function.structured and self.metadata is not None and self.metadata.structured
+
+    @property
+    def is_autograd_kernel(self) -> bool:
+        return self.metadata is not None and self.metadata.is_autograd
+
+    def __post_init__(self) -> None:
+        if self.metadata is not None:
+            assert self.metadata.operator == self.native_function.func.name, \
+                f'Metadata and native function names do not match: {self.metadata.operator} and {self.native_function.func.name}'
+        kind = self.native_function.func.kind()
+        if kind == SchemaKind.out or kind == SchemaKind.inplace:
+            assert self.metadata is None or not self.metadata.structured, \
+                "Found an out/inplace operator marked with the structured keyword." \
+                f" Only functional operators can be marked as structured. operator={str(self.native_function.func.name)}"
+
+@dataclass(frozen=True)
+class ExternalBackendFunctionsGroup:
+    functional: ExternalBackendFunction
+    inplace: Optional[ExternalBackendFunction]
+    out: ExternalBackendFunction
+
+    @property
+    def structured(self) -> bool:
+        return self.primary.structured
+
+    @property
+    def primary(self) -> ExternalBackendFunction:
+        # TODO: hardcoding that XLA will only implement functional variants of structured kernel.
+        # This will eventually be toggleable per backend.
+        return self.functional
+
+    @property
+    def is_autograd_kernel(self) -> bool:
+        return self.primary.metadata is not None and self.primary.metadata.is_autograd
+
+    def __post_init__(self) -> None:
+        # Note: I didn't want to copy-paste the post_init checks that NativeFunctionsGroup performs.
+        # ExternalBackendFunctionsGroup objects should be created using `from_function_group` (below),
+        # which guarantees that the relevant checks have already been performed.
+        if self.structured:
+            for f in self.functions():
+                if f == self.primary:
+                    continue
+                # For ops marked as structured externally, we expect external backends to
+                # only include either the functional or out variant in their yaml
+                assert f.metadata is None, \
+                    f"{str(self.primary.native_function.func.name)} is marked as structured. " \
+                    f"variant, {str(f.native_function.func.name)} will be generated for you " \
+                    "and doesn't need to live in the yaml."
+
+    def functions(self) -> Iterator[ExternalBackendFunction]:
+        yield self.out
+        yield self.functional
+        if self.inplace is not None:
+            yield self.inplace
+
+    @staticmethod
+    def from_function_group(
+            g: NativeFunctionsGroup,
+            metadata: Dict[OperatorName, ExternalBackendMetadata]
+    ) -> 'ExternalBackendFunctionsGroup':
+        out_meta = metadata.get(g.out.func.name, None)
+        out = ExternalBackendFunction(g.out, out_meta)
+
+        functional_meta = metadata.get(g.functional.func.name, None)
+        functional = ExternalBackendFunction(g.functional, functional_meta)
+
+        inplace = None
+        if g.inplace:
+            inplace_meta = metadata.get(g.inplace.func.name, None)
+            inplace = ExternalBackendFunction(g.inplace, inplace_meta)
+
+        return ExternalBackendFunctionsGroup(functional, inplace, out)
+
+
 # Helper functions for parsing argument lists (both inputs and returns)
 
 def parse_returns(return_decl: str) -> Tuple[Return, ...]: