| from typing import List, Optional, Union |
| import itertools |
| from typing_extensions import Literal |
| from dataclasses import dataclass |
| |
| from tools.codegen.context import * |
| from tools.codegen.utils import * |
| from tools.codegen.model import * |
| from tools.codegen.api.types import * |
| import tools.codegen.api.meta as meta |
| import tools.codegen.api.structured as structured |
| from tools.codegen.api.translate import translate |
| import tools.codegen.local as local |
| from tools.codegen.selective_build.selector import SelectiveBuilder |
| |
| # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). |
| # |
| # - The primary function of this file is to register all of the |
| # implementations for the given dispatch key to the dispatcher, |
| # so they are available for use in PyTorch. If dispatch is |
| # None, we generate schema (def) registrations and catchall |
| # registrations. |
| # - The secondary function of this file is to generate a wrapper |
| # around functions. In CPUType these wrappers do nothing |
| # (and should be removed), but in other cases they handle |
| # DeviceGuard. A small extra benefit of wrappers is they |
| # are not overloaded, so they can be used in the registration |
| # API without having to disambiguate which overload you want |
| # (as would be the case if you directly registered native:: |
| # functions). |
| # - The tertiary function of this file is to generate *static* |
| # cpp API bindings which can be used to bypass dispatcher |
| # directly to kernels, but with user-friendly cpp-style API |
| @dataclass(frozen=True) |
| class RegisterDispatchKey: |
| dispatch_key: DispatchKey |
| |
| target: Union[ |
| Literal[Target.ANONYMOUS_DEFINITION], |
| Literal[Target.NAMESPACED_DEFINITION], |
| Literal[Target.NAMESPACED_DECLARATION], |
| Literal[Target.REGISTRATION] |
| ] |
| |
| # Selector object to determine which operators to generate |
| # registration code for. |
| selector: SelectiveBuilder |
| |
| # Whether or not we are actually code-genning for ROCm |
| rocm: bool |
| |
| @method_with_native_function |
| def __call__(self, f: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]: |
| if isinstance(f, StructuredNativeFunctions): |
| return self.gen_structured(f) |
| elif isinstance(f, NativeFunction): |
| r = self.gen_unstructured(f) |
| return [] if r is None else [r] |
| else: |
| assert_never(f) |
| |
| def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: |
| if self.dispatch_key == DispatchKey.Meta: |
| assert self.dispatch_key not in g.out.dispatch, \ |
| "Do not explicitly specify Meta dispatch key on structured " \ |
| "functions, they will be automatically generated for you" |
| elif self.dispatch_key == DispatchKey.DefaultBackend: |
| assert self.dispatch_key not in g.out.dispatch, \ |
| "Do not explicitly specify DefaultBackend dispatch key on structured " \ |
| "functions, they will be automatically generated for you" |
| elif not is_structured_dispatch_key(self.dispatch_key): |
| return list(mapMaybe(self.gen_unstructured, g.functions())) |
| elif self.dispatch_key not in g.out.dispatch: |
| return [] |
| |
| structured_gen = StructuredRegisterDispatchKey( |
| self.dispatch_key, |
| self.target, |
| self.selector, |
| self.rocm, |
| g |
| ) |
| return list(mapMaybe(structured_gen.gen_one, g.functions())) |
| |
| @method_with_native_function |
| def gen_unstructured(self, f: NativeFunction) -> Optional[str]: |
| if self.dispatch_key not in f.dispatch: |
| return None |
| if f.manual_kernel_registration: |
| return None |
| |
| if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): |
| return None |
| |
| sig = NativeSignature(f.func, prefix='wrapper_') |
| |
| name = sig.name() |
| returns_type = sig.returns_type() |
| args = sig.arguments() |
| args_str = ', '.join(a.defn() for a in args) |
| |
| # See Note [Direct dispatch bindings] |
| cpp_sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False) |
| |
| if self.target is Target.NAMESPACED_DECLARATION: |
| result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" |
| if cpp_sig_group.faithful_signature is not None: |
| result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" |
| return result |
| elif self.target is Target.NAMESPACED_DEFINITION: |
| def generate_defn(cpp_sig: CppSignature) -> str: |
| return f""" |
| {cpp_sig.defn()} {{ |
| return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); |
| }} |
| """ |
| result = generate_defn(cpp_sig_group.signature) |
| if cpp_sig_group.faithful_signature is not None: |
| result += generate_defn(cpp_sig_group.faithful_signature) |
| return result |
| elif self.target is Target.ANONYMOUS_DEFINITION: |
| impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" |
| |
| args_exprs_str = ', '.join(a.name for a in args) |
| |
| return_kw = " return " |
| |
| cuda_guard = "" |
| if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key): |
| self_arg = [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else [] |
| |
| # There is precedence for which argument we use to do |
| # device guard. This describes the precedence order. |
| candidate_args = itertools.chain( |
| self_arg, |
| f.func.arguments.out, |
| f.func.arguments.flat_positional |
| ) |
| |
| # Only tensor like arguments are eligible |
| device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) |
| |
| has_tensor_options = any(isinstance(a.argument, TensorOptionsArguments) for a in args) |
| |
| if local.use_c10_dispatcher() == UseC10Dispatcher.full: |
| cuda_guard_from_tensor_options = """\ |
| const DeviceGuard device_guard(device_or_default(device)); |
| """ |
| else: |
| assert local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures |
| cuda_guard_from_tensor_options = """\ |
| const DeviceGuard device_guard(options.device()); |
| """ |
| |
| # TODO: There is probably a simpler version of this that |
| # works just as well. |
| if f.device_guard and is_generic_dispatch_key(self.dispatch_key) and has_tensor_options: |
| cuda_guard = cuda_guard_from_tensor_options |
| elif f.device_guard and is_cuda_dispatch_key(self.dispatch_key) and has_tensor_options: |
| cuda_guard = f"""\ |
| globalContext().lazyInitCUDA(); |
| {cuda_guard_from_tensor_options} |
| """ |
| elif f.device_guard and device_of is not None: |
| cuda_guard = f"""\ |
| const OptionalDeviceGuard device_guard(device_of({device_of})); |
| """ |
| else: |
| cuda_guard = """\ |
| // DeviceGuard omitted |
| """ |
| |
| return f"""\ |
| namespace {{ |
| |
| {returns_type} {name}({args_str}) {{ |
| {cuda_guard}{return_kw}{impl_name}({args_exprs_str}); |
| }} |
| |
| }} // anonymous namespace |
| """ |
| |
| elif self.target is Target.REGISTRATION: |
| if f.manual_kernel_registration: |
| return None |
| else: |
| dispatcher_sig = DispatcherSignature.from_schema(f.func) |
| |
| # Figure out which signature the function is |
| if local.use_c10_dispatcher() is UseC10Dispatcher.full: |
| payload = f"TORCH_FN({name})" |
| else: |
| assert local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures |
| payload = f""" |
| c10::impl::hacky_wrapper_for_legacy_signatures< |
| {dispatcher_sig.type()}, |
| {len(f.func.arguments.out)} |
| >(TORCH_FN({name})) |
| """ |
| |
| return f'm.impl("{f.func.name}",\n{payload});\n' |
| else: |
| assert_never(self.target) |
| |
| |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| # |
| # STRUCTURED |
| # |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| |
| @dataclass(frozen=True) |
| class StructuredRegisterDispatchKey(RegisterDispatchKey): |
| g: StructuredNativeFunctions |
| |
| def gen_class_set_output(self, k: SchemaKind, parent_class: str, generate_super: bool) -> str: |
| if generate_super: |
| set_output_super = f"{parent_class}::set_output(output_idx, sizes, strides, options, names);" |
| else: |
| set_output_super = "" |
| return f""" |
| void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, |
| TensorOptions options, DimnameList names) override {{ |
| {self.gen_class_set_output_body(k)} |
| if (!names.empty()) namedinference::propagate_names(outputs_[output_idx], names); |
| // super must happen after, so that downstream can use maybe_get_output |
| // to retrieve the output |
| {set_output_super} |
| }} |
| """ |
| |
| def gen_class_set_output_body(self, k: SchemaKind) -> str: |
| if self.dispatch_key in [DispatchKey.CUDA, DispatchKey.DefaultBackend]: |
| maybe_set_guard = """ |
| auto current_device = guard_.current_device(); |
| if (C10_UNLIKELY(current_device.has_value())) { |
| TORCH_INTERNAL_ASSERT(*current_device == options.device(), |
| "structured kernels don't support multi-device outputs"); |
| } else { |
| guard_.reset_device(options.device()); |
| } |
| """ |
| else: |
| maybe_set_guard = '' |
| |
| if k is SchemaKind.functional: |
| if self.dispatch_key == DispatchKey.Meta: |
| return """ |
| if (strides.empty()) { |
| outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); |
| } else { |
| TORCH_INTERNAL_ASSERT(0, "not implemented yet"); |
| } |
| """ |
| else: |
| expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ |
| "options.device_opt(), options.pinned_memory_opt()" |
| if self.dispatch_key == DispatchKey.CPU: |
| empty_impl = "at::native::empty_cpu" |
| empty_strided_impl = "at::native::empty_strided_cpu" |
| elif self.dispatch_key == DispatchKey.CUDA: |
| empty_impl = "at::native::empty_cuda" |
| empty_strided_impl = "at::native::empty_strided_cuda" |
| elif self.dispatch_key == DispatchKey.DefaultBackend: |
| empty_impl = "at::empty" |
| empty_strided_impl = "at::empty_strided" |
| else: |
| raise AssertionError("unsupported dispatch key") |
| return f""" |
| {maybe_set_guard} |
| if (strides.empty()) {{ |
| outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); |
| }} else {{ |
| outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts}); |
| }} |
| """ |
| elif k is SchemaKind.inplace: |
| return maybe_set_guard |
| elif k is SchemaKind.out: |
| if self.dispatch_key == DispatchKey.CPU: |
| resize_impl = "resize_output_cpu" |
| else: |
| # Only bothering to include a resize_output fastpath for CPU for now. |
| # We can add one in if for the perf if we need to. But it'll be easier when external backends |
| # have access to meta functions, and we can write one for resize_. |
| resize_impl = "resize_output" |
| return f""" |
| {maybe_set_guard} |
| at::native::{resize_impl}(outputs_[output_idx], sizes); |
| if (!strides.empty()) {{ |
| TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); |
| at::native::as_strided_(outputs_[output_idx], sizes, strides); |
| }} else if (options.memory_format_opt().has_value()) {{ |
| outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); |
| }} |
| """ |
| else: |
| assert_never(k) |
| |
| # returns the definition of a ctor, as well as how to construct |
| # this class to a variable named op |
| def gen_class_ctor(self, k: SchemaKind, class_name: str) -> str: |
| if k is SchemaKind.functional: |
| return "" |
| elif k is SchemaKind.inplace: |
| # TODO: Make sure out argument is guaranteed to be self |
| return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" |
| elif k is SchemaKind.out: |
| # TODO: Stop hardcoding out here |
| return f"{class_name}(Tensor& out) : outputs_{{std::ref(out)}} {{}}" |
| else: |
| assert_never(k) |
| |
| def gen_class( |
| self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool |
| ) -> str: |
| if k is SchemaKind.functional: |
| assert len(f.func.returns) == 1, "multi-return not supported yet" |
| output_type = "Tensor" |
| elif k is SchemaKind.inplace: |
| output_type = "std::reference_wrapper<Tensor>" |
| elif k is SchemaKind.out: |
| assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" |
| output_type = "std::reference_wrapper<Tensor>" |
| |
| if self.dispatch_key == DispatchKey.CUDA: |
| if self.rocm: |
| guard_field = 'c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;' |
| else: |
| guard_field = 'c10::cuda::OptionalCUDAGuard guard_;' |
| elif self.dispatch_key == DispatchKey.DefaultBackend: |
| guard_field = 'c10::OptionalDeviceGuard guard_;' |
| else: |
| guard_field = '' |
| |
| return f""" |
| struct {class_name} final : public {parent_class} {{ |
| {self.gen_class_ctor(k, class_name)} |
| {self.gen_class_set_output(k, parent_class, generate_super)} |
| const Tensor& maybe_get_output(int64_t output_idx) override {{ |
| return outputs_[output_idx]; |
| }} |
| std::array<{output_type}, {len(f.func.returns)}> outputs_; |
| {guard_field} |
| }}; |
| """ |
| |
| @method_with_native_function |
| def gen_one(self, f: NativeFunction) -> Optional[str]: |
| assert not f.manual_kernel_registration |
| |
| if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): |
| return None |
| |
| # TODO: Now, there is something interesting going on here. In the code below, |
| # we generate DefaultBackend implementations of functional and inplace |
| # based on the out implementation. But in fact, out is definable by |
| # functional too (just not very efficiently), and this is honestly the |
| # MORE likely situation for a backend implementor. How do we pick? |
| # Well, taking a page from Haskell type classes and default methods, |
| # we could conceivably register a circular definition (out in terms |
| # of functional, and functional in terms of out) and just require |
| # someone to implement one or the other. We'd have to do a little bit |
| # of work to not register one of these "weak" definitions unless there |
| # is a strong definition somewhere in the DAG! So it's not implemented yet. |
| if self.dispatch_key == DispatchKey.DefaultBackend and f.func.kind() is SchemaKind.out: |
| # Never generate a default implementation for out, that's what you |
| # have to define as a backend implementor |
| return None |
| |
| # Note [Direct dispatch bindings] |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # Signature of the non-dispatched function we'll expose in a header |
| # (e.g., at::cpu::add). We don't generate methods (TODO: do this |
| # when CPUTensor class is a thing); nor do we generate fallback |
| # bindings for manual_cpp_binding functions. |
| cpp_sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False) |
| |
| # Signature of the wrapper function we'll register to the dispatcher |
| sig = NativeSignature(f.func, prefix="wrapper_") |
| |
| if self.target is Target.NAMESPACED_DECLARATION: |
| result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" |
| if cpp_sig_group.faithful_signature is not None: |
| result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" |
| return result |
| |
| elif self.target is Target.NAMESPACED_DEFINITION: |
| def generate_defn(cpp_sig: CppSignature) -> str: |
| return f""" |
| {cpp_sig.defn()} {{ |
| return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); |
| }} |
| """ |
| result = generate_defn(cpp_sig_group.signature) |
| if cpp_sig_group.faithful_signature is not None: |
| result += generate_defn(cpp_sig_group.faithful_signature) |
| return result |
| |
| elif self.target is Target.ANONYMOUS_DEFINITION: |
| |
| k = f.func.kind() |
| |
| # Construct the body of the wrapper function with signature sig |
| sig_body = [] |
| # We'll use context to keep track of any variables we've brought |
| # into scope while generating code |
| context: List[Union[Binding, Expr]] = list(sig.arguments()) |
| |
| # Initialize the class corresponding to this structured |
| # operator; feeding it the output argument(s) if it is known |
| if self.dispatch_key is DispatchKey.Meta: |
| class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" |
| parent_class = f"at::meta::{meta.name(self.g)}" |
| elif self.dispatch_key is DispatchKey.DefaultBackend: |
| # TODO: dedup this branch |
| class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" |
| parent_class = f"at::meta::{meta.name(self.g)}" |
| else: |
| class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}" |
| parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}" |
| |
| if k is SchemaKind.functional: |
| assert len(f.func.returns) == 1, "multi-return not supported yet" |
| sig_body.append(f"{class_name} op;") |
| elif k is SchemaKind.inplace: |
| sig_body.append(f"{class_name} op(self);") |
| elif k is SchemaKind.out: |
| assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" |
| sig_body.append(f"{class_name} op({f.func.arguments.out[0].name});") |
| |
| # Translate the input native arguments into structured |
| # arguments for the meta call |
| meta_exprs = ', '.join( |
| e.expr for e in translate( |
| context, |
| structured.meta_arguments(self.g), |
| method=False |
| ) |
| ) |
| sig_body.append(f"op.meta({meta_exprs});") |
| |
| # After running meta, op.outputs_ is guaranteed to be valid; |
| # add it to the context |
| # TODO: handle multi-return |
| assert ConstRefCType(BaseCType("Tensor", structured.out_arguments(self.g)[0].ctype.name)) == \ |
| structured.out_arguments(self.g)[0].ctype |
| context.append(Expr( |
| expr="op.outputs_[0]", |
| # TODO: Stop hardcoding that the output type is a Tensor. Note |
| # that for the codegen here this is fine because outputs_ is |
| # hardcoded to be tensor already |
| type=MutRefCType(BaseCType("Tensor", structured.out_arguments(self.g)[0].ctype.name)), |
| )) |
| |
| # With the expanded context, do the impl call (if not a meta |
| # function) |
| if self.dispatch_key == DispatchKey.DefaultBackend: |
| # TODO: https://github.com/pytorch/pytorch/issues/53023 |
| out_sig_group = CppSignatureGroup.from_native_function( |
| self.g.out, method=False, fallback_binding=f.manual_cpp_binding) |
| out_sig = out_sig_group.most_faithful_signature() |
| api_name = out_sig.name() |
| out_exprs = ', '.join( |
| e.expr for e in translate( |
| context, |
| out_sig.arguments(), |
| method=False |
| ) |
| ) |
| # TODO: I think this means structured won't work with method |
| # only functions (but maybe you're saved by faithful? iunno.) |
| # NB: Originally I wrote this as an at::redispatch call, but |
| # I got in trouble because that meant I needed a DispatchKeySet |
| # in the wrapper function, which meant I needed a DispatchKeySet |
| # in the DispatchKeyFunctions declarations, but the defined API |
| # there does NOT permit a dispatch key set. I think you can |
| # probably unwind this by calling some function to do the TLS |
| # fetch and get the DispatchKeySet when you don't have it, but |
| # I didn't do it for this version |
| sig_body.append(f"at::{api_name}({out_exprs});") |
| elif self.dispatch_key != DispatchKey.Meta: |
| impl_exprs = ', '.join( |
| e.expr for e in translate( |
| context, |
| structured.impl_arguments(self.g), |
| method=False |
| ) |
| ) |
| sig_body.append(f"op.impl({impl_exprs});") |
| |
| # Destructively return the final tensors |
| if k is SchemaKind.functional: |
| assert len(f.func.returns) == 1, "multi-return not supported yet" |
| ret_expr = "std::move(op.outputs_[0])" # small optimization |
| elif k is SchemaKind.inplace: |
| ret_expr = "self" |
| elif k is SchemaKind.out: |
| assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" |
| ret_expr = f.func.arguments.out[0].name |
| sig_body.append(f"return {ret_expr};") |
| |
| sig_body_str = "\n".join(sig_body) |
| |
| # For an overview of what this template code looks like, see |
| # https://github.com/pytorch/rfcs/pull/9 |
| return f"""\ |
| {self.gen_class( |
| f, k, |
| class_name=class_name, |
| parent_class=parent_class, |
| generate_super=self.g.out.structured_inherits is not None |
| )} |
| |
| {sig.defn()} {{ |
| {sig_body_str} |
| }} |
| """ |
| |
| elif self.target is Target.REGISTRATION: |
| assert local.use_c10_dispatcher() is UseC10Dispatcher.full |
| return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' |
| else: |
| assert_never(self.target) |
| # Silence mypy's "Missing return statement" error |
| return None |