[ZeRO] (Reland) Add ctor support for multiple param groups (#72932)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.
**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).
To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.
**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932
Reviewed By: rohan-varma
Differential Revision: D34281482
Pulled By: awgu
fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc6349ff1aad403563206fb170a3af0c70)
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index de8ea51..67c2745 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -33,7 +33,7 @@
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import SGD
+from torch.optim import SGD, AdamW
from torch.testing._internal import common_distributed, common_utils
from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
@@ -249,27 +249,54 @@
def test_constructor(self):
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
- passing different values for `params`"""
+ passing different values for the ``params`` argument."""
self.dist_init(self.rank)
- m = torch.nn.Linear(1, 1)
- # (input, expected error)
- inputs = [
+ m = torch.nn.Sequential(
+ torch.nn.Linear(5, 10),
+ torch.nn.Linear(10, 10),
+ torch.nn.Linear(10, 10),
+ )
+
+ # Test various constructor inputs in the form: (input, expected error)
+ ctor_inputs = [
([], ValueError), # empty parameter list
(torch.randn(1), TypeError), # non-iterable: `torch.Tensor`
(1.2, TypeError), # non-iterable: `float`
- ([{"params": m.parameters()}], TypeError), # iterable of dict
- (list(m.parameters()) + [42], TypeError), # iterable containing non-`torch.Tensor`
+ ([
+ {"params": [l.weight for l in m]},
+ {"params": [l.bias for l in m]},
+ ], None), # iterable of dict
+ (list(m.parameters()) + [42], TypeError), # iterable containing invalid type
(m.parameters(), None), # `params` as a generator
(list(m.parameters()), None) # `params` as a list
]
- for input, error in inputs:
- if (error):
+ for ctor_input, error in ctor_inputs:
+ if error:
with self.assertRaises(error):
- ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
+ ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
else:
- ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
+ ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
+
+ # Test constructing with multiple parameter groups more thoroughly
+ weight_decay = 0.01
+ lr = 0.01
+ betas = (0.9, 0.999)
+ eps = 1e-8
+ params = [
+ {"params": [l.weight for l in m], "weight_decay": 0.},
+ {"params": [l.bias for l in m], "weight_decay": weight_decay},
+ ]
+ o = ZeroRedundancyOptimizer(
+ params, optimizer_class=AdamW,
+ lr=lr, betas=betas, eps=eps,
+ )
+ assert len(o.param_groups) == 2, \
+ f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
+ assert len(o.optim.param_groups) == 2, \
+ "Expected 2 local optimizer param groups, but got " \
+ f"{len(o.optim.param_groups)}"
def test_same_dense_param_type(self):
"""Check that ZeroRedundancyOptimizer raises an exception if the input
@@ -459,7 +486,76 @@
all_trainable()
some_trainable()
+ @common_distributed.skip_if_no_gpu
+ def test_multiple_param_groups(self):
+ """
+ Tests parity between constructing ZeRO with multiple parameter groups
+ upfront versus adding parameter groups to ZeRO after construction
+ versus a non-sharded optimizer.
+ """
+ self.dist_init(self.rank)
+
+ model1 = torch.nn.Sequential(
+ torch.nn.Linear(5, 10),
+ torch.nn.Linear(10, 10),
+ torch.nn.Linear(10, 5),
+ )
+ model2 = copy.deepcopy(model1)
+ model3 = copy.deepcopy(model1)
+ model1 = model1.to(self.device)
+ model2 = model2.to(self.device)
+ model3 = model3.to(self.device)
+
+ batch_size = 8
+ num_iters = 3
+ inputs = [
+ torch.randn(batch_size, 5).to(self.device) for _ in range(num_iters)
+ ]
+ wd = 0.01
+ lr = 0.01
+ # Construct `optim1` with both parameter groups upfront
+ optim1 = ZeroRedundancyOptimizer(
+ [
+ {"params": [l.weight for l in model1], "weight_decay": 0.},
+ {"params": [l.bias for l in model1], "weight_decay": wd},
+ ],
+ optimizer_class=AdamW, lr=lr,
+ )
+ # Construct `optim2` by adding the second parameter after
+ optim2 = ZeroRedundancyOptimizer(
+ [l.weight for l in model2],
+ optimizer_class=AdamW, lr=lr, weight_decay=0.,
+ )
+ optim2.add_param_group(
+ {"params": [l.bias for l in model2], "weight_decay": wd}
+ )
+ # Construct `optim3` as a non-sharded optimizer
+ optim3 = AdamW(
+ [
+ {"params": [l.weight for l in model3], "weight_decay": 0.},
+ {"params": [l.bias for l in model3], "weight_decay": wd},
+ ], lr=lr,
+ )
+
+ # Check parity over a few iterations
+ for iter in range(num_iters):
+ for model, optim in (
+ (model1, optim1), (model2, optim2), (model3, optim3),
+ ):
+ optim.zero_grad()
+ out = model(inputs[iter])
+ loss = out.sum()
+ loss.backward()
+ optim.step()
+
+ for layer1, layer2, layer3 in zip(model1, model2, model3):
+ assert torch.allclose(layer1.weight, layer2.weight)
+ assert torch.allclose(layer1.weight, layer3.weight)
+ assert torch.allclose(layer1.bias, layer2.bias)
+ assert torch.allclose(layer1.bias, layer3.bias)
+
@common_distributed.skip_if_lt_x_gpu(2)
+ @common_distributed.skip_if_rocm
def test_collect_shards(self):
""" Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
self.dist_init(self.rank)
diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py
index 70779ea..a87bfda 100644
--- a/torch/distributed/optim/zero_redundancy_optimizer.py
+++ b/torch/distributed/optim/zero_redundancy_optimizer.py
@@ -10,7 +10,16 @@
import io
import logging
from itertools import chain
-from typing import Any, Callable, Dict, List, Optional, Set, Type
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Type,
+ Union,
+)
import torch
import torch.distributed as dist
@@ -287,7 +296,8 @@
Arguments:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
- giving all parameters, which will be sharded across ranks.
+ or :class:`dict` s giving all parameters, which will be sharded
+ across ranks.
Keyword Args:
optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
@@ -364,7 +374,7 @@
**defaults: Any,
):
# Perform type and assumption checks on the input parameters
- self._verify_and_init_params(params)
+ params = self._verify_and_init_params(params)
self._verify_same_dense_param_type()
# NOTE: The parent constructor uses `add_param_group()` which is
@@ -373,7 +383,7 @@
# between the parent and child.
self.initialized = False
- Optimizer.__init__(self, self._all_params, defaults)
+ Optimizer.__init__(self, params, defaults)
Joinable.__init__(self)
# Now, all parameters are held in both `self._all_params` and
# `self.param_groups`
@@ -1289,36 +1299,60 @@
offset = offset_next
bucket_assignment.tensor = tensor
- def _verify_and_init_params(self, params: Any) -> None:
+ def _verify_and_init_params(
+ self, params: Any,
+ ) -> Union[List[torch.Tensor], List[dict]]:
r"""
Verifies the type of ``params`` and initializes ``self._all_params``
- if ``params`` is valid.
+ as a :class:`list` of all parameters if ``params`` is valid.
- While :class:`optim.Optimizer <torch.optim.Optimizer>` allows
- ``params`` to be an iterable of :class:`dict` s, currently
- ``ZeroRedundancyOptimizer`` strictly requires ``params`` to be an
- iterable of :class:`torch.Tensor` s.
+ Arguments:
+ params (Any): Candidate parameter list or parameter groups to
+ verify.
Raises:
TypeError: ``params`` has an invalid type.
ValueError: ``params`` is empty.
+
+ Returns:
+ The persistent form of ``params`` to be passed into the parent
+ :class:`Optimizer` constructor -- i.e. returns ``params`` as a
+ :class:`list` to ensure that it can be iterated over again.
"""
if isinstance(params, torch.Tensor):
- raise TypeError("params argument should be an iterable of "
+ raise TypeError("`params` argument should be an iterable of "
f"Tensors, but got {torch.typename(params)}")
try:
- self._all_params = list(params)
+ all_params = list(params)
except TypeError:
- raise TypeError("params argument should be an iterable of "
+ raise TypeError("`params` argument should be an iterable of "
f"Tensors, but got {torch.typename(params)}")
- if len(self._all_params) == 0:
+ if len(all_params) == 0:
raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
"list")
- for param in self._all_params:
- if not isinstance(param, torch.Tensor):
- raise TypeError("params argument should be an iterable of "
- "Tensors, but got an iterable containing "
- f"{torch.typename(param)}")
+ all_tensors = True
+ all_dicts = True
+ for param in all_params:
+ all_tensors &= isinstance(param, torch.Tensor)
+ all_dicts &= isinstance(param, dict)
+ if not all_tensors and not all_dicts:
+ raise TypeError("`params` argument should be an iterable of "
+ "Tensors or dicts")
+ # Ensure that `self._all_params` contains a list of all parameters
+ if all_tensors:
+ self._all_params = all_params
+ elif all_dicts:
+ self._all_params = []
+ # `all_params` contains parameter groups (not parameters)
+ for param_group in all_params:
+ if "params" not in param_group:
+ raise ValueError(
+ "Each parameter group passed-in via `params` must "
+ "have a 'params' key mapping to the parameters in "
+ "the group"
+ )
+ self._all_params.extend(param_group["params"])
+ return all_params
def _verify_same_dense_param_type(self) -> None:
r"""