[state_dict][5/N] Add submodules save and load support (#111110)
It is not easy for user to do submodules save and load (e.g., fine tuning) because FSDP requires to get the root module. This PR enables the support of submodule save and load.
Differential Revision: [D50209727](https://our.internmc.facebook.com/intern/diff/D50209727/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111110
Approved by: https://github.com/wz337
ghstack dependencies: #111106, #111107, #111275, #111109
diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py
index f510aa1..ff71646 100644
--- a/test/distributed/checkpoint/test_state_dict.py
+++ b/test/distributed/checkpoint/test_state_dict.py
@@ -374,3 +374,48 @@
)
with self.assertRaisesRegex(RuntimeError, "Missing key"):
load_state_dict(model, model_state_dict=model_state_dict)
+
+ @skip_if_lt_x_gpu(1)
+ def test_partial(self) -> None:
+ model = CompositeParamModel(device=torch.device("cuda"))
+
+ model_state_dict1, _ = state_dict(model)
+ model_state_dict1 = copy.deepcopy(model_state_dict1)
+ model_state_dict2, _ = state_dict(model, submodules={model.l})
+ model_state_dict2 = copy.deepcopy(model_state_dict2)
+ model_state_dict3, _ = state_dict(
+ model,
+ submodules={model.l},
+ options=StateDictOptions(keep_submodule_prefixes=False),
+ )
+ model_state_dict3 = copy.deepcopy(model_state_dict3)
+ self.assertEqual(len(model_state_dict2), 2)
+ self.assertEqual(len(model_state_dict3), 2)
+ for key in model_state_dict3.keys():
+ full_fqn = f"l.{key}"
+ value1 = model_state_dict1[full_fqn]
+ value2 = model_state_dict2[full_fqn]
+ value3 = model_state_dict3[key]
+ self.assertEqual(value1, value2)
+ self.assertEqual(value2, value3)
+
+ zeros_state_dict = {
+ k: torch.zeros_like(v) for k, v in model_state_dict1.items()
+ }
+ model.load_state_dict(zeros_state_dict)
+ load_state_dict(
+ model,
+ model_state_dict=model_state_dict2,
+ options=StateDictOptions(strict=False),
+ )
+ self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
+ self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
+
+ model.load_state_dict(zeros_state_dict)
+ load_state_dict(
+ model,
+ model_state_dict={model.l: model_state_dict3},
+ options=StateDictOptions(strict=False),
+ )
+ self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
+ self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py
index 3236b42..cf301a1 100644
--- a/torch/distributed/checkpoint/state_dict.py
+++ b/torch/distributed/checkpoint/state_dict.py
@@ -90,6 +90,15 @@
# Whether to ignore the frozen parameters when getting the state_dict.
# The default is False.
ignore_frozen_params: bool = False
+ # When asking to return only the submodule state_dict (submodules != None),
+ # whether to keep the submodule prefixes from the state_dict keys.
+ # For example, if the submodule is ``module.pretrain`` and the full FQN of
+ # the parameter is ``pretrain.layer1.weight`` of the param, setting
+ # this option to False will return ``layer.weight``, otherwise the full FQN
+ # will be returned.
+ # Note that if ``keep_submodule_prefixes`` is False, there may be conflict
+ # FQNs, hence there shouldbe only one submodule in ``submodules``.
+ keep_submodule_prefixes: bool = True
# The `strict` option for model.load_state_dict() call.
strict: bool = True
@@ -100,6 +109,7 @@
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
] = field(default_factory=dict)
all_fqns: Set[str] = field(default_factory=set)
+ submodule_prefixes: Set[str] = field(default_factory=set)
handle_model: bool = True
handle_optim: bool = True
fsdp_context: Callable = contextlib.nullcontext
@@ -156,6 +166,8 @@
optims: Tuple[torch.optim.Optimizer, ...],
model_only: bool,
optim_only: bool,
+ *,
+ submodules: Optional[Set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> _StateDictInfo:
"""
@@ -187,6 +199,17 @@
fqn_param_mapping[fqn] = param
all_fqns.add(fqn)
+ submodule_prefixes = set()
+ if submodules:
+ submodules = set(submodules)
+ for name, module in model.named_modules():
+ if module not in submodules:
+ continue
+ fqns = _get_fqns(model, name)
+ assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
+ for fqn in fqns:
+ submodule_prefixes.add(f"{fqn}.")
+
fsdp_modules = FSDP.fsdp_modules(model)
state_dict_config: StateDictConfig
optim_state_dict_config: OptimStateDictConfig
@@ -222,6 +245,7 @@
**asdict(options),
fqn_param_mapping=fqn_param_mapping,
all_fqns=all_fqns,
+ submodule_prefixes=submodule_prefixes,
fsdp_context=fsdp_context,
fsdp_modules=cast(List[nn.Module], fsdp_modules),
handle_model=model_only or not optim_only,
@@ -247,7 +271,13 @@
# Verify if the model_state_dict and optim_state_dict are valid. This API
# should give the users an explicit error message to debug or report.
- if info.handle_model and not model_state_dict:
+ if (
+ info.handle_model
+ and not model_state_dict
+ and not info.submodule_prefixes
+ and not info.ignore_frozen_params
+ and info.strict
+ ):
raise RuntimeError(
"The option indicates that model state_dict is required to save "
"or load, but model state_dict is empty."
@@ -312,6 +342,20 @@
raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}")
state_dict[fqn] = state_dict.pop(key)
+ if info.submodule_prefixes:
+ new_state_dict: Dict[str, ValueType] = {}
+ # TODO: make this faster.
+ for fqn in state_dict.keys():
+ for prefix in info.submodule_prefixes:
+ if not fqn.startswith(prefix):
+ continue
+ if info.keep_submodule_prefixes:
+ new_state_dict[fqn] = state_dict[fqn]
+ else:
+ new_fqn = fqn[len(prefix) :]
+ new_state_dict[new_fqn] = state_dict[fqn]
+ state_dict = new_state_dict
+
if info.ignore_frozen_params:
for key, param in model.named_parameters():
if param.requires_grad:
@@ -327,7 +371,7 @@
state_dict: Dict[str, ValueType],
info: _StateDictInfo,
) -> None:
- if not info.handle_model:
+ if not info.handle_model or not state_dict:
return
for key, _ in model.named_parameters():
@@ -494,6 +538,7 @@
] = None,
model_only: bool = False,
optim_only: bool = False,
+ submodules: Optional[Set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> Tuple[Dict[str, ValueType], OptimizerStateType]:
"""
@@ -554,6 +599,9 @@
optim_only (bool): if optim_only is True, the returned model state_dict
will be empty (default: False)
+ submodules: Optional[Set[nn.Module]]: only return the model parameters
+ that belong to the submodules.
+
options (StateDictOptions): the options to control how
model state_dict and optimizer state_dict should be returned. See
`StateDictOptions` for the details.
@@ -575,20 +623,55 @@
else tuple(optimizers)
)
)
- info = _verify_options(model, optimizers, model_only, optim_only, options)
+ info = _verify_options(
+ model,
+ optimizers,
+ model_only,
+ optim_only,
+ submodules=submodules,
+ options=options,
+ )
model_state_dict = _get_model_state_dict(model, info)
optim_state_dict = _get_optim_state_dict(model, optimizers, info)
_verify_state_dict(model_state_dict, optim_state_dict, info)
return model_state_dict, optim_state_dict
+def _unflatten_model_state_dict(
+ model: nn.Module,
+ state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]],
+) -> Dict[str, ValueType]:
+ if not state_dict:
+ return {}
+
+ if isinstance(next(iter(state_dict.keys())), nn.Module):
+ cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict)
+ new_state_dict: Dict[str, ValueType] = {}
+ for submodule, sub_state_dict in cast_state_dict.items():
+ for name, m in model.named_modules():
+ if m != submodule:
+ continue
+
+ fqns = _get_fqns(model, name)
+ assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
+ prefix = f"{next(iter(fqns))}."
+ new_state_dict.update(
+ {prefix + subfqn: value for subfqn, value in sub_state_dict.items()}
+ )
+ return new_state_dict
+ else:
+ return cast(Dict[str, ValueType], state_dict)
+
+
def load_state_dict(
model: nn.Module,
*,
optimizers: Union[
None, torch.optim.Optimizer, Iterable[torch.optim.Optimizer]
] = None,
- model_state_dict: Optional[Dict[str, ValueType]] = None,
+ model_state_dict: Union[
+ None, Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]
+ ] = None,
optim_state_dict: Optional[OptimizerStateType] = None,
model_only: bool = False,
optim_only: bool = False,
@@ -609,6 +692,14 @@
model (nn.Module): the nn.Module to the model.
optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
The optimizers that are used to optimize ``model``.
+ model_state_dict: (Union[
+ None, Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):
+ the model state_dict to load. If the key of the ``model_state_dict``
+ is nn.Module, the key is a submodule of ``model`` and the value should
+ be the state_dict of the submodule. When loading the state_dict,
+ the prefix of the submodule will be append to the state_dict.
+ optim_state_dict: Optional[OptimizerStateType]:
+ the optimizer state_dict to load.
model_only (bool): if model_only is True, only the model state_dict will
be loaded (default: False)
optim_only (bool): if optim_only is True, only the optimizer state_dict
@@ -621,7 +712,9 @@
None
"""
- model_state_dict = model_state_dict if model_state_dict else {}
+ model_state_dict: Dict[str, ValueType] = (
+ _unflatten_model_state_dict(model, model_state_dict) if model_state_dict else {}
+ )
optim_state_dict = optim_state_dict if optim_state_dict else {}
with gc_context():
optimizers = (
@@ -633,7 +726,10 @@
else tuple(optimizers)
)
)
- info = _verify_options(model, optimizers, model_only, optim_only, options)
+ info = _verify_options(
+ model, optimizers, model_only, optim_only, options=options
+ )
+
_verify_state_dict(model_state_dict, optim_state_dict, info)
_load_model_state_dict(model, model_state_dict, info)
_load_optim_state_dict(model, optimizers, optim_state_dict, info)