Fix and update type hints for `make_functional.py` (#91579)
Changes in details:
- Fix and update some out-of-date type hints in `_functorch/make_functional.py`.
- ~Explicitly use `OrderedDict` for order-sensitive mappings.~
In `create_names_map()`, `_swap_state()`, and `FunctionalModuleWithBuffers.__init__()`, the unordered `dict` was used. The key order should be preserved for `dict.items()` while it is required to `zip` with a tuple of `params`/`buffers`. Although since Python 3.6, the built-in dictionary is insertion ordered ([PEP 468](https://peps.python.org/pep-0468)). Explicit is better than implicit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91579
Approved by: https://github.com/zou3519
diff --git a/.lintrunner.toml b/.lintrunner.toml
index b1efdee..bb075d4 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -108,7 +108,6 @@
'torch/_functorch/eager_transforms.py',
'torch/_functorch/fx_minifier.py',
'torch/_functorch/partitioners.py',
- 'torch/_functorch/make_functional.py',
'torch/_functorch/top_operators_github_usage.py',
'torch/_functorch/vmap.py',
'torch/distributed/elastic/agent/server/api.py',
@@ -830,6 +829,7 @@
'torchgen/**/*.py',
'functorch/functorch/_src/aot_autograd.py',
'functorch/functorch/_src/compilers.py',
+ 'torch/_functorch/make_functional.py',
'torch/testing/*.py',
'torch/distributed/fsdp/*.py',
'test/distributed/fsdp/*.py',
diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py
index abb3f07..cd7db82 100644
--- a/torch/_functorch/make_functional.py
+++ b/torch/_functorch/make_functional.py
@@ -4,12 +4,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+import copy
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ NoReturn,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+
import torch
import torch.nn as nn
from torch import Tensor
-from typing import List, Tuple
-from .named_members_polyfill import _named_parameters, _named_buffers
-import copy
+from .named_members_polyfill import _named_buffers, _named_parameters
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
@@ -40,69 +52,79 @@
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
-def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
+def _get_nested_attr(obj: nn.Module, names: List[str]) -> Tensor:
if len(names) == 1:
return getattr(obj, names[0])
else:
return _get_nested_attr(getattr(obj, names[0]), names[1:])
-def raise_parameter_tying_error():
+def raise_parameter_tying_error() -> NoReturn:
raise RuntimeError(
"make_functional(module): we don't yet support models that "
"do parameter tying (also sometimes known as weight sharing). "
"Please try to rewrite your model by replacing all instances of the "
"tied parameter with another and/or comment your support in "
- "https://github.com/pytorch/functorch/issues/446")
+ "https://github.com/pytorch/functorch/issues/446"
+ )
-def create_names_map(named_params, tied_named_params):
+def create_names_map(
+ named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
+ tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
+) -> Dict[str, List[List[str]]]:
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors
This function creates a mapping from the names in named_params to the
- names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
+ names in tied_named_params: {'A': [['A']], 'B': [['B'], ['B_tied']]}.
"""
- named_params = {k: v for k, v in named_params}
- tied_named_params = {k: v for k, v in tied_named_params}
+ named_params = dict(named_params)
+ tied_named_params = dict(tied_named_params)
tensors_dict_keys = set(named_params.keys())
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
- tensor_to_mapping = {}
+ tensor_to_mapping: Dict[Tensor, Tuple[str, List[List[str]]]] = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
- tensor_to_mapping[tensor][1].append(key.split('.'))
- result = {key: value for key, value in tensor_to_mapping.values()}
- return result
+ tensor_to_mapping[tensor][1].append(key.split("."))
+ return dict(tensor_to_mapping.values())
-def _extract_members(mod: nn.Module, _named_members, named_members, subclass):
+def _extract_members(
+ mod: nn.Module,
+ _named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
+ named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
+ subclass: Callable[[Tensor], Tensor],
+) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
all_named_members = tuple(_named_members(mod, remove_duplicate=False))
- named_members = tuple(named_members())
- names_map = create_names_map(named_members, all_named_members)
+ unique_named_members = tuple(named_members())
+ names_map = create_names_map(unique_named_members, all_named_members)
# Remove all the members in the model
memo = {}
for name, p in all_named_members:
if p not in memo:
- memo[p] = subclass(torch.empty_like(p, device='meta'))
+ memo[p] = subclass(torch.empty_like(p, device="meta"))
replacement = memo[p]
_set_nested_attr(mod, name.split("."), replacement)
- if len(named_members) == 0:
+ if len(unique_named_members) == 0:
names, params = (), ()
else:
- names, params = zip(*named_members)
+ names, params = zip(*unique_named_members) # type: ignore[assignment]
return params, names, names_map
-def extract_weights(mod: nn.Module):
+def extract_weights(
+ mod: nn.Module,
+) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
@@ -114,11 +136,18 @@
return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter)
-def extract_buffers(mod: nn.Module):
+def extract_buffers(
+ mod: nn.Module,
+) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[List[str]]]]:
return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x)
-def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
+def load_weights(
+ mod: nn.Module,
+ names: Sequence[str],
+ params: Sequence[Tensor],
+ as_params: bool = False,
+) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
@@ -131,8 +160,10 @@
_set_nested_attr(mod, name.split("."), p)
-def _swap_state(mod: nn.Module, names_map: List[str], elems):
- result = []
+def _swap_state(
+ mod: nn.Module, names_map: Dict[str, List[List[str]]], elems: Iterable[Tensor]
+) -> List[Tensor]:
+ result: List[Tensor] = []
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
@@ -142,15 +173,23 @@
return result
-def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None:
+def load_buffers(
+ mod: nn.Module,
+ names: Sequence[str],
+ buffers: Sequence[Tensor],
+ as_params: bool = False,
+) -> None:
for name, p in zip(names, buffers):
_set_nested_attr(mod, name.split("."), p)
def load_state(
- model: nn.Module,
- weights: List[Tensor], weight_names: List[str],
- buffers=(), buffer_names=()):
+ model: nn.Module,
+ weights: Sequence[Tensor],
+ weight_names: Sequence[str],
+ buffers: Sequence[Tensor] = (),
+ buffer_names: Sequence[str] = (),
+) -> nn.Module:
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
@@ -192,8 +231,10 @@
"""
buffers = list(model.buffers())
if len(buffers) > 0:
- raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
- 'make_functional_with_buffers_deprecated_v1(model) instead.')
+ raise RuntimeError(
+ "make_functional_deprecated_v1(model): `model` has buffers. Please use "
+ "make_functional_with_buffers_deprecated_v1(model) instead."
+ )
weights, descriptors, _ = extract_weights(model)
def fun(weights, data):
@@ -246,8 +287,14 @@
This is the callable object returned by :func:`make_functional_with_buffers`.
"""
- def __init__(self, stateless_model, param_names, buffer_names,
- param_names_map, buffer_names_map):
+ def __init__(
+ self,
+ stateless_model: nn.Module,
+ param_names: Tuple[str, ...],
+ buffer_names: Tuple[str, ...],
+ param_names_map: Dict[str, List[List[str]]],
+ buffer_names_map: Dict[str, List[List[str]]],
+ ) -> None:
super(FunctionalModuleWithBuffers, self).__init__()
self.stateless_model = stateless_model
self.param_names = param_names
@@ -257,7 +304,9 @@
self.all_names_map.update(buffer_names_map)
@staticmethod
- def _create_from(model, disable_autograd_tracking=False):
+ def _create_from(
+ model: nn.Module, disable_autograd_tracking: bool = False
+ ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
@@ -266,18 +315,22 @@
for param in params:
param.requires_grad_(False)
return (
- FunctionalModuleWithBuffers(model_copy, param_names, buffer_names,
- param_names_map, buffer_names_map),
+ FunctionalModuleWithBuffers(
+ model_copy, param_names, buffer_names, param_names_map, buffer_names_map
+ ),
params,
buffers,
)
- def forward(self, params, buffers, *args, **kwargs):
+ def forward(
+ self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
+ ) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(
self.stateless_model,
self.all_names_map,
- list(params) + list(buffers))
+ tuple(params) + tuple(buffers),
+ )
try:
return self.stateless_model(*args, **kwargs)
finally:
@@ -290,14 +343,21 @@
This is the callable object returned by :func:`make_functional`.
"""
- def __init__(self, stateless_model, param_names, names_map):
+ def __init__(
+ self,
+ stateless_model: nn.Module,
+ param_names: Tuple[str, ...],
+ names_map: Dict[str, List[List[str]]],
+ ) -> None:
super(FunctionalModule, self).__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.names_map = names_map
@staticmethod
- def _create_from(model, disable_autograd_tracking=False):
+ def _create_from(
+ model: nn.Module, disable_autograd_tracking: bool = False
+ ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
@@ -306,7 +366,7 @@
param.requires_grad_(False)
return FunctionalModule(model_copy, param_names, names_map), params
- def forward(self, params, *args, **kwargs):
+ def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(self.stateless_model, self.names_map, params)
try:
@@ -316,7 +376,9 @@
_swap_state(self.stateless_model, self.names_map, old_state)
-def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
+def make_functional(
+ model: nn.Module, disable_autograd_tracking: bool = False
+) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
"""make_functional(model, disable_autograd_tracking=False) -> func, params
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
@@ -375,12 +437,18 @@
"""
buffers = list(model.buffers())
if len(buffers) > 0:
- raise RuntimeError('make_functional(model): `model` has buffers. Please use '
- 'make_functional_with_buffers(model) instead.')
- return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
+ raise RuntimeError(
+ "make_functional(model): `model` has buffers. Please use "
+ "make_functional_with_buffers(model) instead."
+ )
+ return FunctionalModule._create_from(
+ model, disable_autograd_tracking=disable_autograd_tracking
+ )
-def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
+def make_functional_with_buffers(
+ model: nn.Module, disable_autograd_tracking: bool = False
+) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
@@ -434,16 +502,24 @@
history with PyTorch autograd.
"""
- return FunctionalModuleWithBuffers._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
+ return FunctionalModuleWithBuffers._create_from(
+ model, disable_autograd_tracking=disable_autograd_tracking
+ )
-def transpose_stack(tuple_of_tuple_of_tensors):
+def transpose_stack(
+ tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
+) -> Tuple[Tensor, ...]:
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
- results = tuple(torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors)
+ results = tuple(
+ torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
+ )
return results
-def combine_state_for_ensemble(models):
+def combine_state_for_ensemble(
+ models: Sequence[nn.Module],
+) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
"""combine_state_for_ensemble(models) -> func, params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
@@ -482,62 +558,90 @@
create ensembles and would love your feedback how to improve this.
"""
if len(models) == 0:
- raise RuntimeError('combine_state_for_ensemble: Expected at least one model, got 0.')
+ raise RuntimeError(
+ "combine_state_for_ensemble: Expected at least one model, got 0."
+ )
if not (all(m.training for m in models) or all(not m.training for m in models)):
- raise RuntimeError('combine_state_for_ensemble: Expected all models to '
- 'have the same training/eval mode.')
+ raise RuntimeError(
+ "combine_state_for_ensemble: Expected all models to "
+ "have the same training/eval mode."
+ )
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
- raise RuntimeError('combine_state_for_ensemble: Expected all models to '
- 'be of the same class.')
- funcs, params, buffers = zip(*[make_functional_with_buffers(model)
- for model in models])
+ raise RuntimeError(
+ "combine_state_for_ensemble: Expected all models to "
+ "be of the same class."
+ )
+ funcs, params, buffers = zip(
+ *[make_functional_with_buffers(model) for model in models]
+ )
params = transpose_stack(params)
buffers = transpose_stack(buffers)
return funcs[0], params, buffers
-def functional_init(model_class, ensemble_shape=(), device='cpu'):
+def functional_init(
+ model_class: Type[nn.Module],
+ ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
+ device: torch.types.Device = "cpu",
+):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
- raise ValueError('NYI: ensemble_shape with more than 1 element')
+ raise ValueError("NYI: ensemble_shape with more than 1 element")
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
- num_models = ensemble_shape[0]
+ num_models = ensemble_shape[0] # type: ignore[misc]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
- models = tuple(model_class(*args, **kwargs).to(device)
- for _ in range(num_models))
+ models = tuple(
+ model_class(*args, **kwargs).to(device) for _ in range(num_models)
+ )
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights, fn, names
+
return wrapped
-def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
+def functional_init_with_buffers(
+ model_class: Type[nn.Module],
+ ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
+ device: torch.types.Device = "cpu",
+):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
- raise ValueError('NYI: ensemble_shape with more than 1 element')
+ raise ValueError("NYI: ensemble_shape with more than 1 element")
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
- num_models = ensemble_shape[0]
+ num_models = ensemble_shape[0] # type: ignore[misc]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
- models = tuple(model_class(*args, **kwargs).to(device)
- for _ in range(num_models))
- _, _, fn, weight_names, buffer_names = \
- make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
- weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2]
- for model in models))
+ models = tuple(
+ model_class(*args, **kwargs).to(device) for _ in range(num_models)
+ )
+ (
+ _,
+ _,
+ fn,
+ weight_names,
+ buffer_names,
+ ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
+ weights, buffers = zip(
+ *tuple(
+ make_functional_with_buffers_deprecated_v1(model)[:2]
+ for model in models
+ )
+ )
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
buffers = tuple(zip(*buffers))
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
return weights, buffers, fn, weight_names, buffer_names
+
return wrapped