[optim] Fix: wrong ASGD implementation (#126375)
This PR is based on #125440, additionally merging the latest main branch and fixing the lint failures from #126361.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126375
Approved by: https://github.com/janeyx99
diff --git a/test/test_optim.py b/test/test_optim.py
index 717e892..7fa612e 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -604,8 +604,16 @@
for input, model, optimizer in zip(inputs, models, optimizers):
optimizer.zero_grad()
+ if i == 3:
+ # Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
+ # is same as the step in 'forloop'.
+ model[2].requires_grad_(False)
+ if i == 5:
+ # Unfreeze the layer after 2 iters.
+ model[2].requires_grad_(True)
+
# Test that step behaves as expected (a no-op) when grads are set to None
- if i != 3:
+ if i != 2:
output = model(input)
loss = output.sum()
loss.backward()
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index a4edd83..93e45bf 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -19,6 +19,7 @@
corresponding_real_dtype,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
+ FloatLike,
IntLike,
make_contiguous_strides_for,
Number,
@@ -3286,6 +3287,15 @@
return
+@register_meta([aten._foreach_pow_.Scalar])
+def meta__foreach_pow__scalar(self, exponent):
+ torch._check(
+ isinstance(exponent, FloatLike),
+ lambda: f"exponent must be a float but got {type(exponent)}",
+ )
+ return
+
+
@register_meta([aten._foreach_pow.ScalarAndTensor])
def meta__foreach_pow_scalar_and_tensor(self, exponent):
# Only foreach_pow has a ScalarAndTensor method and needs special
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index a87aadc..f53f8b4 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -22,13 +22,6 @@
__all__ = ["ASGD", "asgd"]
-def _to_tensor(x, device=None):
- if not isinstance(x, torch.Tensor):
- return torch.tensor(x, device=device)
-
- return x
-
-
class ASGD(Optimizer):
def __init__(
self,
@@ -264,9 +257,9 @@
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:
step = _get_value(step_t)
- new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha))
+ new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
eta.copy_(new_eta)
- new_mu = _to_tensor(1 / max(1, step - t0))
+ new_mu = torch.as_tensor(1 / max(1, step - t0))
mu.copy_(new_mu)
@@ -381,27 +374,23 @@
torch._foreach_copy_(grouped_mus, new_mus)
del new_mus
- # update eta = lr / (1 + lambd * lr * step^alpha)
- new_etas = torch._foreach_pow(grouped_state_steps, alpha)
- torch._foreach_mul_(new_etas, lambd)
+ # update eta = lr / ((1 + lambd * lr * step)^alpha)
+ new_etas = torch._foreach_mul(grouped_state_steps, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
+ torch._foreach_pow_(new_etas, alpha)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
- step = grouped_state_steps[0].item()
- new_etas = []
- new_mus = []
-
- for i in range(len(grouped_mus)):
- new_eta = _to_tensor(
- lr / (1 + lambd * lr * step**alpha), device=device
- )
- new_etas.append(new_eta)
- new_mu = _to_tensor(1 / max(1, step - t0), device=device)
- new_mus.append(new_mu)
-
+ new_etas = [
+ torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
+ for step in grouped_state_steps
+ ]
+ new_mus = [
+ torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
+ for step in grouped_state_steps
+ ]
torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)
diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py
index 5a66923..c81efb0 100644
--- a/torch/testing/_internal/common_optimizers.py
+++ b/torch/testing/_internal/common_optimizers.py
@@ -590,6 +590,7 @@
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
+ OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
@@ -1450,6 +1451,13 @@
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
+ DecorateInfo(
+ unittest.skip(
+ "ASGD internally changes the weights even with zero grad"
+ ),
+ "TestOptimRenewed",
+ "test_step_is_noop_for_zero_grads",
+ ),
),
),
OptimizerInfo(