[ZeRO] (Reland) Add ctor support for multiple param groups (#72932)

Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.

**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).

To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.

**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932

Reviewed By: rohan-varma

Differential Revision: D34281482

Pulled By: awgu

fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc6349ff1aad403563206fb170a3af0c70)
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index de8ea51..67c2745 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -33,7 +33,7 @@
 from torch.distributed.optim import ZeroRedundancyOptimizer
 from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
 from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import SGD
+from torch.optim import SGD, AdamW
 from torch.testing._internal import common_distributed, common_utils
 from torch.testing._internal.common_utils import (
     TEST_WITH_ASAN,
@@ -249,27 +249,54 @@
 
     def test_constructor(self):
         """Check the robustness of the ZeroRedundancyOptimizer constructor by
-        passing different values for `params`"""
+        passing different values for the ``params`` argument."""
         self.dist_init(self.rank)
 
-        m = torch.nn.Linear(1, 1)
-        # (input, expected error)
-        inputs = [
+        m = torch.nn.Sequential(
+            torch.nn.Linear(5, 10),
+            torch.nn.Linear(10, 10),
+            torch.nn.Linear(10, 10),
+        )
+
+        # Test various constructor inputs in the form: (input, expected error)
+        ctor_inputs = [
             ([], ValueError),                           # empty parameter list
             (torch.randn(1), TypeError),                # non-iterable: `torch.Tensor`
             (1.2, TypeError),                           # non-iterable: `float`
-            ([{"params": m.parameters()}], TypeError),  # iterable of dict
-            (list(m.parameters()) + [42], TypeError),   # iterable containing non-`torch.Tensor`
+            ([
+                {"params": [l.weight for l in m]},
+                {"params": [l.bias for l in m]},
+            ], None),                                   # iterable of dict
+            (list(m.parameters()) + [42], TypeError),   # iterable containing invalid type
             (m.parameters(), None),                     # `params` as a generator
             (list(m.parameters()), None)                # `params` as a list
         ]
 
-        for input, error in inputs:
-            if (error):
+        for ctor_input, error in ctor_inputs:
+            if error:
                 with self.assertRaises(error):
-                    ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
+                    ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
             else:
-                ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=0.1)
+                ZeroRedundancyOptimizer(ctor_input, optimizer_class=SGD, lr=0.01)
+
+        # Test constructing with multiple parameter groups more thoroughly
+        weight_decay = 0.01
+        lr = 0.01
+        betas = (0.9, 0.999)
+        eps = 1e-8
+        params = [
+            {"params": [l.weight for l in m], "weight_decay": 0.},
+            {"params": [l.bias for l in m], "weight_decay": weight_decay},
+        ]
+        o = ZeroRedundancyOptimizer(
+            params, optimizer_class=AdamW,
+            lr=lr, betas=betas, eps=eps,
+        )
+        assert len(o.param_groups) == 2, \
+            f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
+        assert len(o.optim.param_groups) == 2, \
+            "Expected 2 local optimizer param groups, but got " \
+            f"{len(o.optim.param_groups)}"
 
     def test_same_dense_param_type(self):
         """Check that ZeroRedundancyOptimizer raises an exception if the input
@@ -459,7 +486,76 @@
         all_trainable()
         some_trainable()
 
+    @common_distributed.skip_if_no_gpu
+    def test_multiple_param_groups(self):
+        """
+        Tests parity between constructing ZeRO with multiple parameter groups
+        upfront versus adding parameter groups to ZeRO after construction
+        versus a non-sharded optimizer.
+        """
+        self.dist_init(self.rank)
+
+        model1 = torch.nn.Sequential(
+            torch.nn.Linear(5, 10),
+            torch.nn.Linear(10, 10),
+            torch.nn.Linear(10, 5),
+        )
+        model2 = copy.deepcopy(model1)
+        model3 = copy.deepcopy(model1)
+        model1 = model1.to(self.device)
+        model2 = model2.to(self.device)
+        model3 = model3.to(self.device)
+
+        batch_size = 8
+        num_iters = 3
+        inputs = [
+            torch.randn(batch_size, 5).to(self.device) for _ in range(num_iters)
+        ]
+        wd = 0.01
+        lr = 0.01
+        # Construct `optim1` with both parameter groups upfront
+        optim1 = ZeroRedundancyOptimizer(
+            [
+                {"params": [l.weight for l in model1], "weight_decay": 0.},
+                {"params": [l.bias for l in model1], "weight_decay": wd},
+            ],
+            optimizer_class=AdamW, lr=lr,
+        )
+        # Construct `optim2` by adding the second parameter after
+        optim2 = ZeroRedundancyOptimizer(
+            [l.weight for l in model2],
+            optimizer_class=AdamW, lr=lr, weight_decay=0.,
+        )
+        optim2.add_param_group(
+            {"params": [l.bias for l in model2], "weight_decay": wd}
+        )
+        # Construct `optim3` as a non-sharded optimizer
+        optim3 = AdamW(
+            [
+                {"params": [l.weight for l in model3], "weight_decay": 0.},
+                {"params": [l.bias for l in model3], "weight_decay": wd},
+            ], lr=lr,
+        )
+
+        # Check parity over a few iterations
+        for iter in range(num_iters):
+            for model, optim in (
+                (model1, optim1), (model2, optim2), (model3, optim3),
+            ):
+                optim.zero_grad()
+                out = model(inputs[iter])
+                loss = out.sum()
+                loss.backward()
+                optim.step()
+
+            for layer1, layer2, layer3 in zip(model1, model2, model3):
+                assert torch.allclose(layer1.weight, layer2.weight)
+                assert torch.allclose(layer1.weight, layer3.weight)
+                assert torch.allclose(layer1.bias, layer2.bias)
+                assert torch.allclose(layer1.bias, layer3.bias)
+
     @common_distributed.skip_if_lt_x_gpu(2)
+    @common_distributed.skip_if_rocm
     def test_collect_shards(self):
         """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
         self.dist_init(self.rank)
diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py
index 70779ea..a87bfda 100644
--- a/torch/distributed/optim/zero_redundancy_optimizer.py
+++ b/torch/distributed/optim/zero_redundancy_optimizer.py
@@ -10,7 +10,16 @@
 import io
 import logging
 from itertools import chain
-from typing import Any, Callable, Dict, List, Optional, Set, Type
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Type,
+    Union,
+)
 
 import torch
 import torch.distributed as dist
@@ -287,7 +296,8 @@
 
     Arguments:
         params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
-            giving all parameters, which will be sharded across ranks.
+            or :class:`dict` s giving all parameters, which will be sharded
+            across ranks.
 
     Keyword Args:
         optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
@@ -364,7 +374,7 @@
         **defaults: Any,
     ):
         # Perform type and assumption checks on the input parameters
-        self._verify_and_init_params(params)
+        params = self._verify_and_init_params(params)
         self._verify_same_dense_param_type()
 
         # NOTE: The parent constructor uses `add_param_group()` which is
@@ -373,7 +383,7 @@
         # between the parent and child.
         self.initialized = False
 
-        Optimizer.__init__(self, self._all_params, defaults)
+        Optimizer.__init__(self, params, defaults)
         Joinable.__init__(self)
         # Now, all parameters are held in both `self._all_params` and
         # `self.param_groups`
@@ -1289,36 +1299,60 @@
                     offset = offset_next
                 bucket_assignment.tensor = tensor
 
-    def _verify_and_init_params(self, params: Any) -> None:
+    def _verify_and_init_params(
+        self, params: Any,
+    ) -> Union[List[torch.Tensor], List[dict]]:
         r"""
         Verifies the type of ``params`` and initializes ``self._all_params``
-        if ``params`` is valid.
+        as a :class:`list` of all parameters if ``params`` is valid.
 
-        While :class:`optim.Optimizer <torch.optim.Optimizer>` allows
-        ``params`` to be an iterable of :class:`dict` s, currently
-        ``ZeroRedundancyOptimizer`` strictly requires ``params`` to be an
-        iterable of :class:`torch.Tensor` s.
+        Arguments:
+            params (Any): Candidate parameter list or parameter groups to
+                verify.
 
         Raises:
             TypeError: ``params`` has an invalid type.
             ValueError: ``params`` is empty.
+
+        Returns:
+            The persistent form of ``params`` to be passed into the parent
+            :class:`Optimizer` constructor -- i.e. returns ``params`` as a
+            :class:`list` to ensure that it can be iterated over again.
         """
         if isinstance(params, torch.Tensor):
-            raise TypeError("params argument should be an iterable of "
+            raise TypeError("`params` argument should be an iterable of "
                             f"Tensors, but got {torch.typename(params)}")
         try:
-            self._all_params = list(params)
+            all_params = list(params)
         except TypeError:
-            raise TypeError("params argument should be an iterable of "
+            raise TypeError("`params` argument should be an iterable of "
                             f"Tensors, but got {torch.typename(params)}")
-        if len(self._all_params) == 0:
+        if len(all_params) == 0:
             raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
                              "list")
-        for param in self._all_params:
-            if not isinstance(param, torch.Tensor):
-                raise TypeError("params argument should be an iterable of "
-                                "Tensors, but got an iterable containing "
-                                f"{torch.typename(param)}")
+        all_tensors = True
+        all_dicts = True
+        for param in all_params:
+            all_tensors &= isinstance(param, torch.Tensor)
+            all_dicts &= isinstance(param, dict)
+        if not all_tensors and not all_dicts:
+            raise TypeError("`params` argument should be an iterable of "
+                            "Tensors or dicts")
+        # Ensure that `self._all_params` contains a list of all parameters
+        if all_tensors:
+            self._all_params = all_params
+        elif all_dicts:
+            self._all_params = []
+            # `all_params` contains parameter groups (not parameters)
+            for param_group in all_params:
+                if "params" not in param_group:
+                    raise ValueError(
+                        "Each parameter group passed-in via `params` must "
+                        "have a 'params' key mapping to the parameters in "
+                        "the group"
+                    )
+                self._all_params.extend(param_group["params"])
+        return all_params
 
     def _verify_same_dense_param_type(self) -> None:
         r"""