| """ | 
 | The APIs in this file are exposed as `functorch.*`. They are thin wrappers | 
 | around the torch.func.* APIs that have deprecation warnings -- we're trying | 
 | to move people to the torch.func.* equivalents. | 
 |  | 
 | NB: We don't use *args, **kwargs in the signatures because that changes the | 
 | documentation. | 
 | """ | 
 |  | 
 | import textwrap | 
 | import warnings | 
 | from typing import Any, Callable, Optional, Tuple, Union | 
 |  | 
 | import torch._functorch.apis as apis | 
 | import torch._functorch.eager_transforms as _impl | 
 | import torch._functorch.make_functional as _nn_impl | 
 | import torch.nn as nn | 
 | from torch._functorch.eager_transforms import argnums_t | 
 | from torch._functorch.vmap import in_dims_t, out_dims_t | 
 |  | 
 |  | 
 | def get_warning(api, new_api=None, replace_newlines=False): | 
 |     if new_api is None: | 
 |         new_api = f"torch.func.{api}" | 
 |     warning = ( | 
 |         f"We've integrated functorch into PyTorch. As the final step of the \n" | 
 |         f"integration, `functorch.{api}` is deprecated as of PyTorch \n" | 
 |         f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" | 
 |         f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n" | 
 |         f"and/or the `torch.func` migration guide for more details \n" | 
 |         f"https://pytorch.org/docs/main/func.migrating.html" | 
 |     ) | 
 |     if replace_newlines: | 
 |         warning = warning.replace("\n", "") | 
 |     return warning | 
 |  | 
 |  | 
 | def warn_deprecated(api, new_api=None): | 
 |     warning = get_warning(api, new_api, replace_newlines=True) | 
 |     warnings.warn(warning, FutureWarning, stacklevel=2) | 
 |  | 
 |  | 
 | def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): | 
 |     api_name = functorch_api.__name__ | 
 |     if torch_func_api is None: | 
 |         torch_func_api = getattr(_impl, api_name) | 
 |     # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO | 
 |     if torch_func_api.__doc__ is None: | 
 |         return | 
 |  | 
 |     warning = get_warning(api_name, new_api_name) | 
 |     warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, "    ") | 
 |     warning_note = textwrap.indent(warning_note, "    ") | 
 |     functorch_api.__doc__ = torch_func_api.__doc__ + warning_note | 
 |  | 
 |  | 
 | def vmap( | 
 |     func: Callable, | 
 |     in_dims: in_dims_t = 0, | 
 |     out_dims: out_dims_t = 0, | 
 |     randomness: str = "error", | 
 |     *, | 
 |     chunk_size=None, | 
 | ) -> Callable: | 
 |     warn_deprecated("vmap", "torch.vmap") | 
 |     return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size) | 
 |  | 
 |  | 
 | def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: | 
 |     warn_deprecated("grad") | 
 |     return apis.grad(func, argnums, has_aux) | 
 |  | 
 |  | 
 | def grad_and_value( | 
 |     func: Callable, argnums: argnums_t = 0, has_aux: bool = False | 
 | ) -> Callable: | 
 |     warn_deprecated("grad_and_value") | 
 |     return apis.grad_and_value(func, argnums, has_aux) | 
 |  | 
 |  | 
 | def vjp(func: Callable, *primals, has_aux: bool = False): | 
 |     warn_deprecated("vjp") | 
 |     return _impl.vjp(func, *primals, has_aux=has_aux) | 
 |  | 
 |  | 
 | def jvp( | 
 |     func: Callable, | 
 |     primals: Any, | 
 |     tangents: Any, | 
 |     *, | 
 |     strict: bool = False, | 
 |     has_aux: bool = False, | 
 | ): | 
 |     warn_deprecated("jvp") | 
 |     return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux) | 
 |  | 
 |  | 
 | def jacrev( | 
 |     func: Callable, | 
 |     argnums: Union[int, Tuple[int]] = 0, | 
 |     *, | 
 |     has_aux=False, | 
 |     chunk_size: Optional[int] = None, | 
 |     _preallocate_and_copy=False, | 
 | ): | 
 |     warn_deprecated("jacrev") | 
 |     return _impl.jacrev( | 
 |         func, | 
 |         argnums, | 
 |         has_aux=has_aux, | 
 |         chunk_size=chunk_size, | 
 |         _preallocate_and_copy=_preallocate_and_copy, | 
 |     ) | 
 |  | 
 |  | 
 | def jacfwd( | 
 |     func: Callable, | 
 |     argnums: argnums_t = 0, | 
 |     has_aux: bool = False, | 
 |     *, | 
 |     randomness: str = "error", | 
 | ): | 
 |     warn_deprecated("jacfwd") | 
 |     return _impl.jacfwd(func, argnums, has_aux, randomness=randomness) | 
 |  | 
 |  | 
 | def hessian(func, argnums=0): | 
 |     warn_deprecated("hessian") | 
 |     return _impl.hessian(func, argnums=argnums) | 
 |  | 
 |  | 
 | def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: | 
 |     warn_deprecated("functionalize") | 
 |     return _impl.functionalize(func, remove=remove) | 
 |  | 
 |  | 
 | def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): | 
 |     warn_deprecated("make_functional", "torch.func.functional_call") | 
 |     return _nn_impl.make_functional(model, disable_autograd_tracking) | 
 |  | 
 |  | 
 | def make_functional_with_buffers( | 
 |     model: nn.Module, disable_autograd_tracking: bool = False | 
 | ): | 
 |     warn_deprecated("make_functional_with_buffers", "torch.func.functional_call") | 
 |     return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking) | 
 |  | 
 |  | 
 | def combine_state_for_ensemble(models): | 
 |     warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state") | 
 |     return _nn_impl.combine_state_for_ensemble(models) | 
 |  | 
 |  | 
 | setup_docs(vmap, apis.vmap, "torch.vmap") | 
 | setup_docs(grad, apis.grad) | 
 | setup_docs(grad_and_value, apis.grad_and_value) | 
 | setup_docs(vjp) | 
 | setup_docs(jvp) | 
 | setup_docs(jacrev) | 
 | setup_docs(jacfwd) | 
 | setup_docs(hessian) | 
 | setup_docs(functionalize) | 
 | setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call") | 
 | setup_docs( | 
 |     make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call" | 
 | ) | 
 | setup_docs( | 
 |     combine_state_for_ensemble, | 
 |     _nn_impl.combine_state_for_ensemble, | 
 |     "torch.func.stack_module_state", | 
 | ) |