| import textwrap | 
 | from dataclasses import dataclass | 
 | from typing import List, Optional, Sequence, Tuple | 
 |  | 
 | from torchgen.api.translate import translate | 
 | from torchgen.api.types import DispatcherSignature | 
 | from torchgen.context import method_with_native_function | 
 | from torchgen.model import ( | 
 |     Argument, | 
 |     BaseTy, | 
 |     BaseType, | 
 |     FunctionSchema, | 
 |     ListType, | 
 |     NativeFunction, | 
 |     OptionalType, | 
 |     Return, | 
 |     SchemaKind, | 
 |     Type, | 
 | ) | 
 | from torchgen.utils import mapMaybe | 
 |  | 
 |  | 
 | def is_tensor(typ: Type) -> bool: | 
 |     return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor | 
 |  | 
 |  | 
 | def is_optional_tensor(typ: Type) -> bool: | 
 |     return isinstance(typ, OptionalType) and is_tensor(typ.elem) | 
 |  | 
 |  | 
 | def is_tensor_list(typ: Type) -> bool: | 
 |     return isinstance(typ, ListType) and is_tensor(typ.elem) | 
 |  | 
 |  | 
 | def unwrap_tensor(name: str, cur_level_var: str) -> List[str]: | 
 |     result = f"""\ | 
 |     Tensor {name}_value; | 
 |     optional<int64_t> {name}_bdim; | 
 |     std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, {cur_level_var});""" | 
 |     return textwrap.dedent(result).split("\n") | 
 |  | 
 |  | 
 | def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]: | 
 |     result = f"""\ | 
 |     optional<Tensor> {name}_value; | 
 |     optional<int64_t> {name}_bdim; | 
 |     if ({name}) {{ | 
 |         std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var}); | 
 |     }}""" | 
 |     return textwrap.dedent(result).split("\n") | 
 |  | 
 |  | 
 | def gen_unwraps( | 
 |     flat_arguments: Sequence[Argument], cur_level_var: str | 
 | ) -> Tuple[str, List[str]]: | 
 |     arg_names = [a.name for a in flat_arguments] | 
 |     arg_types = [a.type for a in flat_arguments] | 
 |  | 
 |     tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)] | 
 |     optional_tensors = [ | 
 |         name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ) | 
 |     ] | 
 |  | 
 |     unwraps = [] | 
 |     for tensor in tensors: | 
 |         unwraps += unwrap_tensor(tensor, cur_level_var) | 
 |  | 
 |     for opt_tensor in optional_tensors: | 
 |         unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var) | 
 |     unwrap_code = "\n".join(unwraps) | 
 |  | 
 |     unwrapped_arg_list = [] | 
 |     for arg in arg_names: | 
 |         if arg in tensors or arg in optional_tensors: | 
 |             unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"] | 
 |         else: | 
 |             unwrapped_arg_list.append(arg) | 
 |     return unwrap_code, unwrapped_arg_list | 
 |  | 
 |  | 
 | def gen_case_where_all_bdims_are_none( | 
 |     outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str | 
 | ) -> str: | 
 |     conditions = [] | 
 |     flat_args = schema.arguments.flat_all | 
 |     for arg in flat_args: | 
 |         if not arg.type.is_tensor_like(): | 
 |             continue | 
 |         conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})") | 
 |  | 
 |     sig = DispatcherSignature.from_schema(schema) | 
 |     translated_args = ", ".join( | 
 |         e.expr for e in translate(outer_sig.arguments(), sig.arguments()) | 
 |     ) | 
 |     return f"""\ | 
 | if ({' && '.join(conditions)}) {{ | 
 |   return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args}); | 
 | }}""" | 
 |  | 
 |  | 
 | def gen_returns( | 
 |     returns: Tuple[Return, ...], cur_level_var: str, results_var: str | 
 | ) -> str: | 
 |     idx = 0 | 
 |     wrapped_returns = [] | 
 |     for ret in returns: | 
 |         if is_tensor(ret.type): | 
 |             wrapped_returns.append( | 
 |                 f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})" | 
 |             ) | 
 |             idx += 2 | 
 |         elif is_tensor_list(ret.type): | 
 |             wrapped_returns.append( | 
 |                 f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})" | 
 |             ) | 
 |             idx += 2 | 
 |         else: | 
 |             wrapped_returns.append(f"std::get<{idx}>({results_var})") | 
 |             idx += 1 | 
 |     if len(wrapped_returns) == 1: | 
 |         result = f"return {wrapped_returns[0]};" | 
 |     else: | 
 |         result = f'return std::make_tuple({", ".join(wrapped_returns)});' | 
 |     return result | 
 |  | 
 |  | 
 | def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool: | 
 |     return any(a.type.is_tensor_like() for a in schema.arguments.flat_all) | 
 |  | 
 |  | 
 | def is_mutated_arg(argument: Argument) -> bool: | 
 |     return argument.annotation is not None and argument.annotation.is_write | 
 |  | 
 |  | 
 | def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]: | 
 |     # Assumptions: | 
 |     # - only one argument is being modified in-place | 
 |     # - the argument that is being modified in-place is the first argument | 
 |     # - all returns are either Tensor, tuple of Tensor, or TensorList | 
 |     schema = native_function.func | 
 |     sig = DispatcherSignature.from_schema(schema) | 
 |     returns = schema.returns | 
 |  | 
 |     # Check assumptions. If these are invalid we return None | 
 |     # and punt the work to handle them to the future. | 
 |     assert schema.kind() == SchemaKind.inplace | 
 |     if not is_mutated_arg(schema.arguments.flat_all[0]): | 
 |         return None | 
 |     if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1: | 
 |         return None | 
 |  | 
 |     # Only support cases where all returns are Tensors or vector<Tensor> | 
 |     if len(returns) == 0: | 
 |         return None | 
 |     if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns): | 
 |         return None | 
 |     if not accepts_at_least_one_tensor_input(schema): | 
 |         return None | 
 |  | 
 |     cur_level_var = "cur_level" | 
 |  | 
 |     unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) | 
 |     bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) | 
 |  | 
 |     return f"""\ | 
 | template <typename batch_rule_t, batch_rule_t batch_rule> | 
 | {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ | 
 |   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); | 
 |   auto maybe_layer = maybeCurrentDynamicLayer(); | 
 |   vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); | 
 |   int64_t {cur_level_var} = maybe_layer->layerId(); | 
 | {textwrap.indent(bdims_all_none_case, "  ")} | 
 | {textwrap.indent(unwraps, "  ")} | 
 |   batch_rule({', '.join(unwrapped_arg_list)}); | 
 |   return {schema.arguments.flat_all[0].name}; | 
 | }}""" | 
 |  | 
 |  | 
 | def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str: | 
 |     schema = native_function.func | 
 |     sig = DispatcherSignature.from_schema(schema) | 
 |     cur_level_var = "cur_level" | 
 |  | 
 |     unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) | 
 |     bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) | 
 |  | 
 |     return f"""\ | 
 | template <typename batch_rule_t, batch_rule_t batch_rule> | 
 | {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ | 
 |   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); | 
 |   auto maybe_layer = maybeCurrentDynamicLayer(); | 
 |   vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); | 
 |   int64_t {cur_level_var} = maybe_layer->layerId(); | 
 | {textwrap.indent(bdims_all_none_case, "  ")} | 
 | {textwrap.indent(unwraps, "  ")} | 
 |   batch_rule({', '.join(unwrapped_arg_list)}); | 
 | }}""" | 
 |  | 
 |  | 
 | def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]: | 
 |     schema = native_function.func | 
 |     sig = DispatcherSignature.from_schema(schema) | 
 |     returns = schema.returns | 
 |  | 
 |     # Only support cases where all returns are Tensors or vector<Tensor> | 
 |     if not accepts_at_least_one_tensor_input(schema): | 
 |         return None | 
 |     if len(returns) == 0: | 
 |         return gen_vmap_plumbing_no_returns(native_function) | 
 |     if not all(ret.type.is_tensor_like() for ret in returns): | 
 |         return None | 
 |     # in-place views need special handling | 
 |     if "inplace_view" in native_function.tags: | 
 |         return None | 
 |  | 
 |     if schema.kind() == SchemaKind.inplace: | 
 |         return gen_vmap_inplace_plumbing(native_function) | 
 |  | 
 |     # Don't support these (mutable, out, scratch) | 
 |     if schema.kind() != SchemaKind.functional: | 
 |         return None | 
 |  | 
 |     results_var = "results" | 
 |     cur_level_var = "cur_level" | 
 |  | 
 |     unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) | 
 |     bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) | 
 |  | 
 |     wrapped_returns = gen_returns(returns, cur_level_var, results_var) | 
 |     return f"""\ | 
 | template <typename batch_rule_t, batch_rule_t batch_rule> | 
 | {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ | 
 |   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); | 
 |   auto maybe_layer = maybeCurrentDynamicLayer(); | 
 |   vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); | 
 |   int64_t {cur_level_var} = maybe_layer->layerId(); | 
 | {textwrap.indent(bdims_all_none_case, "  ")} | 
 | {textwrap.indent(unwraps, "  ")} | 
 |   auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)}); | 
 |   {wrapped_returns} | 
 | }}""" | 
 |  | 
 |  | 
 | @dataclass(frozen=True) | 
 | class ComputeBatchRulePlumbing: | 
 |     @method_with_native_function | 
 |     def __call__(self, f: NativeFunction) -> Optional[str]: | 
 |         opname = str(f.func.name) | 
 |         result = gen_vmap_plumbing(f) | 
 |         return result | 
 |  | 
 |  | 
 | def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str: | 
 |     body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions))) | 
 |     return f""" | 
 | #pragma once | 
 | #include <ATen/Operators.h> | 
 | #include <ATen/functorch/PlumbingHelper.h> | 
 |  | 
 | namespace at {{ namespace functorch {{ | 
 |  | 
 | {body} | 
 |  | 
 | }}}} // namespace at::functorch | 
 | """ |