Move (most) generated return statements for TH functions out of the switch. (#38073)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38073

Most of the generated return statements don't depend on the scalar type and it saves ~900 lines of generated code.

Test Plan: Imported from OSS

Differential Revision: D21476010

Pulled By: gchanan

fbshipit-source-id: 3fcc4db466d697c90abafb9da6c3f3644621810b
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 501592c..c4c7dc3 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -44,6 +44,7 @@
     default:
         AT_ERROR("${api_name} not supported on ${Type} for ", dispatch_scalar_type);
 }
+${switch_epilogue}
 """)
 
 LEGACY_TH_DEFINITION_CASE = CodeTemplate("""\
@@ -1414,19 +1415,7 @@
 
                 if ret['kind'] == 'arguments':
                     case_body.extend([call + ';' for call in calls])
-                    arguments_indices = ret['arguments']
-                    arguments = [option['arguments'][argi]
-                                 for argi in arguments_indices]
-                    if len(arguments_indices) == 1:
-                        arg = arguments[0]
-                        case_body.append("return {};".format(arg['name']))
-                    else:
-                        types = [to_return_type(arg, option)['type']
-                                 for arg in arguments]
-                        # TODO: check for move semantics...
-                        names = [arg['name'] for arg in arguments]
-                        case_body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
-                            types=types, names=names))
+                    # return handled later
                 elif ret['kind'] == 'type':
                     assert len(calls) == 1
                     call = calls[0]
@@ -1444,7 +1433,24 @@
                     raise Exception("NYI - return handling")
 
                 cases.append(LEGACY_TH_DEFINITION_CASE.substitute(case_env, case_body=case_body))
-        body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases, switch_prologue=switch_prologue))
+        switch_epilogue = ''
+        if ret['kind'] == 'arguments':
+            arguments_indices = ret['arguments']
+            arguments = [option['arguments'][argi]
+                         for argi in arguments_indices]
+            if len(arguments_indices) == 1:
+                arg = arguments[0]
+                switch_epilogue = "return {};".format(arg['name'])
+            else:
+                types = [to_return_type(arg, option)['type']
+                         for arg in arguments]
+                # TODO: check for move semantics...
+                names = [arg['name'] for arg in arguments]
+                switch_epilogue = CodeTemplate("return std::tuple<${types}>(${names});").substitute(
+                    types=types, names=names)
+        body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases,
+                                                                     switch_prologue=switch_prologue,
+                                                                     switch_epilogue=switch_epilogue))
         return body
 
     def process_legacy_th_option(option):