Optim foreach cleanup for Rprop (#70483)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70483

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D33767866

Pulled By: mikaylagawarecki

fbshipit-source-id: ffc5ae68eeea8fa09385862b853b731554b77bcb
(cherry picked from commit 3a0fe295807bb4519884a1838edeea1a9d222e41)
diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py
index 3302822..ed6ebdd 100644
--- a/torch/distributed/optim/functional_rprop.py
+++ b/torch/distributed/optim/functional_rprop.py
@@ -21,6 +21,7 @@
         lr: float = 1e-2,
         etas: Tuple[float, float] = (0.5, 1.2),
         step_sizes: Tuple[float, float] = (1e-6, 50),
+        foreach: bool = False,
         _allow_empty_param_list: bool = False,
     ):
         self.defaults = {
@@ -28,6 +29,7 @@
         }
         self.etas = etas
         self.step_sizes = step_sizes
+        self.foreach = foreach
 
         if len(params) == 0 and not _allow_empty_param_list:
             raise ValueError("optimizer got an empty parameter list")
@@ -81,4 +83,5 @@
                     step_size_min=step_size_min,
                     step_size_max=step_size_max,
                     etaminus=etaminus,
-                    etaplus=etaplus)
+                    etaplus=etaplus,
+                    foreach=self.foreach)
diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py
index 64fd7a4..da13716 100644
--- a/torch/optim/_functional.py
+++ b/torch/optim/_functional.py
@@ -12,6 +12,7 @@
 from .nadam import nadam  # type: ignore[attr-defined] # noqa: F401
 from .radam import radam  # type: ignore[attr-defined] # noqa: F401
 from .rmsprop import rmsprop  # type: ignore[attr-defined] # noqa: F401
+from .rprop import rprop  # type: ignore[attr-defined] # noqa: F401
 from .sgd import sgd  # type: ignore[attr-defined] # noqa: F401
 
 
@@ -70,44 +71,6 @@
         param.addcdiv_(exp_avg, denom, value=-step_size)
 
 
-def rprop(params: List[Tensor],
-          grads: List[Tensor],
-          prevs: List[Tensor],
-          step_sizes: List[Tensor],
-          *,
-          step_size_min: float,
-          step_size_max: float,
-          etaminus: float,
-          etaplus: float):
-    r"""Functional API that performs rprop algorithm computation.
-
-    See :class:`~torch.optim.Rprop` for details.
-    """
-
-    for i, param in enumerate(params):
-        grad = grads[i]
-        prev = prevs[i]
-        step_size = step_sizes[i]
-
-        sign = grad.mul(prev).sign()
-        sign[sign.gt(0)] = etaplus
-        sign[sign.lt(0)] = etaminus
-        sign[sign.eq(0)] = 1
-
-        # update stepsizes with step size updates
-        step_size.mul_(sign).clamp_(step_size_min, step_size_max)
-
-        # for dir<0, dfdx=0
-        # for dir>=0 dfdx=dfdx
-        grad = grad.clone(memory_format=torch.preserve_format)
-        grad[sign.eq(etaminus)] = 0
-
-        # update parameters
-        param.addcmul_(grad.sign(), step_size, value=-1)
-
-        prev.copy_(grad)
-
-
 def sparse_adam(params: List[Tensor],
                 grads: List[Tensor],
                 exp_avgs: List[Tensor],
diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py
index 16300a9..e9e6b13 100644
--- a/torch/optim/_multi_tensor/__init__.py
+++ b/torch/optim/_multi_tensor/__init__.py
@@ -13,11 +13,10 @@
 SGD = partial(optim.SGD, foreach=True)
 RAdam = partial(optim.RAdam, foreach=True)
 RMSprop = partial(optim.RMSprop, foreach=True)
-from .rprop import Rprop
+Rprop = partial(optim.Rprop, foreach=True)
 ASGD = partial(optim.ASGD, foreach=True)
 Adamax = partial(optim.Adamax, foreach=True)
 Adadelta = partial(optim.Adadelta, foreach=True)
 Adagrad = partial(optim.Adagrad, foreach=True)
 
 del adamw
-del rprop
diff --git a/torch/optim/_multi_tensor/__init__.pyi b/torch/optim/_multi_tensor/__init__.pyi
index 3d3c602..812d9fc 100644
--- a/torch/optim/_multi_tensor/__init__.pyi
+++ b/torch/optim/_multi_tensor/__init__.pyi
@@ -7,7 +7,7 @@
 SGD = partial(optim.SGD, foreach=True)
 RAdam = partial(optim.RAdam, foreach=True)
 RMSprop = partial(optim.RMSprop, foreach=True)
-from .rprop import Rprop as Rprop
+Rprop = partial(optim.Rprop, foreach=True)
 ASGD = partial(optim.ASGD, foreach=True)
 Adamax = partial(optim.Adamax, foreach=True)
 Adadelta = partial(optim.Adadelta, foreach=True)
diff --git a/torch/optim/_multi_tensor/rprop.py b/torch/optim/_multi_tensor/rprop.py
deleted file mode 100644
index 67baf1e..0000000
--- a/torch/optim/_multi_tensor/rprop.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import torch
-from ..optimizer import Optimizer
-
-class Rprop(Optimizer):
-    """Implements the resilient backpropagation algorithm.
-
-    Args:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float, optional): learning rate (default: 1e-2)
-        etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that
-            are multiplicative increase and decrease factors
-            (default: (0.5, 1.2))
-        step_sizes (Tuple[float, float], optional): a pair of minimal and
-            maximal allowed step sizes (default: (1e-6, 50))
-    """
-
-    def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):
-        if not 0.0 <= lr:
-            raise ValueError("Invalid learning rate: {}".format(lr))
-        if not 0.0 < etas[0] < 1.0 < etas[1]:
-            raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
-
-        defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=True)
-        super(Rprop, self).__init__(params, defaults)
-
-    @torch.no_grad()
-    def step(self, closure=None):
-        """Performs a single optimization step.
-
-        Args:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
-        """
-        loss = None
-        if closure is not None:
-            with torch.enable_grad():
-                loss = closure()
-
-        grads = []
-        states = []
-        params_with_grad = []
-        step_sizes = []
-
-        for group in self.param_groups:
-            for p in group['params']:
-                etaminus, etaplus = group['etas']
-                step_size_min, step_size_max = group['step_sizes']
-
-                if p.grad is not None:
-                    if p.grad.is_sparse:
-                        raise RuntimeError('RMSprop does not support sparse gradients')
-
-                    grads.append(p.grad)
-                    params_with_grad.append(p)
-
-                    state = self.state[p]
-                    # State initialization
-                    if len(state) == 0:
-                        state['step'] = 0
-                        state['prev'] = torch.zeros_like(p, memory_format=torch.preserve_format)
-                        state['step_size'] = p.grad.new().resize_as_(p.grad).fill_(group['lr'])
-
-                        state['step'] += 1
-
-                    states.append(state)
-                    step_sizes.append(state['step_size'])
-
-            signs = torch._foreach_mul(grads, [s['prev'] for s in states])
-            signs = [s.sign() for s in signs]
-            for sign in signs:
-                sign[sign.gt(0)] = etaplus
-                sign[sign.lt(0)] = etaminus
-                sign[sign.eq(0)] = 1
-
-            # update stepsizes with step size updates
-            torch._foreach_mul_(step_sizes, signs)
-            for step_size in step_sizes:
-                step_size.clamp_(step_size_min, step_size_max)
-
-            # for dir<0, dfdx=0
-            # for dir>=0 dfdx=dfdx
-            for i in range(len(grads)):
-                grads[i] = grads[i].clone(memory_format=torch.preserve_format)
-                grads[i][signs[i].eq(etaminus)] = 0
-
-            # update parameters
-            grad_signs = [grad.sign() for grad in grads]
-            torch._foreach_addcmul_(params_with_grad, grad_signs, step_sizes, value=-1)
-
-            for i in range(len(states)):
-                states[i]['prev'].copy_(grads[i])
-
-        return loss
diff --git a/torch/optim/_multi_tensor/rprop.pyi b/torch/optim/_multi_tensor/rprop.pyi
deleted file mode 100644
index 0ea64c6..0000000
--- a/torch/optim/_multi_tensor/rprop.pyi
+++ /dev/null
@@ -1,5 +0,0 @@
-from typing import Tuple
-from ..optimizer import _params_t, Optimizer
-
-class Rprop(Optimizer):
-    def __init__(self, params: _params_t, lr: float=..., etas: Tuple[float, float]=..., step_sizes: Tuple[float, float]=...) -> None: ...
diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py
index 741f6de..f580e35 100644
--- a/torch/optim/rprop.py
+++ b/torch/optim/rprop.py
@@ -1,6 +1,7 @@
 import torch
-from . import _functional as F
+from torch import Tensor
 from .optimizer import Optimizer
+from typing import List, Optional
 
 
 class Rprop(Optimizer):
@@ -47,17 +48,25 @@
             (default: (0.5, 1.2))
         step_sizes (Tuple[float, float], optional): a pair of minimal and
             maximal allowed step sizes (default: (1e-6, 50))
+        foreach (bool, optional): whether foreach implementation of optimizer
+            is used (default: None)
     """
 
-    def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):
+    def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50),
+                 foreach: Optional[bool] = None):
         if not 0.0 <= lr:
             raise ValueError("Invalid learning rate: {}".format(lr))
         if not 0.0 < etas[0] < 1.0 < etas[1]:
             raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
 
-        defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes)
+        defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=foreach)
         super(Rprop, self).__init__(params, defaults)
 
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('foreach', None)
+
     @torch.no_grad()
     def step(self, closure=None):
         """Performs a single optimization step.
@@ -76,6 +85,9 @@
             grads = []
             prevs = []
             step_sizes = []
+            etaminus, etaplus = group['etas']
+            step_size_min, step_size_max = group['step_sizes']
+            foreach = group['foreach']
 
             for p in group['params']:
                 if p.grad is None:
@@ -97,18 +109,128 @@
                 prevs.append(state['prev'])
                 step_sizes.append(state['step_size'])
 
-                etaminus, etaplus = group['etas']
-                step_size_min, step_size_max = group['step_sizes']
-
                 state['step'] += 1
 
-            F.rprop(params,
-                    grads,
-                    prevs,
-                    step_sizes,
-                    step_size_min=step_size_min,
-                    step_size_max=step_size_max,
-                    etaminus=etaminus,
-                    etaplus=etaplus)
+            rprop(params,
+                  grads,
+                  prevs,
+                  step_sizes,
+                  step_size_min=step_size_min,
+                  step_size_max=step_size_max,
+                  etaminus=etaminus,
+                  etaplus=etaplus,
+                  foreach=foreach)
 
         return loss
+
+
+def rprop(params: List[Tensor],
+          grads: List[Tensor],
+          prevs: List[Tensor],
+          step_sizes: List[Tensor],
+          # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
+          # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
+          foreach: bool = None,
+          *,
+          step_size_min: float,
+          step_size_max: float,
+          etaminus: float,
+          etaplus: float):
+    r"""Functional API that performs rprop algorithm computation.
+
+    See :class:`~torch.optim.Rprop` for details.
+    """
+
+    if foreach is None:
+        # Placeholder for more complex foreach logic to be added when value is not set
+        foreach = False
+
+    if foreach and torch.jit.is_scripting():
+        raise RuntimeError('torch.jit.script not supported with foreach optimizers')
+
+    if foreach and not torch.jit.is_scripting():
+        func = _multi_tensor_rprop
+    else:
+        func = _single_tensor_rprop
+
+    func(params,
+         grads,
+         prevs,
+         step_sizes,
+         step_size_min=step_size_min,
+         step_size_max=step_size_max,
+         etaminus=etaminus,
+         etaplus=etaplus)
+
+
+def _single_tensor_rprop(params: List[Tensor],
+                         grads: List[Tensor],
+                         prevs: List[Tensor],
+                         step_sizes: List[Tensor],
+                         *,
+                         step_size_min: float,
+                         step_size_max: float,
+                         etaminus: float,
+                         etaplus: float):
+
+    for i, param in enumerate(params):
+        grad = grads[i]
+        prev = prevs[i]
+        step_size = step_sizes[i]
+
+        sign = grad.mul(prev).sign()
+        sign[sign.gt(0)] = etaplus
+        sign[sign.lt(0)] = etaminus
+        sign[sign.eq(0)] = 1
+
+        # update stepsizes with step size updates
+        step_size.mul_(sign).clamp_(step_size_min, step_size_max)
+
+        # for dir<0, dfdx=0
+        # for dir>=0 dfdx=dfdx
+        grad = grad.clone(memory_format=torch.preserve_format)
+        grad[sign.eq(etaminus)] = 0
+
+        # update parameters
+        param.addcmul_(grad.sign(), step_size, value=-1)
+
+        prev.copy_(grad)
+
+
+def _multi_tensor_rprop(params: List[Tensor],
+                        grads: List[Tensor],
+                        prevs: List[Tensor],
+                        step_sizes: List[Tensor],
+                        *,
+                        step_size_min: float,
+                        step_size_max: float,
+                        etaminus: float,
+                        etaplus: float):
+
+    if len(params) == 0:
+        return
+
+    signs = torch._foreach_mul(grads, prevs)
+    signs = [s.sign() for s in signs]
+    for sign in signs:
+        sign[sign.gt(0)] = etaplus
+        sign[sign.lt(0)] = etaminus
+        sign[sign.eq(0)] = 1
+
+    # update stepsizes with step size updates
+    torch._foreach_mul_(step_sizes, signs)
+    for step_size in step_sizes:
+        step_size.clamp_(step_size_min, step_size_max)
+
+    # for dir<0, dfdx=0
+    # for dir>=0 dfdx=dfdx
+    for i in range(len(grads)):
+        grads[i] = grads[i].clone(memory_format=torch.preserve_format)
+        grads[i][signs[i].eq(etaminus)] = 0
+
+    # update parameters
+    grad_signs = [grad.sign() for grad in grads]
+    torch._foreach_addcmul_(params, grad_signs, step_sizes, value=-1)
+
+    for i in range(len(prevs)):
+        prevs[i].copy_(grads[i])