[pytorch][codegen] migrate gen_variable_factories.py to the new data model (#47818)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47818
This is another relatively small codegen.
Ideally we should CppSignature.decl() to generate the c++ function declaration.
We didn't because it needs to add 'at::' to the types defined in ATen namespace.
E.g.:
- standard declaration:
```
Tensor eye(int64_t n, int64_t m, const TensorOptions & options={})
```
- expected:
```
at::Tensor eye(int64_t n, int64_t m, const at::TensorOptions & options = {})
```
Kept the hacky fully_qualified_type() method to keep compatibility with old codegen.
We could clean up by:
- Using these types in torch namespace - but this is a user facing header file,
not sure if it will cause problem;
- Update cpp.argument_type() method to take optional namespace argument;
Confirmed byte-for-byte compatible with the old codegen:
```
Run it before and after this PR:
.jenkins/pytorch/codegen-test.sh <baseline_output_dir>
.jenkins/pytorch/codegen-test.sh <test_output_dir>
Then run diff to compare the generated files:
diff -Naur <baseline_output_dir> <test_output_dir>
```
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D24909478
Pulled By: ljk53
fbshipit-source-id: a0ceaa60cc765c526908fee39f151cd7ed5ec923
diff --git a/mypy-strict.ini b/mypy-strict.ini
index dc17faa..f5755e9 100644
--- a/mypy-strict.ini
+++ b/mypy-strict.ini
@@ -33,6 +33,7 @@
tools/autograd/gen_annotated_fn_args.py,
tools/autograd/gen_python_functions.py,
tools/autograd/gen_trace_type.py,
+ tools/autograd/gen_variable_factories.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,
torch/utils/benchmark/utils/valgrind_wrapper/*.py,
diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py
index cbeffcd..57cbfcc 100644
--- a/tools/autograd/gen_autograd.py
+++ b/tools/autograd/gen_autograd.py
@@ -191,7 +191,7 @@
# Generate variable_factories.h
from .gen_variable_factories import gen_variable_factories
# Some non-selectable ops (e.g. prim ops) need factory methods so we pass in `full_aten_decls` here.
- gen_variable_factories(out, full_aten_decls, template_path)
+ gen_variable_factories(out, native_functions_path, template_path)
def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py
index 72fdad7..a8c07ae 100644
--- a/tools/autograd/gen_variable_factories.py
+++ b/tools/autograd/gen_variable_factories.py
@@ -3,77 +3,81 @@
# This writes one file: variable_factories.h
import re
+from typing import Optional, List
-from .utils import CodeTemplate, write
-
-
-FUNCTION_TEMPLATE = CodeTemplate("""\
-inline at::Tensor ${name}(${formals}) {
- at::Tensor tensor = ([&]() {
- at::AutoNonVariableTypeMode non_var_type_mode(true);
- return at::${name}(${actuals});
- })();
- at::Tensor result =
- autograd::make_variable(std::move(tensor), /*requires_grad=*/${requires_grad});
- return result;
-}
-""")
-
+from tools.codegen.api.types import *
+import tools.codegen.api.cpp as cpp
+import tools.codegen.api.python as python
+from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager, mapMaybe
+from tools.codegen.model import *
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
-
-def fully_qualified_type(argument_type):
- def maybe_optional_type(t, opt_match):
- return 'c10::optional<{}>'.format(t) if opt_match else t
+# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
+# TODO: maybe update the cpp argument API to take optional namespace argument?
+def fully_qualified_type(argument_type: str) -> str:
+ def maybe_optional_type(type: str, is_opt: bool) -> str:
+ return f'c10::optional<{type}>' if is_opt else type
opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
+ is_opt = opt_match is not None
if opt_match:
argument_type = argument_type[opt_match.start(1):opt_match.end(1)]
match = TYPE_PATTERN.match(argument_type)
if match is None:
- return maybe_optional_type(argument_type, opt_match)
+ return maybe_optional_type(argument_type, is_opt)
index = match.start(1)
- qualified_type = "{}at::{}".format(argument_type[:index], argument_type[index:])
- return maybe_optional_type(qualified_type, opt_match)
+ qualified_type = f'{argument_type[:index]}at::{argument_type[index:]}'
+ return maybe_optional_type(qualified_type, is_opt)
+def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
+ native_functions = parse_native_yaml(native_yaml_path)
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+ fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
+ 'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h',
+ 'function_definitions': list(mapMaybe(process_function, native_functions)),
+ })
-def gen_variable_factories(out, declarations, template_path):
- function_definitions = []
- for decl in declarations:
- has_tensor_options = any(a["simple_type"] == "TensorOptions" for a in decl["arguments"])
- is_namespace_fn = 'namespace' in decl['method_of']
- if (has_tensor_options or decl["name"].endswith("_like")) and is_namespace_fn:
- function_definitions.append(
- process_function(
- decl,
- has_tensor_options,
- )
- )
- write(out,
- "variable_factories.h",
- CodeTemplate.from_file(template_path + "/variable_factories.h"),
- {"function_definitions": function_definitions})
+@with_native_function
+def process_function(f: NativeFunction) -> Optional[str]:
+ name = cpp.name(f.func)
+ has_tensor_options = python.has_tensor_options(f)
+ is_factory = has_tensor_options or name.endswith("_like")
+ if Variant.function not in f.variants or not is_factory:
+ return None
-def process_function(decl, has_tensor_options):
- formals = []
- actuals = []
- for argument in decl["arguments"]:
- type = fully_qualified_type(argument["type"])
- default = " = {}".format(argument["default"]) if "default" in argument else ""
- formals.append("{} {}{}".format(type, argument["name"], default))
- actual = argument["name"]
- if argument["simple_type"] == "TensorOptions":
+ sig = CppSignatureGroup.from_schema(f.func, method=False).signature
+ formals: List[str] = []
+ exprs: List[str] = []
+ requires_grad = 'false'
+ for arg in sig.arguments():
+ qualified_type = fully_qualified_type(arg.type)
+ if arg.default:
+ formals.append(f'{qualified_type} {arg.name} = {arg.default}')
+ else:
+ formals.append(f'{qualified_type} {arg.name}')
+
+ if isinstance(arg.argument, TensorOptionsArguments):
# note: we remove the requires_grad setting from the TensorOptions because
# it is ignored anyways (and we actually have an assertion that it isn't set
# which would fail otherwise). We handle requires_grad explicitly here
# instead of passing it through to the kernel.
- actual = "at::TensorOptions({}).requires_grad(c10::nullopt)".format(actual)
- actuals.append(actual)
- requires_grad = "options.requires_grad()" if has_tensor_options else "false"
+ exprs.append(f'at::TensorOptions({arg.name}).requires_grad(c10::nullopt)')
+ # Manually set the requires_grad bit on the result tensor.
+ requires_grad = f'{arg.name}.requires_grad()'
+ else:
+ exprs.append(arg.name)
- return FUNCTION_TEMPLATE.substitute(
- name=decl["name"], formals=formals, actuals=actuals, requires_grad=requires_grad
- )
+ return f"""\
+inline at::Tensor {name}({', '.join(formals)}) {{
+ at::Tensor tensor = ([&]() {{
+ at::AutoNonVariableTypeMode non_var_type_mode(true);
+ return at::{name}({', '.join(exprs)});
+ }})();
+ at::Tensor result =
+ autograd::make_variable(std::move(tensor), /*requires_grad=*/{requires_grad});
+ return result;
+}}
+"""