parameterized test_graph_optims and test_graph_scaling_fused_optimizers (#133749)
Fixes #123451
This is a rework of a reverted pull request, https://github.com/pytorch/pytorch/pull/125127.
The test failure is fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133749
Approved by: https://github.com/janeyx99
diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py
index a1a38a9..0fcba0d 100644
--- a/test/inductor/test_compiled_optimizers.py
+++ b/test/inductor/test_compiled_optimizers.py
@@ -146,6 +146,8 @@
"test_sgd_weight_decay_maximize_cuda": 4,
"test_sgd_weight_decay_maximize_xpu": 4,
"test_sgd_weight_decay_maximize_cpu": 4,
+ "test_sgd_weight_decay_cpu": 4,
+ "test_sgd_weight_decay_cuda": 4,
"test_sgd_momentum_weight_decay_foreach_cuda": 2,
"test_sgd_momentum_weight_decay_foreach_xpu": 2,
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
diff --git a/test/test_cuda.py b/test/test_cuda.py
index e5e25a6..bb52e8c 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -40,7 +40,12 @@
onlyCUDA,
onlyNativeDeviceTypes,
)
-from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker
+from torch.testing._internal.common_optimizers import (
+ _get_optim_inputs_including_global_cliquey_kwargs,
+ optim_db,
+ optims,
+ TensorTracker,
+)
from torch.testing._internal.common_utils import (
EXPANDABLE_SEGMENTS,
freeze_rng_state,
@@ -3457,93 +3462,6 @@
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
- def test_graph_optims(self):
- # Needs generalization if we want to extend this test to non-Adam-like optimizers.
- cases = (
- [
- (
- optimizer_ctor,
- {
- "lr": 0.1,
- "betas": (0.8, 0.7),
- "foreach": foreach,
- "decoupled_weight_decay": decoupled_weight_decay,
- "weight_decay": weight_decay,
- },
- )
- for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product(
- (torch.optim.NAdam, torch.optim.RAdam),
- (False, True),
- (False, True),
- (0.0, 0.1),
- )
- ]
- + [
- (
- torch.optim.Rprop,
- {"lr": 0.1, "foreach": foreach, "maximize": maximize},
- )
- for foreach, maximize in product(
- (False, True),
- (False, True),
- )
- ]
- + [
- (
- optimizer_ctor,
- {
- "lr": 0.1,
- "betas": (0.8, 0.7),
- "foreach": foreach,
- "amsgrad": amsgrad,
- },
- )
- for optimizer_ctor, foreach, amsgrad in product(
- (torch.optim.Adam, torch.optim.AdamW),
- (False, True),
- (False, True),
- )
- ]
- + [
- (
- optimizer_ctor,
- {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
- )
- for optimizer_ctor, amsgrad in product(
- (torch.optim.Adam, torch.optim.AdamW), (False, True)
- )
- ]
- + [
- (
- optimizer_ctor,
- {
- "lr": 0.1,
- "foreach": foreach,
- "maximize": maximize,
- "weight_decay": weight_decay,
- },
- )
- for optimizer_ctor, foreach, maximize, weight_decay in product(
- (
- torch.optim.Adamax,
- torch.optim.ASGD,
- torch.optim.Adadelta,
- torch.optim.RMSprop,
- ),
- (False, True),
- (False, True),
- (0, 0.1),
- )
- ]
- )
-
- for optimizer_ctor, kwargs in cases:
- with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
- self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)
-
- @unittest.skipIf(
- not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
- )
def test_graph_optims_with_explicitly_capturable_param_groups(self):
# mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__
n_warmup, n_replay = 3, 2
@@ -3615,125 +3533,6 @@
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
- def test_graph_scaling_fused_optimizers(self):
- cases = [
- (
- optimizer_ctor,
- {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
- )
- for optimizer_ctor, amsgrad in product(
- (torch.optim.Adam, torch.optim.AdamW), (False, True)
- )
- ] + list(
- product(
- (torch.optim.SGD,),
- [
- {
- "lr": 0.1,
- "momentum": 0.0,
- "dampening": d,
- "weight_decay": w,
- "nesterov": n,
- "fused": True,
- }
- for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
- ]
- + [
- {
- "lr": 0.1,
- "momentum": 0.5,
- "dampening": d,
- "weight_decay": w,
- "nesterov": n,
- "fused": True,
- }
- for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
- ],
- )
- )
-
- steps_warmup = 3
- steps_train = 2
-
- for OptClass, kwargs in cases:
- has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW)
- for actually_do_graphs in (True, False) if has_capturable_arg else (True,):
- params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
- params_control = [p.clone().requires_grad_() for p in params]
- params_graphed = [p.clone().requires_grad_() for p in params]
-
- # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
- grads = [
- [torch.randn_like(p) for p in params]
- for _ in range(steps_warmup + steps_train)
- ]
- with torch.no_grad():
- grads_control = [[g.clone() for g in gs] for gs in grads]
- grads_graphed = [[g.clone() for g in gs] for gs in grads]
-
- # Gradient Scaler
- scaler_for_control = torch.amp.GradScaler(
- device="cuda", init_scale=128.0
- )
- with torch.no_grad():
- scaler_for_control._lazy_init_scale_growth_tracker(
- torch.device("cuda")
- )
-
- scaler_for_graphed = torch.amp.GradScaler(device="cuda")
- scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
- with torch.no_grad():
- scaler_for_graphed._lazy_init_scale_growth_tracker(
- torch.device("cuda")
- )
-
- # Control (capturable=False)
- if has_capturable_arg:
- kwargs["capturable"] = False
- opt = OptClass(params_control, **kwargs)
-
- for i in range(steps_warmup + steps_train):
- for j, p in enumerate(params_control):
- p.grad = grads_control[i][j]
- scaler_for_control.step(opt)
- scaler_for_control.update()
-
- # capturable=True
- if has_capturable_arg:
- kwargs["capturable"] = True
- opt = OptClass(params_graphed, **kwargs)
-
- for i in range(steps_warmup):
- for j, p in enumerate(params_graphed):
- p.grad = grads_graphed[i][j]
- scaler_for_graphed.step(opt)
- scaler_for_graphed.update()
-
- if actually_do_graphs:
- g = torch.cuda.CUDAGraph()
- with torch.cuda.graph(g):
- scaler_for_graphed.step(opt)
- scaler_for_graphed.update()
-
- for i in range(steps_train):
- if actually_do_graphs:
- for j, p in enumerate(params_graphed):
- p.grad.copy_(grads_graphed[i + steps_warmup][j])
- g.replay()
- else:
- # Passing capturable=True to the constructor and running without graphs should still be
- # numerically correct, even if it's not ideal for performance.
- for j, p in enumerate(params_graphed):
- p.grad = grads_graphed[i + steps_warmup][j]
- scaler_for_graphed.step(opt)
- scaler_for_graphed.update()
-
- for p_control, p_graphed in zip(params_control, params_graphed):
- self.assertEqual(p_control, p_graphed)
-
- @unittest.skipIf(
- not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
- )
def test_cuda_graph_error_options(self):
def fn():
x = torch.zeros([2000], device="cuda")
@@ -5082,10 +4881,179 @@
self.assertEqual(len(set(active_pool_ids)), 4)
+@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaOptims(TestCase):
# These tests will be instantiate with instantiate_device_type_tests
# to apply the new OptimizerInfo structure.
+ @onlyCUDA
+ @unittest.skipIf(
+ not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
+ )
+ @optims(
+ [optim for optim in optim_db if optim.has_capturable_arg],
+ dtypes=[torch.float32],
+ )
+ def test_graph_optims(self, device, dtype, optim_info):
+ optim_cls = optim_info.optim_cls
+ all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
+ device, dtype, optim_info, skip=("differentiable",)
+ )
+
+ steps_warmup = 3
+ steps_train = 2
+
+ for optim_input in all_optim_inputs:
+ kwargs = optim_input.kwargs
+
+ # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
+ # and torch.optim.adamw
+ kwargs["lr"] = 0.1
+
+ for actually_do_graphs in (True, False):
+ params = [
+ torch.randn((i + 5, i + 5), device=device) for i in range(2)
+ ] + [torch.randn((), device=device)]
+ params_control = [p.clone().requires_grad_() for p in params]
+ params_graphed = [p.clone().requires_grad_() for p in params]
+
+ grads = [
+ [torch.randn_like(p) for p in params]
+ for _ in range(steps_warmup + steps_train)
+ ]
+
+ # Control (capturable=False)
+ kwargs["capturable"] = False
+
+ opt = optim_cls(params_control, **kwargs)
+ for i in range(steps_warmup + steps_train):
+ for j, p in enumerate(params_control):
+ p.grad = grads[i][j]
+ opt.step()
+
+ # capturable=True
+ kwargs["capturable"] = True
+ opt = optim_cls(params_graphed, **kwargs)
+
+ for i in range(steps_warmup):
+ for j, p in enumerate(params_graphed):
+ p.grad = grads[i][j]
+ opt.step()
+
+ if actually_do_graphs:
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ opt.step()
+
+ for i in range(steps_train):
+ if actually_do_graphs:
+ for j, p in enumerate(params_graphed):
+ p.grad.copy_(grads[i + steps_warmup][j])
+ g.replay()
+ else:
+ # Passing capturable=True to the constructor and running without graphs should still be
+ # numerically correct, even if it's not ideal for performance.
+ for j, p in enumerate(params_graphed):
+ p.grad = grads[i + steps_warmup][j]
+ opt.step()
+
+ for p_control, p_graphed in zip(params_control, params_graphed):
+ self.assertEqual(p_control, p_graphed)
+
+ @onlyCUDA
+ @unittest.skipIf(
+ not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
+ )
+ @optims(
+ [
+ optim
+ for optim in optim_db
+ if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on
+ ],
+ dtypes=[torch.float32],
+ )
+ def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
+ optim_cls = optim_info.optim_cls
+
+ steps_warmup = 3
+ steps_train = 2
+
+ optim_inputs = optim_info.optim_inputs_func(device=device)
+
+ for optim_input in optim_inputs:
+ kwargs = optim_input.kwargs
+ kwargs["fused"] = True
+
+ for actually_do_graphs in (
+ (True, False) if optim_info.has_capturable_arg else (True,)
+ ):
+ params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
+ params_control = [p.clone().requires_grad_() for p in params]
+ params_graphed = [p.clone().requires_grad_() for p in params]
+
+ # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
+ grads = [
+ [torch.randn_like(p) for p in params]
+ for _ in range(steps_warmup + steps_train)
+ ]
+ with torch.no_grad():
+ grads_control = [[g.clone() for g in gs] for gs in grads]
+ grads_graphed = [[g.clone() for g in gs] for gs in grads]
+
+ # Gradient Scaler
+ scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
+ with torch.no_grad():
+ scaler_for_control._lazy_init_scale_growth_tracker(device)
+
+ scaler_for_graphed = torch.cuda.amp.GradScaler()
+ scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
+ with torch.no_grad():
+ scaler_for_graphed._lazy_init_scale_growth_tracker(device)
+
+ # Control (capturable=False)
+ if optim_info.has_capturable_arg:
+ kwargs["capturable"] = False
+ opt = optim_cls(params_control, **kwargs)
+
+ for i in range(steps_warmup + steps_train):
+ for j, p in enumerate(params_control):
+ p.grad = grads_control[i][j]
+ scaler_for_control.step(opt)
+ scaler_for_control.update()
+
+ # capturable=True
+ if optim_info.has_capturable_arg:
+ kwargs["capturable"] = True
+ opt = optim_cls(params_graphed, **kwargs)
+
+ for i in range(steps_warmup):
+ for j, p in enumerate(params_graphed):
+ p.grad = grads_graphed[i][j]
+ scaler_for_graphed.step(opt)
+ scaler_for_graphed.update()
+
+ if actually_do_graphs:
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ scaler_for_graphed.step(opt)
+ scaler_for_graphed.update()
+
+ for i in range(steps_train):
+ if actually_do_graphs:
+ for j, p in enumerate(params_graphed):
+ p.grad.copy_(grads_graphed[i + steps_warmup][j])
+ g.replay()
+ else:
+ # Passing capturable=True to the constructor and running without graphs should still be
+ # numerically correct, even if it's not ideal for performance.
+ for j, p in enumerate(params_graphed):
+ p.grad = grads_graphed[i + steps_warmup][j]
+ scaler_for_graphed.step(opt)
+ scaler_for_graphed.update()
+
+ for p_control, p_graphed in zip(params_control, params_graphed):
+ self.assertEqual(p_control, p_graphed)
+
@onlyNativeDeviceTypes
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py
index 7d4d225..bd8f1f2 100644
--- a/torch/testing/_internal/common_optimizers.py
+++ b/torch/testing/_internal/common_optimizers.py
@@ -132,6 +132,8 @@
),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
+ # the optimizer constructor supports passing in capturable as a kwarg
+ has_capturable_arg: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
only_supports_sparse_grads: bool = False,
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
@@ -157,6 +159,7 @@
self.supported_impls = supported_impls
self.not_og_supported_flags = not_og_supported_flags
self.supports_sparse = supports_sparse
+ self.has_capturable_arg = has_capturable_arg
self.metadata_for_sparse = metadata_for_sparse
self.only_supports_sparse_grads = only_supports_sparse_grads
self.supports_complex = supports_complex
@@ -330,10 +333,11 @@
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
+ OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
- desc="maximize",
+ desc="maximize, weight_decay",
),
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
@@ -631,9 +635,14 @@
),
OptimizerInput(
params=None,
- kwargs={"weight_decay": 0.1, "maximize": True},
+ kwargs={"maximize": True},
desc="maximize",
),
+ OptimizerInput(
+ params=None,
+ kwargs={"weight_decay": 0.1, "maximize": True},
+ desc="maximize, weight_decay",
+ ),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
@@ -788,11 +797,18 @@
),
OptimizerInput(
params=None,
- kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+ kwargs={
+ "weight_decay": 0.1,
+ },
desc="weight_decay",
),
OptimizerInput(
params=None,
+ kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+ desc="weight_decay, momentum_decay",
+ ),
+ OptimizerInput(
+ params=None,
kwargs={
"weight_decay": 0.1,
"momentum_decay": 6e-3,
@@ -935,11 +951,26 @@
),
OptimizerInput(
params=None,
+ kwargs={
+ "maximize": True,
+ },
+ desc="maximize",
+ ),
+ OptimizerInput(
+ params=None,
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
),
OptimizerInput(
params=None,
+ kwargs={
+ "maximize": True,
+ "weight_decay": 0.1,
+ },
+ desc="maximize, weight_decay",
+ ),
+ OptimizerInput(
+ params=None,
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
desc="momentum",
),
@@ -951,7 +982,7 @@
"momentum": 0.1,
"maximize": True,
},
- desc="maximize",
+ desc="maximize, centered, weight_decay, w/ momentum",
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
@@ -1022,27 +1053,30 @@
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
),
+ OptimizerInput(
+ params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
+ ),
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
OptimizerInput(
params=None,
+ kwargs={"weight_decay": 0.1, "maximize": True},
+ desc="maximize",
+ ),
+ OptimizerInput(
+ params=None,
kwargs={"momentum": 0.9, "dampening": 0.5},
desc="dampening",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "weight_decay": 0.1},
- desc="non-zero weight_decay",
+ desc="weight_decay w/ momentum",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
desc="nesterov",
),
- OptimizerInput(
- params=None,
- kwargs={"weight_decay": 0.1, "maximize": True},
- desc="maximize",
- ),
]
@@ -1208,6 +1242,7 @@
optim_inputs_func=optim_inputs_func_adadelta,
optim_error_inputs_func=optim_error_inputs_func_adadelta,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@@ -1493,6 +1528,7 @@
),
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
+ has_capturable_arg=True,
not_og_supported_flags=(
"foreach",
"differentiable",
@@ -1578,6 +1614,7 @@
optim_inputs_func=optim_inputs_func_adamax,
optim_error_inputs_func=optim_error_inputs_func_adamax,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@@ -1630,6 +1667,7 @@
"capturable",
),
supports_fused_on=("cpu", "cuda", "mps"),
+ has_capturable_arg=True,
decorators=(
# Expected error between compiled forloop and fused optimizers
DecorateInfo(
@@ -1710,6 +1748,7 @@
optim_inputs_func=optim_inputs_func_asgd,
optim_error_inputs_func=optim_error_inputs_func_asgd,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@@ -1822,6 +1861,7 @@
optim_inputs_func=optim_inputs_func_nadam,
optim_error_inputs_func=optim_error_inputs_func_nadam,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@@ -1870,6 +1910,7 @@
optim_inputs_func=optim_inputs_func_radam,
optim_error_inputs_func=optim_error_inputs_func_radam,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@@ -1915,6 +1956,7 @@
optim_inputs_func=optim_inputs_func_rmsprop,
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@@ -1964,6 +2006,7 @@
optim_inputs_func=optim_inputs_func_rprop,
optim_error_inputs_func=optim_error_inputs_func_rprop,
supported_impls=("foreach", "differentiable"),
+ has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # Rprop doesn't update for non-contiguous, see #118117