Fix and update type hints for `make_functional.py` (#91579)

Changes in details:

- Fix and update some out-of-date type hints in `_functorch/make_functional.py`.
- ~Explicitly use `OrderedDict` for order-sensitive mappings.~

	In `create_names_map()`, `_swap_state()`, and `FunctionalModuleWithBuffers.__init__()`, the unordered `dict` was used. The key order should be preserved for `dict.items()` while it is required to `zip` with a tuple of `params`/`buffers`. Although since Python 3.6, the built-in dictionary is insertion ordered ([PEP 468](https://peps.python.org/pep-0468)). Explicit is better than implicit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91579
Approved by: https://github.com/zou3519
diff --git a/.lintrunner.toml b/.lintrunner.toml
index b1efdee..bb075d4 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -108,7 +108,6 @@
     'torch/_functorch/eager_transforms.py',
     'torch/_functorch/fx_minifier.py',
     'torch/_functorch/partitioners.py',
-    'torch/_functorch/make_functional.py',
     'torch/_functorch/top_operators_github_usage.py',
     'torch/_functorch/vmap.py',
     'torch/distributed/elastic/agent/server/api.py',
@@ -830,6 +829,7 @@
     'torchgen/**/*.py',
     'functorch/functorch/_src/aot_autograd.py',
     'functorch/functorch/_src/compilers.py',
+    'torch/_functorch/make_functional.py',
     'torch/testing/*.py',
     'torch/distributed/fsdp/*.py',
     'test/distributed/fsdp/*.py',
diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py
index abb3f07..cd7db82 100644
--- a/torch/_functorch/make_functional.py
+++ b/torch/_functorch/make_functional.py
@@ -4,12 +4,24 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import copy
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    NoReturn,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
+
 import torch
 import torch.nn as nn
 from torch import Tensor
-from typing import List, Tuple
-from .named_members_polyfill import _named_parameters, _named_buffers
-import copy
+from .named_members_polyfill import _named_buffers, _named_parameters
 
 # Utilities to make nn.Module "functional"
 # In particular the goal is to be able to provide a function that takes as input
@@ -40,69 +52,79 @@
         _set_nested_attr(getattr(obj, names[0]), names[1:], value)
 
 
-def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
+def _get_nested_attr(obj: nn.Module, names: List[str]) -> Tensor:
     if len(names) == 1:
         return getattr(obj, names[0])
     else:
         return _get_nested_attr(getattr(obj, names[0]), names[1:])
 
 
-def raise_parameter_tying_error():
+def raise_parameter_tying_error() -> NoReturn:
     raise RuntimeError(
         "make_functional(module): we don't yet support models that "
         "do parameter tying (also sometimes known as weight sharing). "
         "Please try to rewrite your model by replacing all instances of the "
         "tied parameter with another and/or comment your support in "
-        "https://github.com/pytorch/functorch/issues/446")
+        "https://github.com/pytorch/functorch/issues/446"
+    )
 
 
-def create_names_map(named_params, tied_named_params):
+def create_names_map(
+    named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
+    tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
+) -> Dict[str, List[List[str]]]:
     """
     named_params is a dictionary of tensors: {'A': A, 'B': B}
     tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
     with potentially tied (or 'duplicated') tensors
 
     This function creates a mapping from the names in named_params to the
-    names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
+    names in tied_named_params: {'A': [['A']], 'B': [['B'], ['B_tied']]}.
     """
-    named_params = {k: v for k, v in named_params}
-    tied_named_params = {k: v for k, v in tied_named_params}
+    named_params = dict(named_params)
+    tied_named_params = dict(tied_named_params)
 
     tensors_dict_keys = set(named_params.keys())
     tied_tensors_dict_keys = set(tied_named_params.keys())
     assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
 
-    tensor_to_mapping = {}
+    tensor_to_mapping: Dict[Tensor, Tuple[str, List[List[str]]]] = {}
     for key, tensor in named_params.items():
         tensor_to_mapping[tensor] = (key, [])
     for key, tensor in tied_named_params.items():
         assert tensor in tensor_to_mapping
-        tensor_to_mapping[tensor][1].append(key.split('.'))
-    result = {key: value for key, value in tensor_to_mapping.values()}
-    return result
+        tensor_to_mapping[tensor][1].append(key.split("."))
+    return dict(tensor_to_mapping.values())
 
 
-def _extract_members(mod: nn.Module, _named_members, named_members, subclass):
+def _extract_members(
+    mod: nn.Module,
+    _named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
+    named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
+    subclass: Callable[[Tensor], Tensor],
+) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
     all_named_members = tuple(_named_members(mod, remove_duplicate=False))
-    named_members = tuple(named_members())
-    names_map = create_names_map(named_members, all_named_members)
+    unique_named_members = tuple(named_members())
+    names_map = create_names_map(unique_named_members, all_named_members)
 
     # Remove all the members in the model
     memo = {}
     for name, p in all_named_members:
         if p not in memo:
-            memo[p] = subclass(torch.empty_like(p, device='meta'))
+            memo[p] = subclass(torch.empty_like(p, device="meta"))
         replacement = memo[p]
         _set_nested_attr(mod, name.split("."), replacement)
 
-    if len(named_members) == 0:
+    if len(unique_named_members) == 0:
         names, params = (), ()
     else:
-        names, params = zip(*named_members)
+        names, params = zip(*unique_named_members)  # type: ignore[assignment]
     return params, names, names_map
 
 
-def extract_weights(mod: nn.Module):
+def extract_weights(
+    mod: nn.Module,
+) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
     """
     This function removes all the Parameters from the model and
     return them as a tuple as well as their original attribute names.
@@ -114,11 +136,18 @@
     return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter)
 
 
-def extract_buffers(mod: nn.Module):
+def extract_buffers(
+    mod: nn.Module,
+) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
     return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x)
 
 
-def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
+def load_weights(
+    mod: nn.Module,
+    names: Sequence[str],
+    params: Sequence[Tensor],
+    as_params: bool = False,
+) -> None:
     """
     Reload a set of weights so that `mod` can be used again to perform a forward pass.
     Note that the `params` are regular Tensors (that can have history) and so are left
@@ -131,8 +160,10 @@
         _set_nested_attr(mod, name.split("."), p)
 
 
-def _swap_state(mod: nn.Module, names_map: List[str], elems):
-    result = []
+def _swap_state(
+    mod: nn.Module, names_map: Dict[str, List[List[str]]], elems: Iterable[Tensor]
+) -> List[Tensor]:
+    result: List[Tensor] = []
     for (_, attr_names), elem in zip(names_map.items(), elems):
         for i, attr_name in enumerate(attr_names):
             if i == 0:
@@ -142,15 +173,23 @@
     return result
 
 
-def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None:
+def load_buffers(
+    mod: nn.Module,
+    names: Sequence[str],
+    buffers: Sequence[Tensor],
+    as_params: bool = False,
+) -> None:
     for name, p in zip(names, buffers):
         _set_nested_attr(mod, name.split("."), p)
 
 
 def load_state(
-        model: nn.Module,
-        weights: List[Tensor], weight_names: List[str],
-        buffers=(), buffer_names=()):
+    model: nn.Module,
+    weights: Sequence[Tensor],
+    weight_names: Sequence[str],
+    buffers: Sequence[Tensor] = (),
+    buffer_names: Sequence[str] = (),
+) -> nn.Module:
     """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
 
     load_state takes `weights` and `buffers` and assigns them to the model.
@@ -192,8 +231,10 @@
     """
     buffers = list(model.buffers())
     if len(buffers) > 0:
-        raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
-                           'make_functional_with_buffers_deprecated_v1(model) instead.')
+        raise RuntimeError(
+            "make_functional_deprecated_v1(model): `model` has buffers. Please use "
+            "make_functional_with_buffers_deprecated_v1(model) instead."
+        )
     weights, descriptors, _ = extract_weights(model)
 
     def fun(weights, data):
@@ -246,8 +287,14 @@
     This is the callable object returned by :func:`make_functional_with_buffers`.
     """
 
-    def __init__(self, stateless_model, param_names, buffer_names,
-                 param_names_map, buffer_names_map):
+    def __init__(
+        self,
+        stateless_model: nn.Module,
+        param_names: Tuple[str, ...],
+        buffer_names: Tuple[str, ...],
+        param_names_map: Dict[str, List[List[str]]],
+        buffer_names_map: Dict[str, List[List[str]]],
+    ) -> None:
         super(FunctionalModuleWithBuffers, self).__init__()
         self.stateless_model = stateless_model
         self.param_names = param_names
@@ -257,7 +304,9 @@
         self.all_names_map.update(buffer_names_map)
 
     @staticmethod
-    def _create_from(model, disable_autograd_tracking=False):
+    def _create_from(
+        model: nn.Module, disable_autograd_tracking: bool = False
+    ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
         # TODO: We don't need to copy the model to create a stateless copy
         model_copy = copy.deepcopy(model)
         params, param_names, param_names_map = extract_weights(model_copy)
@@ -266,18 +315,22 @@
             for param in params:
                 param.requires_grad_(False)
         return (
-            FunctionalModuleWithBuffers(model_copy, param_names, buffer_names,
-                                        param_names_map, buffer_names_map),
+            FunctionalModuleWithBuffers(
+                model_copy, param_names, buffer_names, param_names_map, buffer_names_map
+            ),
             params,
             buffers,
         )
 
-    def forward(self, params, buffers, *args, **kwargs):
+    def forward(
+        self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
+    ) -> Any:
         # Temporarily load the state back onto self.stateless_model
         old_state = _swap_state(
             self.stateless_model,
             self.all_names_map,
-            list(params) + list(buffers))
+            tuple(params) + tuple(buffers),
+        )
         try:
             return self.stateless_model(*args, **kwargs)
         finally:
@@ -290,14 +343,21 @@
     This is the callable object returned by :func:`make_functional`.
     """
 
-    def __init__(self, stateless_model, param_names, names_map):
+    def __init__(
+        self,
+        stateless_model: nn.Module,
+        param_names: Tuple[str, ...],
+        names_map: Dict[str, List[List[str]]],
+    ) -> None:
         super(FunctionalModule, self).__init__()
         self.stateless_model = stateless_model
         self.param_names = param_names
         self.names_map = names_map
 
     @staticmethod
-    def _create_from(model, disable_autograd_tracking=False):
+    def _create_from(
+        model: nn.Module, disable_autograd_tracking: bool = False
+    ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
         # TODO: We don't need to copy the model to create a stateless copy
         model_copy = copy.deepcopy(model)
         params, param_names, names_map = extract_weights(model_copy)
@@ -306,7 +366,7 @@
                 param.requires_grad_(False)
         return FunctionalModule(model_copy, param_names, names_map), params
 
-    def forward(self, params, *args, **kwargs):
+    def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
         # Temporarily load the state back onto self.stateless_model
         old_state = _swap_state(self.stateless_model, self.names_map, params)
         try:
@@ -316,7 +376,9 @@
             _swap_state(self.stateless_model, self.names_map, old_state)
 
 
-def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
+def make_functional(
+    model: nn.Module, disable_autograd_tracking: bool = False
+) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
     """make_functional(model, disable_autograd_tracking=False) -> func, params
 
     Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
@@ -375,12 +437,18 @@
     """
     buffers = list(model.buffers())
     if len(buffers) > 0:
-        raise RuntimeError('make_functional(model): `model` has buffers. Please use '
-                           'make_functional_with_buffers(model) instead.')
-    return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
+        raise RuntimeError(
+            "make_functional(model): `model` has buffers. Please use "
+            "make_functional_with_buffers(model) instead."
+        )
+    return FunctionalModule._create_from(
+        model, disable_autograd_tracking=disable_autograd_tracking
+    )
 
 
-def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
+def make_functional_with_buffers(
+    model: nn.Module, disable_autograd_tracking: bool = False
+) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
     """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
 
     Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
@@ -434,16 +502,24 @@
             history with PyTorch autograd.
 
     """
-    return FunctionalModuleWithBuffers._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
+    return FunctionalModuleWithBuffers._create_from(
+        model, disable_autograd_tracking=disable_autograd_tracking
+    )
 
 
-def transpose_stack(tuple_of_tuple_of_tensors):
+def transpose_stack(
+    tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
+) -> Tuple[Tensor, ...]:
     tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
-    results = tuple(torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors)
+    results = tuple(
+        torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
+    )
     return results
 
 
-def combine_state_for_ensemble(models):
+def combine_state_for_ensemble(
+    models: Sequence[nn.Module],
+) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
     """combine_state_for_ensemble(models) -> func, params, buffers
 
     Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
@@ -482,62 +558,90 @@
         create ensembles and would love your feedback how to improve this.
     """
     if len(models) == 0:
-        raise RuntimeError('combine_state_for_ensemble: Expected at least one model, got 0.')
+        raise RuntimeError(
+            "combine_state_for_ensemble: Expected at least one model, got 0."
+        )
     if not (all(m.training for m in models) or all(not m.training for m in models)):
-        raise RuntimeError('combine_state_for_ensemble: Expected all models to '
-                           'have the same training/eval mode.')
+        raise RuntimeError(
+            "combine_state_for_ensemble: Expected all models to "
+            "have the same training/eval mode."
+        )
     model0_typ = type(models[0])
     if not all(type(m) == model0_typ for m in models):
-        raise RuntimeError('combine_state_for_ensemble: Expected all models to '
-                           'be of the same class.')
-    funcs, params, buffers = zip(*[make_functional_with_buffers(model)
-                                   for model in models])
+        raise RuntimeError(
+            "combine_state_for_ensemble: Expected all models to "
+            "be of the same class."
+        )
+    funcs, params, buffers = zip(
+        *[make_functional_with_buffers(model) for model in models]
+    )
     params = transpose_stack(params)
     buffers = transpose_stack(buffers)
     return funcs[0], params, buffers
 
 
-def functional_init(model_class, ensemble_shape=(), device='cpu'):
+def functional_init(
+    model_class: Type[nn.Module],
+    ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
+    device: torch.types.Device = "cpu",
+):
     def wrapped(*args, **kwargs):
         if len(ensemble_shape) >= 2:
-            raise ValueError('NYI: ensemble_shape with more than 1 element')
+            raise ValueError("NYI: ensemble_shape with more than 1 element")
         if len(ensemble_shape) == 0:
             model = model_class(*args, **kwargs).to(device)
             return make_functional_deprecated_v1(model)
-        num_models = ensemble_shape[0]
+        num_models = ensemble_shape[0]  # type: ignore[misc]
         if num_models <= 0:
             raise ValueError(f"num_models {num_models} should be > 0")
         # NB: Not very efficient, more of a POC
-        models = tuple(model_class(*args, **kwargs).to(device)
-                       for _ in range(num_models))
+        models = tuple(
+            model_class(*args, **kwargs).to(device) for _ in range(num_models)
+        )
         _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
         weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
         weights = tuple(zip(*weights))
         weights = tuple(torch.stack(shards).detach() for shards in weights)
         return weights, fn, names
+
     return wrapped
 
 
-def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
+def functional_init_with_buffers(
+    model_class: Type[nn.Module],
+    ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
+    device: torch.types.Device = "cpu",
+):
     def wrapped(*args, **kwargs):
         if len(ensemble_shape) >= 2:
-            raise ValueError('NYI: ensemble_shape with more than 1 element')
+            raise ValueError("NYI: ensemble_shape with more than 1 element")
         if len(ensemble_shape) == 0:
             model = model_class(*args, **kwargs).to(device)
             return make_functional_deprecated_v1(model)
-        num_models = ensemble_shape[0]
+        num_models = ensemble_shape[0]  # type: ignore[misc]
         if num_models <= 0:
             raise ValueError(f"num_models {num_models} should be > 0")
         # NB: Not very efficient, more of a POC
-        models = tuple(model_class(*args, **kwargs).to(device)
-                       for _ in range(num_models))
-        _, _, fn, weight_names, buffer_names = \
-            make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
-        weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2]
-                                      for model in models))
+        models = tuple(
+            model_class(*args, **kwargs).to(device) for _ in range(num_models)
+        )
+        (
+            _,
+            _,
+            fn,
+            weight_names,
+            buffer_names,
+        ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
+        weights, buffers = zip(
+            *tuple(
+                make_functional_with_buffers_deprecated_v1(model)[:2]
+                for model in models
+            )
+        )
         weights = tuple(zip(*weights))
         weights = tuple(torch.stack(shards).detach() for shards in weights)
         buffers = tuple(zip(*buffers))
         buffers = tuple(torch.stack(shards).detach() for shards in buffers)
         return weights, buffers, fn, weight_names, buffer_names
+
     return wrapped