Move compute_native_function_declaration to its own dest module (#54419)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54419

I'm planning to break it into some helper functions, so let's put it in its own module first.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D27235378

Pulled By: ezyang

fbshipit-source-id: c03c5440d2d753859e2c5ec2b2c8b1b82870f03a
diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py
index 6db59b7..ab4bada 100644
--- a/tools/codegen/dest/__init__.py
+++ b/tools/codegen/dest/__init__.py
@@ -1 +1,2 @@
 from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey
+from .native_functions import compute_native_function_declaration as compute_native_function_declaration
diff --git a/tools/codegen/dest/native_functions.py b/tools/codegen/dest/native_functions.py
new file mode 100644
index 0000000..6c3e842
--- /dev/null
+++ b/tools/codegen/dest/native_functions.py
@@ -0,0 +1,66 @@
+from typing import List, Union, Set, Any
+
+from tools.codegen.context import *
+from tools.codegen.utils import *
+from tools.codegen.model import *
+from tools.codegen.api.types import *
+import tools.codegen.api.meta as meta
+import tools.codegen.api.native as native
+import tools.codegen.api.structured as structured
+
+# Generates NativeFunctions.h, a list of forward declarations of all
+# actual kernel definitions we keep in aten/src/ATen/native/
+@with_native_function
+def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
+    if isinstance(g, StructuredNativeFunctions):
+        # only out has dispatch
+        meta_name = meta.name(g)
+        rs = []
+        seen: Set[Any] = set()
+        out_args = structured.impl_arguments(g)
+        for k, n in g.out.dispatch.items():
+            if n in seen:
+                continue
+            if not is_structured_dispatch_key(k):
+                continue
+            seen.add(n)
+            rs.append(f"""\
+struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
+    void impl({', '.join(a.decl() for a in out_args)});
+}};
+""")
+
+        seen = set()
+        for f in g.functions():
+            returns_type = native.returns_type(f.func.returns)
+            args = native.arguments(f.func)
+            for k, n in f.dispatch.items():
+                if n in seen:
+                    continue
+                if is_structured_dispatch_key(k):
+                    continue
+                seen.add(n)
+                args_str = ', '.join(a.decl() for a in args)
+                rs.append(f"TORCH_API {returns_type} {n}({args_str});")
+
+        return rs
+
+    else:
+        f = g
+        ns = list(f.dispatch.values())
+
+        rs = []
+        # Sometimes a function name shows up multiple times; only generate
+        # it once!
+        seen = set()
+        for n in ns:
+            if n in seen:
+                continue
+            if "legacy::" in n:
+                continue
+            seen.add(n)
+            returns_type = native.returns_type(f.func.returns)
+            args = native.arguments(f.func)
+            rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
+
+        return rs
diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py
index 5009363..a2e660f 100644
--- a/tools/codegen/gen.py
+++ b/tools/codegen/gen.py
@@ -319,63 +319,6 @@
 def compute_aten_op(f: NativeFunction) -> str:
     return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
 
-# Generates NativeFunctions.h, a list of forward declarations of all
-# actual kernel definitions we keep in aten/src/ATen/native/
-@with_native_function
-def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
-    if isinstance(g, StructuredNativeFunctions):
-        # only out has dispatch
-        meta_name = meta.name(g)
-        rs = []
-        seen: Set[Any] = set()
-        out_args = structured.impl_arguments(g)
-        for k, n in g.out.dispatch.items():
-            if n in seen:
-                continue
-            if not is_structured_dispatch_key(k):
-                continue
-            seen.add(n)
-            rs.append(f"""\
-struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
-    void impl({', '.join(a.decl() for a in out_args)});
-}};
-""")
-
-        seen = set()
-        for f in g.functions():
-            returns_type = native.returns_type(f.func.returns)
-            args = native.arguments(f.func)
-            for k, n in f.dispatch.items():
-                if n in seen:
-                    continue
-                if is_structured_dispatch_key(k):
-                    continue
-                seen.add(n)
-                args_str = ', '.join(a.decl() for a in args)
-                rs.append(f"TORCH_API {returns_type} {n}({args_str});")
-
-        return rs
-
-    else:
-        f = g
-        ns = list(f.dispatch.values())
-
-        rs = []
-        # Sometimes a function name shows up multiple times; only generate
-        # it once!
-        seen = set()
-        for n in ns:
-            if n in seen:
-                continue
-            if "legacy::" in n:
-                continue
-            seen.add(n)
-            returns_type = native.returns_type(f.func.returns)
-            args = native.arguments(f.func)
-            rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
-
-        return rs
-
 # Generates MetaFunctions.h
 def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str:
     with native_function_manager(g.out):
@@ -1027,7 +970,7 @@
         'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
     })
     cpu_fm.write('NativeFunctions.h', lambda: {
-        'native_function_declarations': list(concatMap(compute_native_function_declaration, grouped_native_functions)),
+        'native_function_declarations': list(concatMap(dest.compute_native_function_declaration, grouped_native_functions)),
     })
 
     cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]))