Add beta1 support to CyclicLR momentum (#113548)

Fixes #73910

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113548
Approved by: https://github.com/janeyx99
diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py
index 20c5028..02199e4 100644
--- a/test/optim/test_lrscheduler.py
+++ b/test/optim/test_lrscheduler.py
@@ -8,7 +8,7 @@
 import torch
 import torch.nn.functional as F
 from torch.nn import Parameter
-from torch.optim import Adam, SGD
+from torch.optim import Adam, SGD, Rprop
 from torch.optim.lr_scheduler import (
     LambdaLR,
     MultiplicativeLR,
@@ -1510,8 +1510,12 @@
 
     def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
         with self.assertRaises(ValueError):
-            adam_opt = Adam(self.net.parameters())
-            scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
+            rprop_opt = Rprop(self.net.parameters())
+            scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
+
+    def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
+        adam_opt = Adam(self.net.parameters())
+        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
 
     def test_cycle_lr_removed_after_out_of_scope(self):
         import gc
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index df659b6..a681142 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1268,15 +1268,20 @@
 
         self.cycle_momentum = cycle_momentum
         if cycle_momentum:
-            if 'momentum' not in optimizer.defaults:
-                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+            if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
+                raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
 
-            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
-            if last_epoch == -1:
-                for momentum, group in zip(base_momentums, optimizer.param_groups):
-                    group['momentum'] = momentum
-            self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
+            self.use_beta1 = 'betas' in self.optimizer.defaults
+            self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
             self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+            if last_epoch == -1:
+                for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
+                    if self.use_beta1:
+                        group['betas'] = (m_momentum, *group['betas'][1:])
+                    else:
+                        group['momentum'] = m_momentum
+                    group['max_momentum'] = m_momentum
+                    group['base_momentum'] = b_momentum
 
         super().__init__(optimizer, last_epoch, verbose)
         self.base_lrs = base_lrs
@@ -1359,7 +1364,10 @@
                     momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
                 momentums.append(momentum)
             for param_group, momentum in zip(self.optimizer.param_groups, momentums):
-                param_group['momentum'] = momentum
+                if self.use_beta1:
+                    param_group['betas'] = (momentum, *param_group['betas'][1:])
+                else:
+                    param_group['momentum'] = momentum
 
         return lrs
 
@@ -1721,7 +1729,7 @@
         self.cycle_momentum = cycle_momentum
         if self.cycle_momentum:
             if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
-                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+                raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
             self.use_beta1 = 'betas' in self.optimizer.defaults
             max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
             base_momentums = self._format_param('base_momentum', optimizer, base_momentum)