Move (most) generated return statements for TH functions out of the switch. (#38073)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38073
Most of the generated return statements don't depend on the scalar type and it saves ~900 lines of generated code.
Test Plan: Imported from OSS
Differential Revision: D21476010
Pulled By: gchanan
fbshipit-source-id: 3fcc4db466d697c90abafb9da6c3f3644621810b
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 501592c..c4c7dc3 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -44,6 +44,7 @@
default:
AT_ERROR("${api_name} not supported on ${Type} for ", dispatch_scalar_type);
}
+${switch_epilogue}
""")
LEGACY_TH_DEFINITION_CASE = CodeTemplate("""\
@@ -1414,19 +1415,7 @@
if ret['kind'] == 'arguments':
case_body.extend([call + ';' for call in calls])
- arguments_indices = ret['arguments']
- arguments = [option['arguments'][argi]
- for argi in arguments_indices]
- if len(arguments_indices) == 1:
- arg = arguments[0]
- case_body.append("return {};".format(arg['name']))
- else:
- types = [to_return_type(arg, option)['type']
- for arg in arguments]
- # TODO: check for move semantics...
- names = [arg['name'] for arg in arguments]
- case_body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
- types=types, names=names))
+ # return handled later
elif ret['kind'] == 'type':
assert len(calls) == 1
call = calls[0]
@@ -1444,7 +1433,24 @@
raise Exception("NYI - return handling")
cases.append(LEGACY_TH_DEFINITION_CASE.substitute(case_env, case_body=case_body))
- body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases, switch_prologue=switch_prologue))
+ switch_epilogue = ''
+ if ret['kind'] == 'arguments':
+ arguments_indices = ret['arguments']
+ arguments = [option['arguments'][argi]
+ for argi in arguments_indices]
+ if len(arguments_indices) == 1:
+ arg = arguments[0]
+ switch_epilogue = "return {};".format(arg['name'])
+ else:
+ types = [to_return_type(arg, option)['type']
+ for arg in arguments]
+ # TODO: check for move semantics...
+ names = [arg['name'] for arg in arguments]
+ switch_epilogue = CodeTemplate("return std::tuple<${types}>(${names});").substitute(
+ types=types, names=names)
+ body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases,
+ switch_prologue=switch_prologue,
+ switch_epilogue=switch_epilogue))
return body
def process_legacy_th_option(option):