Move compute_native_function_declaration to its own dest module (#54419)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54419
I'm planning to break it into some helper functions, so let's put it in its own module first.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D27235378
Pulled By: ezyang
fbshipit-source-id: c03c5440d2d753859e2c5ec2b2c8b1b82870f03a
diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py
index 6db59b7..ab4bada 100644
--- a/tools/codegen/dest/__init__.py
+++ b/tools/codegen/dest/__init__.py
@@ -1 +1,2 @@
from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey
+from .native_functions import compute_native_function_declaration as compute_native_function_declaration
diff --git a/tools/codegen/dest/native_functions.py b/tools/codegen/dest/native_functions.py
new file mode 100644
index 0000000..6c3e842
--- /dev/null
+++ b/tools/codegen/dest/native_functions.py
@@ -0,0 +1,66 @@
+from typing import List, Union, Set, Any
+
+from tools.codegen.context import *
+from tools.codegen.utils import *
+from tools.codegen.model import *
+from tools.codegen.api.types import *
+import tools.codegen.api.meta as meta
+import tools.codegen.api.native as native
+import tools.codegen.api.structured as structured
+
+# Generates NativeFunctions.h, a list of forward declarations of all
+# actual kernel definitions we keep in aten/src/ATen/native/
+@with_native_function
+def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
+ if isinstance(g, StructuredNativeFunctions):
+ # only out has dispatch
+ meta_name = meta.name(g)
+ rs = []
+ seen: Set[Any] = set()
+ out_args = structured.impl_arguments(g)
+ for k, n in g.out.dispatch.items():
+ if n in seen:
+ continue
+ if not is_structured_dispatch_key(k):
+ continue
+ seen.add(n)
+ rs.append(f"""\
+struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
+ void impl({', '.join(a.decl() for a in out_args)});
+}};
+""")
+
+ seen = set()
+ for f in g.functions():
+ returns_type = native.returns_type(f.func.returns)
+ args = native.arguments(f.func)
+ for k, n in f.dispatch.items():
+ if n in seen:
+ continue
+ if is_structured_dispatch_key(k):
+ continue
+ seen.add(n)
+ args_str = ', '.join(a.decl() for a in args)
+ rs.append(f"TORCH_API {returns_type} {n}({args_str});")
+
+ return rs
+
+ else:
+ f = g
+ ns = list(f.dispatch.values())
+
+ rs = []
+ # Sometimes a function name shows up multiple times; only generate
+ # it once!
+ seen = set()
+ for n in ns:
+ if n in seen:
+ continue
+ if "legacy::" in n:
+ continue
+ seen.add(n)
+ returns_type = native.returns_type(f.func.returns)
+ args = native.arguments(f.func)
+ rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
+
+ return rs
diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py
index 5009363..a2e660f 100644
--- a/tools/codegen/gen.py
+++ b/tools/codegen/gen.py
@@ -319,63 +319,6 @@
def compute_aten_op(f: NativeFunction) -> str:
return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
-# Generates NativeFunctions.h, a list of forward declarations of all
-# actual kernel definitions we keep in aten/src/ATen/native/
-@with_native_function
-def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
- if isinstance(g, StructuredNativeFunctions):
- # only out has dispatch
- meta_name = meta.name(g)
- rs = []
- seen: Set[Any] = set()
- out_args = structured.impl_arguments(g)
- for k, n in g.out.dispatch.items():
- if n in seen:
- continue
- if not is_structured_dispatch_key(k):
- continue
- seen.add(n)
- rs.append(f"""\
-struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
- void impl({', '.join(a.decl() for a in out_args)});
-}};
-""")
-
- seen = set()
- for f in g.functions():
- returns_type = native.returns_type(f.func.returns)
- args = native.arguments(f.func)
- for k, n in f.dispatch.items():
- if n in seen:
- continue
- if is_structured_dispatch_key(k):
- continue
- seen.add(n)
- args_str = ', '.join(a.decl() for a in args)
- rs.append(f"TORCH_API {returns_type} {n}({args_str});")
-
- return rs
-
- else:
- f = g
- ns = list(f.dispatch.values())
-
- rs = []
- # Sometimes a function name shows up multiple times; only generate
- # it once!
- seen = set()
- for n in ns:
- if n in seen:
- continue
- if "legacy::" in n:
- continue
- seen.add(n)
- returns_type = native.returns_type(f.func.returns)
- args = native.arguments(f.func)
- rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
-
- return rs
-
# Generates MetaFunctions.h
def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str:
with native_function_manager(g.out):
@@ -1027,7 +970,7 @@
'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
})
cpu_fm.write('NativeFunctions.h', lambda: {
- 'native_function_declarations': list(concatMap(compute_native_function_declaration, grouped_native_functions)),
+ 'native_function_declarations': list(concatMap(dest.compute_native_function_declaration, grouped_native_functions)),
})
cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]))