| from typing import List, Union, Set, Any |
| |
| from tools.codegen.context import with_native_function |
| from tools.codegen.utils import concatMap |
| from tools.codegen.model import (NativeFunction, NativeFunctionsGroup, |
| is_structured_dispatch_key) |
| import tools.codegen.api.meta as meta |
| import tools.codegen.api.native as native |
| import tools.codegen.api.structured as structured |
| |
| @with_native_function |
| def gen_unstructured(f: NativeFunction) -> List[str]: |
| ns = list(f.dispatch.values()) |
| |
| rs = [] |
| # Sometimes a function name shows up multiple times; only generate |
| # it once! |
| seen = set() |
| for n in ns: |
| if n in seen: |
| continue |
| if "legacy::" in n: |
| continue |
| seen.add(n) |
| returns_type = native.returns_type(f.func.returns).cpp_type() |
| args = native.arguments(f.func) |
| rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});") |
| |
| return rs |
| |
| @with_native_function |
| def gen_structured(g: NativeFunctionsGroup) -> List[str]: |
| # only out has dispatch |
| meta_name = meta.name(g) |
| rs = [] |
| seen: Set[Any] = set() |
| out_args = structured.impl_arguments(g) |
| for k, n in g.out.dispatch.items(): |
| if n in seen: |
| continue |
| if not is_structured_dispatch_key(k): |
| continue |
| seen.add(n) |
| rs.append(f"""\ |
| struct TORCH_API structured_{n} : public at::meta::{meta_name} {{ |
| void impl({', '.join(a.decl() for a in out_args)}); |
| }}; |
| """) |
| |
| seen = set() |
| for f in g.functions(): |
| returns_type = native.returns_type(f.func.returns).cpp_type() |
| args = native.arguments(f.func) |
| for k, n in f.dispatch.items(): |
| if n in seen: |
| continue |
| if is_structured_dispatch_key(k): |
| continue |
| seen.add(n) |
| args_str = ', '.join(a.decl() for a in args) |
| rs.append(f"TORCH_API {returns_type} {n}({args_str});") |
| |
| return rs |
| |
| # Generates NativeFunctions.h, a list of forward declarations of all |
| # actual kernel definitions we keep in aten/src/ATen/native/ |
| @with_native_function |
| def compute_native_function_declaration(g: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: |
| if isinstance(g, NativeFunctionsGroup): |
| if g.structured: |
| return gen_structured(g) |
| else: |
| return list(concatMap(gen_unstructured, g.functions())) |
| else: |
| return gen_unstructured(g) |