Implement "RAdamW" optimizer (#107507)

Fixes #107282

## Overview

- basic design decision was followed as they made on #103881 (tensor operation, test cases, order & position of argument etc.)
- for the algorithm for decoupled weight decay, I referred to [1, 2]

## backwards-incompatible changes

- positional argument `decoupled_weight_decay` is added to:
    -  `torch.optim.radam`

The existing code which refers to these APIs can be affected.

Note: Positional argument `decoupled_weight_decay` is added to `torch.optim.RAdam`. However, since it was added to the last position and with default value, it is not affected.

## Reference

- [1] [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
- [2] https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam/radam.py#L5-L94

## TODO

- [x] implement tensor operation
- [x] implement test cases
- [x] modify doc-string
- [x] pass unit test code locally `python test/test_optim.py -k test_radam`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107507
Approved by: https://github.com/janeyx99
diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py
index 812d52b..1677869 100644
--- a/test/optim/test_optim.py
+++ b/test/optim/test_optim.py
@@ -884,6 +884,8 @@
             (optim.RAdam, dict(weight_decay=0)),
             (optim.RAdam, dict(weight_decay=1, eps=1e-6)),
             (optim.RAdam, dict(weight_decay=1)),
+            (optim.RAdam, dict(weight_decay=0, decoupled_weight_decay=True)),
+            (optim.RAdam, dict(weight_decay=1, decoupled_weight_decay=True)),
             (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
             (optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
             (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
@@ -1500,6 +1502,23 @@
             ],
             constructor_accepts_foreach=True,
         )
+        # RAdamW tests
+        self._test_basic_cases(
+            lambda weight, bias, foreach: optim.RAdam(
+                [weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
+            ),
+            constructor_accepts_foreach=True,
+        )
+        self._test_basic_cases(
+            lambda weight, bias, foreach: optim.RAdam(
+                [weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
+            ),
+            [
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            constructor_accepts_foreach=True,
+        )
         with self.assertRaisesRegex(
             ValueError, "Invalid beta parameter at index 0: 1.0"
         ):
@@ -2359,6 +2378,17 @@
                 *state.values(),
             ),
         )
+        gradcheck(
+            _diff_fn,
+            (
+                p,
+                grad,
+                state,
+                torch.optim.RAdam,
+                {"lr": 0.9, "weight_decay": 0.1, "decoupled_weight_decay": True, "differentiable": True},
+                *state.values(),
+            ),
+        )
 
 
     @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
diff --git a/torch/optim/radam.py b/torch/optim/radam.py
index 4e57426..9753ab3 100644
--- a/torch/optim/radam.py
+++ b/torch/optim/radam.py
@@ -1,10 +1,19 @@
 import math
+from typing import List, Optional
+
 import torch
 from torch import Tensor
 
-from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
-                        _default_to_fused_or_foreach, _differentiable_doc, _foreach_doc)
-from typing import List, Optional
+from .optimizer import (
+    Optimizer,
+    _default_to_fused_or_foreach,
+    _differentiable_doc,
+    _dispatch_sqrt,
+    _foreach_doc,
+    _get_value,
+    _stack_if_compiling,
+    _use_grad_for_differentiable,
+)
 
 __all__ = ["RAdam", "radam"]
 
@@ -17,6 +26,7 @@
         betas=(0.9, 0.999),
         eps=1e-8,
         weight_decay=0,
+        decoupled_weight_decay: bool = False,
         *,
         foreach: Optional[bool] = None,
         differentiable: bool = False,
@@ -37,6 +47,7 @@
             eps=eps,
             weight_decay=weight_decay,
             foreach=foreach,
+            decoupled_weight_decay=decoupled_weight_decay,
             differentiable=differentiable,
         )
         super().__init__(params, defaults)
@@ -46,6 +57,7 @@
         for group in self.param_groups:
             group.setdefault("foreach", None)
             group.setdefault("differentiable", False)
+            group.setdefault("decoupled_weight_decay", False)
         state_values = list(self.state.values())
         step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
             state_values[0]["step"]
@@ -115,6 +127,7 @@
                 eps=group["eps"],
                 foreach=group["foreach"],
                 differentiable=group["differentiable"],
+                decoupled_weight_decay=group["decoupled_weight_decay"],
             )
 
         return loss
@@ -128,15 +141,19 @@
             &\textbf{input}      : \gamma \text{ (lr)}, \: \beta_1, \beta_2
                 \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
                 \lambda \text{ (weightdecay)},                                                   \\
-            &\hspace{13mm} \epsilon \text{ (epsilon)}                                            \\
+            &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay}         \\
             &\textbf{initialize} :  m_0 \leftarrow 0 \text{ ( first moment)},
                 v_0 \leftarrow 0 \text{ ( second moment)},                                       \\
             &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1                      \\[-1.ex]
             &\rule{110mm}{0.4pt}  \\
             &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
-            &\hspace{6mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
-            &\hspace{5mm} \textbf{if} \: \lambda \neq 0                                          \\
-            &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1}                             \\
+            &\hspace{6mm} g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1})                      \\
+            &\hspace{6mm} \theta_t \leftarrow \theta_{t-1}                                       \\
+            &\hspace{6mm} \textbf{if} \: \lambda \neq 0                                          \\
+            &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay}                       \\
+            &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t}            \\
+            &\hspace{12mm}\textbf{else}                                                          \\
+            &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t}                               \\
             &\hspace{6mm}m_t           \leftarrow   \beta_1 m_{t-1} + (1 - \beta_1) g_t          \\
             &\hspace{6mm}v_t           \leftarrow   \beta_2 v_{t-1} + (1-\beta_2) g^2_t          \\
             &\hspace{6mm}\widehat{m_t} \leftarrow   m_t/\big(1-\beta_1^t \big)                   \\
@@ -146,9 +163,9 @@
             &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon  } \\
             &\hspace{12mm} r_t \leftarrow
       \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
-            &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t        \\
+            &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t        \\
             &\hspace{6mm}\textbf{else}                                                           \\
-            &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}                \\
+            &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}                \\
             &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
             &\bf{return} \:  \theta_t                                                     \\[-1.ex]
             &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
@@ -156,9 +173,13 @@
 
     For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
 
-    This implementation uses the same weight_decay implementation as Adam (were the weight_decay is applied
-    to the gradient) and not the one from AdamW (were weight_decay is applied to the update). This
-    is different from the `author's implementation`_.
+    This implementation provides an option to use either the original weight_decay implementation as in Adam
+    (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied
+    to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False
+    (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which
+    corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information
+    about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_.
+
     """ + fr"""
     Args:
         params (iterable): iterable of parameters to optimize or dicts defining
@@ -169,6 +190,8 @@
         eps (float, optional): term added to the denominator to improve
             numerical stability (default: 1e-8)
         weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        decoupled_weight_decay (bool, optional): whether to use decoupled weight
+            decay as in AdamW to obtain RAdamW (default: False)
         {_foreach_doc}
         {_differentiable_doc}
 
@@ -176,6 +199,8 @@
         https://arxiv.org/abs/1908.03265
     .. _author's implementation:
         https://github.com/LiyuanLucasLiu/RAdam
+    .. _Decoupled Weight Decay Regularization:
+        https://arxiv.org/abs/1711.05101
 
     """
 
@@ -188,6 +213,7 @@
     state_steps: List[Tensor],
     # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
     # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
+    decoupled_weight_decay: bool = False,
     foreach: Optional[bool] = None,
     differentiable: bool = False,
     *,
@@ -229,6 +255,7 @@
         lr=lr,
         weight_decay=weight_decay,
         eps=eps,
+        decoupled_weight_decay=decoupled_weight_decay,
         differentiable=differentiable,
     )
 
@@ -246,6 +273,7 @@
     weight_decay: float,
     eps: float,
     differentiable: bool,
+    decoupled_weight_decay: bool,
 ):
 
     for i, param in enumerate(params):
@@ -261,7 +289,10 @@
         bias_correction2 = 1 - beta2 ** step
 
         if weight_decay != 0:
-            grad = grad.add(param, alpha=weight_decay)
+            if decoupled_weight_decay:
+                param.mul_(1 - lr * weight_decay)
+            else:
+                grad = grad.add(param, alpha=weight_decay)
 
         # Decay the first and second moment running average coefficient
         exp_avg.lerp_(grad, 1 - beta1)
@@ -306,6 +337,7 @@
     lr: float,
     weight_decay: float,
     eps: float,
+    decoupled_weight_decay: bool,
     differentiable: bool,
 ):
 
@@ -332,7 +364,10 @@
                       (1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
 
         if weight_decay != 0:
-            grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
+            if decoupled_weight_decay:
+                torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
+            else:
+                grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
 
         # Decay the first and second moment running average coefficient
         torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
diff --git a/torch/optim/radam.pyi b/torch/optim/radam.pyi
index e0bbd53..19750ea 100644
--- a/torch/optim/radam.pyi
+++ b/torch/optim/radam.pyi
@@ -10,4 +10,5 @@
         betas: Tuple[float, float] = ...,
         eps: float = ...,
         weight_decay: float = ...,
+        decoupled_weight_decay: bool = ...,
     ) -> None: ...