Reland of "D27708346: generate xla codegen in-tree" (#56601)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56601
Updating it to ensure that RegistrationDeclarations.yaml is completely
unchanged
This reverts commit 90e532f3ef17a9611e9e7a9f1f6189d4168bf084.
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D27915305
Pulled By: bdhirsh
fbshipit-source-id: 491a025c44221690dad849f9a2166934130c0fec
diff --git a/aten/src/ATen/templates/aten_xla_type.h b/aten/src/ATen/templates/aten_xla_type.h
new file mode 100644
index 0000000..4dc34bc
--- /dev/null
+++ b/aten/src/ATen/templates/aten_xla_type.h
@@ -0,0 +1,22 @@
+#pragma once
+// ${generated_comment}
+
+#include <ATen/Tensor.h>
+
+namespace ${cpp_namespace} {
+
+// Base ATEN Type class where the XLA specific overrides should be defined.
+class AtenXlaType {
+ public:
+ static void InitializeAtenBindings();
+
+ //////////////////////////////////////////////////////////////////////////////
+ // ATEN API ovverrides in alphabetical order.
+ // Note: The C++ signatures must match the ones listed within the following
+ // pytorch folder file:
+ // torch/csrc/autograd/generated/RegistrationDeclarations.h
+ /////////////////////////////////////////////////////////////////////////////
+${dispatch_xla_declarations}
+};
+
+} // namespace torch_xla
diff --git a/aten/src/ATen/templates/aten_xla_type_default.cpp b/aten/src/ATen/templates/aten_xla_type_default.cpp
new file mode 100644
index 0000000..56c29166
--- /dev/null
+++ b/aten/src/ATen/templates/aten_xla_type_default.cpp
@@ -0,0 +1,30 @@
+// ${generated_comment}
+#include <torch_xla/csrc/aten_xla_type_default.h>
+
+#include <ATen/Context.h>
+#include <torch/library.h>
+#include <ATen/CPUGeneratorImpl.h>
+
+#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
+#include <tensorflow/compiler/xla/xla_client/metrics.h>
+#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
+#include <torch_xla/csrc/aten_xla_bridge.h>
+#include <torch_xla/csrc/aten_xla_type.h>
+#include <torch_xla/csrc/function_call_tracker.h>
+
+namespace ${cpp_namespace} {
+
+${dispatch_aten_fallback_definitions}
+
+
+
+TORCH_LIBRARY_IMPL(aten, XLA, m) {
+${dispatch_registrations}
+
+}
+TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {
+${dispatch_autograd_registrations}
+
+}
+
+} // namespace torch_xla
diff --git a/aten/src/ATen/templates/aten_xla_type_default.h b/aten/src/ATen/templates/aten_xla_type_default.h
new file mode 100644
index 0000000..6d1e84b
--- /dev/null
+++ b/aten/src/ATen/templates/aten_xla_type_default.h
@@ -0,0 +1,19 @@
+// ${generated_comment}
+
+#include <ATen/Tensor.h>
+#include <c10/core/Stream.h>
+
+using c10::Stream;
+
+namespace ${cpp_namespace} {
+
+class AtenXlaTypeDefault {
+ public:
+${dispatch_aten_fallback_declarations}
+
+};
+
+// TODO: maybe kill this, doesn't look like XLA actually calls it anywhere
+void RegisterAtenTypeFunctions();
+
+} // namespace torch_xla
diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py
index 673f01c..54ae467 100644
--- a/tools/codegen/api/cpp.py
+++ b/tools/codegen/api/cpp.py
@@ -127,6 +127,10 @@
else:
return MutRefCType(BaseCType(tensorT))
else:
+ # Note [Tensor Copy Returns]
+ # Currently, we use "Argument.is_write" to determine
+ # whether or not Tensor return types should be copies or references.
+ # If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
@@ -150,7 +154,7 @@
else:
return TupleCType([return_type(r) for r in rs])
-def return_names(f: NativeFunction) -> Sequence[str]:
+def return_names(f: NativeFunction, *, fallback_name: str = 'result') -> Sequence[str]:
returns: List[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
@@ -171,11 +175,11 @@
name = f'{r.name}_return'
else:
name = r.name
- # If there is no explicit name, we just name the output result,
+ # If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
- name = 'result' if len(f.func.returns) == 1 else f'result{i}'
+ name = fallback_name if len(f.func.returns) == 1 else f'{fallback_name}{i}'
returns.append(name)
return returns
diff --git a/tools/codegen/api/dispatcher.py b/tools/codegen/api/dispatcher.py
index a9ca124..be51c4a 100644
--- a/tools/codegen/api/dispatcher.py
+++ b/tools/codegen/api/dispatcher.py
@@ -4,6 +4,7 @@
from tools.codegen.api.types import ArgName, Binding, NamedCType, CType
from tools.codegen.api import cpp
+from tools.codegen.utils import concatMap
import itertools
from typing import Sequence, List, Union
@@ -40,27 +41,25 @@
# At present, there is no difference. But there could be!
return cpp.returns_type(rs)
-def argument(
- a: Union[Argument, TensorOptionsArguments, SelfArgument]
-) -> List[Binding]:
- if isinstance(a, Argument):
- return [Binding(
- nctype=argument_type(a, binds=a.name),
- name=a.name,
- argument=a,
- )]
- elif isinstance(a, SelfArgument):
- return argument(a.argument)
- elif isinstance(a, TensorOptionsArguments):
- return argument(a.dtype) + argument(a.layout) + argument(a.device) + argument(a.pin_memory)
- else:
- assert_never(a)
+def jit_arguments(func: FunctionSchema) -> List[Argument]:
+ def to_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Argument]:
+ if isinstance(a, Argument):
+ return [a]
+ elif isinstance(a, SelfArgument):
+ return [a.argument]
+ elif isinstance(a, TensorOptionsArguments):
+ return [a.dtype, a.layout, a.device, a.pin_memory]
+ else:
+ assert_never(a)
+ return list(concatMap(to_argument, itertools.chain(
+ func.arguments.positional,
+ func.arguments.kwarg_only,
+ func.arguments.out)))
def arguments(func: FunctionSchema) -> List[Binding]:
return [
- r for a in itertools.chain(
- func.arguments.positional,
- func.arguments.kwarg_only,
- func.arguments.out
- ) for r in argument(a)
- ]
+ Binding(
+ nctype=argument_type(a, binds=a.name),
+ name=a.name,
+ argument=a,
+ ) for a in jit_arguments(func)]
diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py
index 93214a3..b2ef83a 100644
--- a/tools/codegen/api/types.py
+++ b/tools/codegen/api/types.py
@@ -266,11 +266,16 @@
argument=self.argument,
)
- def decl(self) -> str:
+ def decl(self, *, func_ptr_cast: bool = False) -> str:
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
- return f"{self.type} {self.name}{mb_default}"
+
+ # casting only needs to know the type
+ if func_ptr_cast:
+ return f"{self.type}"
+ else:
+ return f"{self.type} {self.name}{mb_default}"
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
@@ -407,6 +412,12 @@
def name(self) -> str:
return dispatcher.name(self.func)
+ def decl(self, name: Optional[str] = None) -> str:
+ args_str = ', '.join(a.decl() for a in self.arguments())
+ if name is None:
+ name = self.name()
+ return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
def defn(self, name: Optional[str] = None) -> str:
args_str = ', '.join(a.defn() for a in self.arguments())
if name is None:
@@ -419,6 +430,10 @@
def returns_type(self) -> CType:
return dispatcher.returns_type(self.func.returns)
+ def ptr_type(self) -> str:
+ dispatcher_args_types_str = ', '.join(a.type for a in self.arguments())
+ return f'{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})'
+
# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
dispatcher_args_types_str = ', '.join(a.type for a in self.arguments())
@@ -438,6 +453,12 @@
def name(self) -> str:
return self.prefix + native.name(self.func)
+ def decl(self, name: Optional[str] = None) -> str:
+ args_str = ', '.join(a.decl() for a in self.arguments())
+ if name is None:
+ name = self.name()
+ return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})"
+
def defn(self, name: Optional[str] = None) -> str:
args_str = ', '.join(a.defn() for a in self.arguments())
if name is None:
diff --git a/tools/codegen/context.py b/tools/codegen/context.py
index 39aeaa0..da49662 100644
--- a/tools/codegen/context.py
+++ b/tools/codegen/context.py
@@ -1,5 +1,6 @@
from tools.codegen.utils import S, T, context
-from tools.codegen.model import NativeFunction, NativeFunctionsGroup
+from tools.codegen.model import (NativeFunction, NativeFunctionsGroup, ExternalBackendFunction,
+ ExternalBackendFunctionsGroup)
import tools.codegen.local as local
import functools
@@ -8,11 +9,25 @@
# Helper functions for defining generators on things in the model
-F = TypeVar('F', NativeFunction, NativeFunctionsGroup, Union[NativeFunction, NativeFunctionsGroup])
+F = TypeVar(
+ 'F',
+ NativeFunction,
+ NativeFunctionsGroup,
+ ExternalBackendFunction,
+ ExternalBackendFunctionsGroup,
+ Union[NativeFunction, NativeFunctionsGroup],
+ Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
+ Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
+)
@contextlib.contextmanager
-def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> Iterator[None]:
- if isinstance(g, NativeFunctionsGroup):
+def native_function_manager(g: Union[
+ NativeFunctionsGroup, NativeFunction, ExternalBackendFunction, ExternalBackendFunctionsGroup]) -> Iterator[None]:
+ if isinstance(g, ExternalBackendFunctionsGroup):
+ f = g.primary.native_function
+ elif isinstance(g, ExternalBackendFunction):
+ f = g.native_function
+ elif isinstance(g, NativeFunctionsGroup):
# By default, we associate all errors with structured native functions
# with the out variant. In some cases, it might be better to have
# a more specific place to hang things; if so, use
@@ -20,7 +35,7 @@
f = g.out
else:
f = g
- with context(f'in {f.loc}:\n {f.func}'):
+ with context(f'in native_functions.yaml line {f.loc}:\n {f.func}'):
with local.parametrize(use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors):
yield
diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py
index ab4bada..0b68bf3 100644
--- a/tools/codegen/dest/__init__.py
+++ b/tools/codegen/dest/__init__.py
@@ -1,2 +1,3 @@
from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey
from .native_functions import compute_native_function_declaration as compute_native_function_declaration
+from .gen_external_aten_fallbacks import GenExternalAtenFallback as GenExternalAtenFallback
diff --git a/tools/codegen/dest/gen_external_aten_fallbacks.py b/tools/codegen/dest/gen_external_aten_fallbacks.py
new file mode 100644
index 0000000..47892f6
--- /dev/null
+++ b/tools/codegen/dest/gen_external_aten_fallbacks.py
@@ -0,0 +1,316 @@
+from typing import List, Optional, Union, Dict
+from typing_extensions import Literal
+from dataclasses import dataclass
+import re
+
+from tools.codegen.context import method_with_native_function
+from tools.codegen.utils import Target, mapMaybe
+from tools.codegen.model import (Argument, ExternalBackendFunction,
+ ExternalBackendFunctionsGroup, SchemaKind,
+ assert_never, Return, is_generic_dispatch_key,
+ ListType, OptionalType, BaseType, BaseTy, Variant)
+from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
+import tools.codegen.api.dispatcher as dispatcher
+import tools.codegen.api.cpp as cpp
+
+# TODO: this contains a list of regex for ops that don't get a CPU fallback.
+# We should just register fallthroughs when we make the CPU fallback a boxed kernel.
+_FN_DENYLIST_REGEX = [
+ # ATEN functions
+ r'[^(]*cudnn',
+ r'slow_conv_transpose2d_backward.grad_output',
+ r'slow_conv_transpose3d_backward.grad_output',
+ r'slow_conv3d_backward.grad_input',
+ r'thnn_conv2d_backward.grad_input',
+ r'thnn_conv_depthwise2d_backward.grad_input',
+ # XLA/TPU functions
+]
+
+# TODO: remove this list.
+# Instead, the codegen will figure out which ops to generate _out wrappers for
+# entirely from the yaml. Maintaining the same behavior as current XLA codegen for now.
+_FN_OUT = [
+ 'abs',
+ 'add',
+ 'acos',
+ 'acosh',
+ 'asin',
+ 'asinh',
+ 'atan',
+ 'atan2',
+ 'atanh',
+ 'baddbmm',
+ 'bernoulli',
+ 'binary_cross_entropy',
+ 'binary_cross_entropy_backward',
+ 'clamp',
+ 'div',
+ 'gather',
+ 'ger',
+ 'hardsigmoid',
+ 'kthvalue',
+ 'index_select',
+ 'inverse',
+ 'log',
+ 'masked_select',
+ 'maximum',
+ 'minimum',
+ 'pow',
+ 'prod',
+ 'nonzero',
+ 'round',
+ 'normal',
+ 'std',
+ 'take',
+ 'topk',
+ 'var',
+]
+
+def requires_backend_wrapper(f: ExternalBackendFunction) -> bool:
+ requires_lowering = not any(is_generic_dispatch_key(k) for k in f.native_function.dispatch)
+ has_xla_lowering = f.metadata is not None
+ in_denylist = any([re.match(frx, str(f.native_function.func.name)) for frx in _FN_DENYLIST_REGEX])
+ return not in_denylist and (requires_lowering or has_xla_lowering)
+
+def xla_tensor_creation_api(
+ ret_name: str,
+ ret: Return,
+ device_param_name: str,
+ *,
+ cpu_result_name: str,
+ tuple_idx: Optional[int] = None
+) -> str:
+ if ret.type == BaseType(BaseTy.Tensor) and not ret.is_write:
+ # Only raw Tensor (non-reference) returns need to go through the XLA tensor creation API.
+ # Tensor references can be returned directly, since they've already been converted to XLA tensors.
+ # See Note [Tensor Copy Returns]
+ bridge_api = 'CreateXlaTensor'
+ elif isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor):
+ bridge_api = 'CreateXlaTensors'
+ else:
+ # for non tensor-types, there's no need to wrap the output in an xla bridge api.
+ return ret_name
+
+ return f"bridge::{bridge_api}({cpu_result_name}, bridge::GetXlaDevice({device_param_name}))"
+
+
+
+# Generates aten_xla_type_default.h and aten_xla_type_default.cpp.
+#
+# - This function registers external backend kernels, and also generates fallbacks to CPU.
+# This is useful because pretty much all external backends (e.g. XLA)
+# do not have full aten coverage.
+# For operators not implemented by the external backend, our codegen
+# will register these fallbacks instead.
+# - Why do we generate fallback for ALL aten ops, including ops that
+# external backends have already implemented?
+# Many external backend kernels only work with specific input shapes,
+# and are written to call into a cpu fallback when given inputs
+# that they cannot handle.
+@dataclass(frozen=True)
+class GenExternalAtenFallback:
+ target: Union[
+ Literal[Target.NAMESPACED_DEFINITION],
+ Literal[Target.NAMESPACED_DECLARATION],
+ Literal[Target.REGISTRATION],
+ ]
+
+ @method_with_native_function
+ def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
+
+ def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
+ dispatcher_sig = DispatcherSignature.from_schema(g.out.native_function.func)
+ name = dispatcher_sig.name()
+
+ dispatcher_order_args = dispatcher.jit_arguments(g.out.native_function.func)
+ tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
+ print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
+
+ func_name = f'AtenXlaTypeDefault::{name}'
+ functional_result_name = f'{name}_tmp'
+ return_names = cpp.return_names(g.out.native_function)
+ if len(return_names) > 1:
+ updates = '\n '.join(
+ f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});'
+ for i, ret_name in enumerate(return_names))
+ returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
+ else:
+ ret_name = return_names[0]
+ updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});'
+ returns = ret_name
+
+ functional_sig = DispatcherSignature.from_schema(g.functional.native_function.func)
+
+ return f"""\
+{dispatcher_sig.defn(name=func_name)} {{
+ XLA_FN_TRACK(3);
+ TF_VLOG(3) << "XLA {name} :"{print_args_str};
+ auto {functional_result_name} = AtenXlaType::{functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
+ {updates}
+ return {returns};
+}}
+
+"""
+
+ def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
+ if not requires_backend_wrapper(f):
+ return None
+
+ def get_device_param(args: List[Argument]) -> str:
+ # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
+ # to use as the device argument.
+ # We should update this to be consistent with how we choose device guards.
+ const_tensor_or_self = [
+ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor)))
+ and not a.is_write]
+ if any(const_tensor_or_self):
+ return const_tensor_or_self[0].name
+ tensor_like = [a for a in args if a.type.is_tensor_like()]
+ if any(tensor_like):
+ return tensor_like[0].name
+ device_like = [a for a in args if a.type == BaseType(BaseTy.Device)
+ or a.type == OptionalType(BaseType(BaseTy.Device))]
+ if any(device_like):
+ return device_like[0].name
+ raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
+
+ # XLA appears to have used the dispatcher convention to write their kernel signatures,
+ # probably because they based their signatures off of our RegistrationDeclarations.h
+ dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
+ name = dispatcher_sig.name()
+ args = dispatcher_sig.arguments()
+
+ if self.target is Target.NAMESPACED_DECLARATION:
+ return f" static {dispatcher_sig.decl()};"
+
+ elif self.target is Target.REGISTRATION:
+ if f.metadata is not None:
+ # xla has their own kernel: register it
+ namespace = 'AtenXlaType'
+ else:
+ # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
+ namespace = 'AtenXlaTypeDefault'
+ payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
+ return f' m.impl("{f.native_function.func.name}", {payload});\n'
+
+ if self.target is not Target.NAMESPACED_DEFINITION:
+ assert_never(self.target)
+
+ # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
+ # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
+ if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
+ and isinstance(g, ExternalBackendFunctionsGroup):
+ return gen_out_wrapper(g)
+
+ # Everything below here is where we generate the CPU fallback.
+ dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
+
+ # Map each argument to it's intermediate variable name in the fallback
+ # We have to do it separately for TensorList/Optional<Tensor>/Tensor
+ tensorlist_args: Dict[Argument, str] = {
+ a: f'l_{a.name}' for a in dispatcher_order_args
+ if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor)}
+
+ opt_tensors = [
+ a for a in dispatcher_order_args
+ if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor)]
+ opt_tensor_args: Dict[Argument, str] = {a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors)}
+
+ tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
+ tensor_args: Dict[Argument, str] = {a: f'xlatens[{i}]' for i, a in enumerate(tensors)}
+ annotated_tensor_indices: List[int] = [
+ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write]
+
+ print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys()])
+
+
+ tensorlist_intermediates_str = ''
+ if len(tensorlist_args) > 0:
+ tensorlist_intermediates_str = '\n'.join([f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
+ for arg, updated_name in tensorlist_args.items()])
+
+ opt_tensor_intermediates_str = ''
+ if len(opt_tensor_args) > 0:
+ arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
+ opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
+ opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'
+
+ intermediates = ''
+ if tensorlist_intermediates_str != '':
+ intermediates += tensorlist_intermediates_str + '\n'
+ intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
+ intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
+ if opt_tensor_intermediates_str != '':
+ intermediates += opt_tensor_intermediates_str
+
+
+ is_method = Variant.function not in f.native_function.variants
+ func_name = f'AtenXlaTypeDefault::{name}'
+
+ # Gather all of the updated variable names to call into the CPU operator.
+ # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
+ updated_bindings: List[str] = [
+ tensorlist_args.get(a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args]
+
+ at_call_name = CppSignatureGroup.from_native_function(
+ f.native_function, method=is_method).most_faithful_signature().name()
+
+ # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
+ # to the faithful C++ API, which are carefuly written to be exactly the same.
+ cpu_result_name = 'x_result'
+ if is_method:
+ at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
+ else:
+ at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
+ avoid_warning = ''
+ if f.native_function.func.returns:
+ at_call = f'auto&& {cpu_result_name} = {at_call}'
+ avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'
+
+ collect_mutated_tensors = ''
+ update_tensors = ''
+ if len(annotated_tensor_indices) > 0:
+ indices_str = ", ".join([str(i) for i in annotated_tensor_indices])
+ collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
+ update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'
+
+ returns = ''
+ if f.native_function.func.returns:
+ ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name)
+ if len(ret_names) == 1:
+ returns = xla_tensor_creation_api(
+ ret_names[0], f.native_function.func.returns[0],
+ get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name)
+ else:
+ return_args = [
+ xla_tensor_creation_api(
+ ret_names[i], f.native_function.func.returns[i],
+ get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})'
+ ) for i in range(len(f.native_function.func.returns))]
+ returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
+ return_str = ''
+ if returns != '':
+ return_str = f'\n return {returns};'
+
+ return f"""\
+{dispatcher_sig.defn(name=func_name)} {{
+ XLA_FN_TRACK(3);
+ XLA_COUNTER("aten::{name}", 1);
+ TF_VLOG(3) << "XLA {name} :"{print_args_str};
+{intermediates}
+ {at_call}{collect_mutated_tensors}{update_tensors}{avoid_warning}{return_str}
+}}
+
+"""
+ if isinstance(g, ExternalBackendFunctionsGroup):
+ if g.structured:
+ # We can probably only bother generating fallbacks for one of the variants, for structured
+ raise AssertionError("Not Implemented")
+ else:
+ return list(mapMaybe(gen_unstructured_external, g.functions()))
+ elif isinstance(g, ExternalBackendFunction):
+ f = g
+ x = gen_unstructured_external(f)
+ return [x] if x else []
+ else:
+ assert_never(f)
diff --git a/tools/codegen/dest/native_functions.py b/tools/codegen/dest/native_functions.py
index e1a208a..3b6bf6d 100644
--- a/tools/codegen/dest/native_functions.py
+++ b/tools/codegen/dest/native_functions.py
@@ -3,7 +3,9 @@
from tools.codegen.context import with_native_function
from tools.codegen.utils import concatMap
from tools.codegen.model import (NativeFunction, NativeFunctionsGroup,
+ ExternalBackendFunction, ExternalBackendFunctionsGroup,
is_structured_dispatch_key)
+from tools.codegen.api.types import DispatcherSignature, NativeSignature
import tools.codegen.api.meta as meta
import tools.codegen.api.native as native
import tools.codegen.api.structured as structured
@@ -11,6 +13,7 @@
@with_native_function
def gen_unstructured(f: NativeFunction) -> List[str]:
ns = list(f.dispatch.values())
+ native_sig = NativeSignature(f.func)
rs = []
# Sometimes a function name shows up multiple times; only generate
@@ -22,13 +25,22 @@
if "legacy::" in n:
continue
seen.add(n)
- returns_type = native.returns_type(f.func.returns).cpp_type()
- args = native.arguments(f.func)
- rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
+ rs.append(f"TORCH_API {native_sig.decl(name=n)};")
return rs
@with_native_function
+def gen_unstructured_external(f: ExternalBackendFunction) -> List[str]:
+ # XLA appears to have used the dispatcher convention to write their kernel signatures,
+ # probably because they based their signatures off of our RegistrationDeclarations.h
+ dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
+ if f.metadata is not None:
+ # Only generate declarations for operators that xla has defined in the yaml
+ return [f"static {dispatcher_sig.decl()};"]
+ else:
+ return []
+
+@with_native_function
def gen_structured(g: NativeFunctionsGroup) -> List[str]:
# only out has dispatch
meta_name = meta.name(g)
@@ -65,8 +77,17 @@
# Generates NativeFunctions.h, a list of forward declarations of all
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function
-def compute_native_function_declaration(g: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
- if isinstance(g, NativeFunctionsGroup):
+def compute_native_function_declaration(
+ g: Union[NativeFunctionsGroup, NativeFunction, ExternalBackendFunctionsGroup, ExternalBackendFunction]
+) -> List[str]:
+ if isinstance(g, ExternalBackendFunctionsGroup):
+ if g.structured:
+ raise AssertionError("Structured external backend functions are not implemented yet.")
+ else:
+ return list(concatMap(gen_unstructured_external, g.functions()))
+ elif isinstance(g, ExternalBackendFunction):
+ return gen_unstructured_external(g)
+ elif isinstance(g, NativeFunctionsGroup):
if g.structured:
return gen_structured(g)
else:
diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py
index 8669d68..522e330 100644
--- a/tools/codegen/gen.py
+++ b/tools/codegen/gen.py
@@ -763,6 +763,26 @@
return selector
+def get_grouped_native_functions(native_yaml_path: str) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
+ native_functions = parse_native_yaml(native_yaml_path)
+
+ pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]
+ pre_grouped_native_functions = defaultdict(dict)
+ for f in native_functions:
+ d = pre_grouped_native_functions[f.func.signature()]
+ assert f.func.kind() not in d
+ d[f.func.kind()] = f
+
+ def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
+ r = NativeFunctionsGroup.from_dict(d)
+ if r is None:
+ return list(d.values())
+ else:
+ return [r]
+
+ # TODO: how come ValuesView isn't a Sequence lol
+ return list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
+
def main() -> None:
parser = argparse.ArgumentParser(description='Generate ATen source files')
parser.add_argument(
@@ -817,24 +837,9 @@
options.op_selection_yaml_path,
)
- native_functions = parse_native_yaml(os.path.join(options.source_path, 'native/native_functions.yaml'))
-
- pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]
- pre_grouped_native_functions = defaultdict(dict)
- for f in native_functions:
- d = pre_grouped_native_functions[f.func.signature()]
- assert f.func.kind() not in d
- d[f.func.kind()] = f
-
- def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
- r = NativeFunctionsGroup.from_dict(d)
- if r is None:
- return list(d.values())
- else:
- return [r]
-
- # TODO: how come ValuesView isn't a Sequence lol
- grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
+ native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
+ native_functions = parse_native_yaml(native_yaml_path)
+ grouped_native_functions = get_grouped_native_functions(native_yaml_path)
structured_native_functions = [g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)]
template_dir = os.path.join(options.source_path, "templates")
diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py
new file mode 100644
index 0000000..a67dbb9
--- /dev/null
+++ b/tools/codegen/gen_backend_stubs.py
@@ -0,0 +1,126 @@
+import pathlib
+import argparse
+import os
+import yaml
+from typing import List, Dict, Union, Tuple, Sequence
+from tools.codegen.gen import FileManager, get_grouped_native_functions, LineLoader, parse_native_yaml
+from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
+ NativeFunction, NativeFunctionsGroup, OperatorName,
+ ExternalBackendMetadata, assert_never)
+from tools.codegen.selective_build.selector import SelectiveBuilder
+from tools.codegen.utils import Target, concatMap
+import tools.codegen.dest as dest
+
+def parse_backend_yaml(
+ backend_yaml_path: str,
+ grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
+) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
+ with open(backend_yaml_path, 'r') as f:
+ yaml_values = yaml.load(f, Loader=LineLoader)
+ assert isinstance(yaml_values, dict)
+
+ cpp_namespace = yaml_values.pop('cpp_namespace')
+ backend = yaml_values.pop('backend')
+
+ supported = yaml_values.pop('supported', [])
+ assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported}'
+ supported_autograd = yaml_values.pop('autograd', [])
+ assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'
+
+ assert len(yaml_values.keys()) > 0, \
+ f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}'
+
+ metadata: Dict[OperatorName, ExternalBackendMetadata] = {}
+ for op in supported:
+ op_name = OperatorName.parse(op)
+ m = ExternalBackendMetadata(op_name, backend, is_autograd=False)
+ metadata[m.operator] = m
+ for op in supported_autograd:
+ op_name = OperatorName.parse(op)
+ m = ExternalBackendMetadata(op_name, backend, is_autograd=True)
+ metadata[m.operator] = m
+
+ native_functions_map: Dict[OperatorName, NativeFunction] = {
+ f.func.name: f
+ for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
+ }
+
+ def native_to_external(
+ g: Union[NativeFunction, NativeFunctionsGroup]
+ ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
+ if isinstance(g, NativeFunction):
+ f = g
+ m = metadata.get(f.func.name, None)
+ return ExternalBackendFunction(f, m)
+ elif isinstance(g, NativeFunctionsGroup):
+ return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
+ else:
+ assert_never(g)
+ for op_name in metadata.keys():
+ if op_name not in native_functions_map:
+ raise AssertionError(f"Found an invalid operator name: {op_name}")
+ return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description='Generate backend stub files')
+ parser.add_argument(
+ '-s',
+ '--source_yaml',
+ help='path to source yaml file containing operator external definitions')
+ parser.add_argument(
+ '-o', '--output_dir', help='output directory')
+ parser.add_argument(
+ '--dry_run', type=bool, default=False, help='output directory')
+ options = parser.parse_args()
+
+ # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
+ pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+ template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
+
+ def make_file_manager(install_dir: str) -> FileManager:
+ return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)
+
+ fm = make_file_manager(options.output_dir)
+
+ native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
+ grouped_native_functions = get_grouped_native_functions(native_yaml_path)
+ cpp_namespace, external_backend_functions = parse_backend_yaml(options.source_yaml, grouped_native_functions)
+
+ native_functions = parse_native_yaml(native_yaml_path)
+
+ selector = SelectiveBuilder.get_nop_selector()
+
+
+ generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
+ fm.write('aten_xla_type.h', lambda: {
+ 'generated_comment': generated_comment,
+ 'cpp_namespace': cpp_namespace,
+ 'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
+ })
+
+ fm.write('aten_xla_type_default.h', lambda: {
+ 'generated_comment': generated_comment,
+ 'cpp_namespace': cpp_namespace,
+ 'dispatch_aten_fallback_declarations': list(concatMap(
+ dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions
+ )),
+ })
+
+ fm.write('aten_xla_type_default.cpp', lambda: {
+ 'generated_comment': generated_comment,
+ 'cpp_namespace': cpp_namespace,
+ # TODO: after cpu fallbacks are moved to a boxed kernel,
+ # merge registrations / definitions into RegisterDispatchKey
+ 'dispatch_aten_fallback_definitions': list(concatMap(
+ dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
+ )),
+ 'dispatch_registrations': list(concatMap(
+ dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
+ )),
+ 'dispatch_autograd_registrations': list(concatMap(
+ dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
+ )),
+ })
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/codegen/model.py b/tools/codegen/model.py
index 8acf296..1c7db29 100644
--- a/tools/codegen/model.py
+++ b/tools/codegen/model.py
@@ -1343,6 +1343,100 @@
else:
return f"{self.name}"
+@dataclass(frozen=True)
+class ExternalBackendMetadata:
+
+ operator: OperatorName
+ backend: str
+ is_autograd: bool
+
+ structured: bool = False # TODO: this will eventually become per-op metadata in the yaml file
+
+@dataclass(frozen=True)
+class ExternalBackendFunction:
+
+ native_function: NativeFunction
+ metadata: Optional[ExternalBackendMetadata]
+
+ @property
+ def structured(self) -> bool:
+ # An external backend op is only considered structured if it's been marked structured both in-tree and out-of-tree
+ return self.native_function.structured and self.metadata is not None and self.metadata.structured
+
+ @property
+ def is_autograd_kernel(self) -> bool:
+ return self.metadata is not None and self.metadata.is_autograd
+
+ def __post_init__(self) -> None:
+ if self.metadata is not None:
+ assert self.metadata.operator == self.native_function.func.name, \
+ f'Metadata and native function names do not match: {self.metadata.operator} and {self.native_function.func.name}'
+ kind = self.native_function.func.kind()
+ if kind == SchemaKind.out or kind == SchemaKind.inplace:
+ assert self.metadata is None or not self.metadata.structured, \
+ "Found an out/inplace operator marked with the structured keyword." \
+ f" Only functional operators can be marked as structured. operator={str(self.native_function.func.name)}"
+
+@dataclass(frozen=True)
+class ExternalBackendFunctionsGroup:
+ functional: ExternalBackendFunction
+ inplace: Optional[ExternalBackendFunction]
+ out: ExternalBackendFunction
+
+ @property
+ def structured(self) -> bool:
+ return self.primary.structured
+
+ @property
+ def primary(self) -> ExternalBackendFunction:
+ # TODO: hardcoding that XLA will only implement functional variants of structured kernel.
+ # This will eventually be toggleable per backend.
+ return self.functional
+
+ @property
+ def is_autograd_kernel(self) -> bool:
+ return self.primary.metadata is not None and self.primary.metadata.is_autograd
+
+ def __post_init__(self) -> None:
+ # Note: I didn't want to copy-paste the post_init checks that NativeFunctionsGroup performs.
+ # ExternalBackendFunctionsGroup objects should be created using `from_function_group` (below),
+ # which guarantees that the relevant checks have already been performed.
+ if self.structured:
+ for f in self.functions():
+ if f == self.primary:
+ continue
+ # For ops marked as structured externally, we expect external backends to
+ # only include either the functional or out variant in their yaml
+ assert f.metadata is None, \
+ f"{str(self.primary.native_function.func.name)} is marked as structured. " \
+ f"variant, {str(f.native_function.func.name)} will be generated for you " \
+ "and doesn't need to live in the yaml."
+
+ def functions(self) -> Iterator[ExternalBackendFunction]:
+ yield self.out
+ yield self.functional
+ if self.inplace is not None:
+ yield self.inplace
+
+ @staticmethod
+ def from_function_group(
+ g: NativeFunctionsGroup,
+ metadata: Dict[OperatorName, ExternalBackendMetadata]
+ ) -> 'ExternalBackendFunctionsGroup':
+ out_meta = metadata.get(g.out.func.name, None)
+ out = ExternalBackendFunction(g.out, out_meta)
+
+ functional_meta = metadata.get(g.functional.func.name, None)
+ functional = ExternalBackendFunction(g.functional, functional_meta)
+
+ inplace = None
+ if g.inplace:
+ inplace_meta = metadata.get(g.inplace.func.name, None)
+ inplace = ExternalBackendFunction(g.inplace, inplace_meta)
+
+ return ExternalBackendFunctionsGroup(functional, inplace, out)
+
+
# Helper functions for parsing argument lists (both inputs and returns)
def parse_returns(return_decl: str) -> Tuple[Return, ...]: