Optim foreach cleanup for Rprop (#70483)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70483
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D33767866
Pulled By: mikaylagawarecki
fbshipit-source-id: ffc5ae68eeea8fa09385862b853b731554b77bcb
(cherry picked from commit 3a0fe295807bb4519884a1838edeea1a9d222e41)
diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py
index 3302822..ed6ebdd 100644
--- a/torch/distributed/optim/functional_rprop.py
+++ b/torch/distributed/optim/functional_rprop.py
@@ -21,6 +21,7 @@
lr: float = 1e-2,
etas: Tuple[float, float] = (0.5, 1.2),
step_sizes: Tuple[float, float] = (1e-6, 50),
+ foreach: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
@@ -28,6 +29,7 @@
}
self.etas = etas
self.step_sizes = step_sizes
+ self.foreach = foreach
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
@@ -81,4 +83,5 @@
step_size_min=step_size_min,
step_size_max=step_size_max,
etaminus=etaminus,
- etaplus=etaplus)
+ etaplus=etaplus,
+ foreach=self.foreach)
diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py
index 64fd7a4..da13716 100644
--- a/torch/optim/_functional.py
+++ b/torch/optim/_functional.py
@@ -12,6 +12,7 @@
from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
from .radam import radam # type: ignore[attr-defined] # noqa: F401
from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
+from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
@@ -70,44 +71,6 @@
param.addcdiv_(exp_avg, denom, value=-step_size)
-def rprop(params: List[Tensor],
- grads: List[Tensor],
- prevs: List[Tensor],
- step_sizes: List[Tensor],
- *,
- step_size_min: float,
- step_size_max: float,
- etaminus: float,
- etaplus: float):
- r"""Functional API that performs rprop algorithm computation.
-
- See :class:`~torch.optim.Rprop` for details.
- """
-
- for i, param in enumerate(params):
- grad = grads[i]
- prev = prevs[i]
- step_size = step_sizes[i]
-
- sign = grad.mul(prev).sign()
- sign[sign.gt(0)] = etaplus
- sign[sign.lt(0)] = etaminus
- sign[sign.eq(0)] = 1
-
- # update stepsizes with step size updates
- step_size.mul_(sign).clamp_(step_size_min, step_size_max)
-
- # for dir<0, dfdx=0
- # for dir>=0 dfdx=dfdx
- grad = grad.clone(memory_format=torch.preserve_format)
- grad[sign.eq(etaminus)] = 0
-
- # update parameters
- param.addcmul_(grad.sign(), step_size, value=-1)
-
- prev.copy_(grad)
-
-
def sparse_adam(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py
index 16300a9..e9e6b13 100644
--- a/torch/optim/_multi_tensor/__init__.py
+++ b/torch/optim/_multi_tensor/__init__.py
@@ -13,11 +13,10 @@
SGD = partial(optim.SGD, foreach=True)
RAdam = partial(optim.RAdam, foreach=True)
RMSprop = partial(optim.RMSprop, foreach=True)
-from .rprop import Rprop
+Rprop = partial(optim.Rprop, foreach=True)
ASGD = partial(optim.ASGD, foreach=True)
Adamax = partial(optim.Adamax, foreach=True)
Adadelta = partial(optim.Adadelta, foreach=True)
Adagrad = partial(optim.Adagrad, foreach=True)
del adamw
-del rprop
diff --git a/torch/optim/_multi_tensor/__init__.pyi b/torch/optim/_multi_tensor/__init__.pyi
index 3d3c602..812d9fc 100644
--- a/torch/optim/_multi_tensor/__init__.pyi
+++ b/torch/optim/_multi_tensor/__init__.pyi
@@ -7,7 +7,7 @@
SGD = partial(optim.SGD, foreach=True)
RAdam = partial(optim.RAdam, foreach=True)
RMSprop = partial(optim.RMSprop, foreach=True)
-from .rprop import Rprop as Rprop
+Rprop = partial(optim.Rprop, foreach=True)
ASGD = partial(optim.ASGD, foreach=True)
Adamax = partial(optim.Adamax, foreach=True)
Adadelta = partial(optim.Adadelta, foreach=True)
diff --git a/torch/optim/_multi_tensor/rprop.py b/torch/optim/_multi_tensor/rprop.py
deleted file mode 100644
index 67baf1e..0000000
--- a/torch/optim/_multi_tensor/rprop.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import torch
-from ..optimizer import Optimizer
-
-class Rprop(Optimizer):
- """Implements the resilient backpropagation algorithm.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-2)
- etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that
- are multiplicative increase and decrease factors
- (default: (0.5, 1.2))
- step_sizes (Tuple[float, float], optional): a pair of minimal and
- maximal allowed step sizes (default: (1e-6, 50))
- """
-
- def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 < etas[0] < 1.0 < etas[1]:
- raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
-
- defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=True)
- super(Rprop, self).__init__(params, defaults)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- grads = []
- states = []
- params_with_grad = []
- step_sizes = []
-
- for group in self.param_groups:
- for p in group['params']:
- etaminus, etaplus = group['etas']
- step_size_min, step_size_max = group['step_sizes']
-
- if p.grad is not None:
- if p.grad.is_sparse:
- raise RuntimeError('RMSprop does not support sparse gradients')
-
- grads.append(p.grad)
- params_with_grad.append(p)
-
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state['step'] = 0
- state['prev'] = torch.zeros_like(p, memory_format=torch.preserve_format)
- state['step_size'] = p.grad.new().resize_as_(p.grad).fill_(group['lr'])
-
- state['step'] += 1
-
- states.append(state)
- step_sizes.append(state['step_size'])
-
- signs = torch._foreach_mul(grads, [s['prev'] for s in states])
- signs = [s.sign() for s in signs]
- for sign in signs:
- sign[sign.gt(0)] = etaplus
- sign[sign.lt(0)] = etaminus
- sign[sign.eq(0)] = 1
-
- # update stepsizes with step size updates
- torch._foreach_mul_(step_sizes, signs)
- for step_size in step_sizes:
- step_size.clamp_(step_size_min, step_size_max)
-
- # for dir<0, dfdx=0
- # for dir>=0 dfdx=dfdx
- for i in range(len(grads)):
- grads[i] = grads[i].clone(memory_format=torch.preserve_format)
- grads[i][signs[i].eq(etaminus)] = 0
-
- # update parameters
- grad_signs = [grad.sign() for grad in grads]
- torch._foreach_addcmul_(params_with_grad, grad_signs, step_sizes, value=-1)
-
- for i in range(len(states)):
- states[i]['prev'].copy_(grads[i])
-
- return loss
diff --git a/torch/optim/_multi_tensor/rprop.pyi b/torch/optim/_multi_tensor/rprop.pyi
deleted file mode 100644
index 0ea64c6..0000000
--- a/torch/optim/_multi_tensor/rprop.pyi
+++ /dev/null
@@ -1,5 +0,0 @@
-from typing import Tuple
-from ..optimizer import _params_t, Optimizer
-
-class Rprop(Optimizer):
- def __init__(self, params: _params_t, lr: float=..., etas: Tuple[float, float]=..., step_sizes: Tuple[float, float]=...) -> None: ...
diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py
index 741f6de..f580e35 100644
--- a/torch/optim/rprop.py
+++ b/torch/optim/rprop.py
@@ -1,6 +1,7 @@
import torch
-from . import _functional as F
+from torch import Tensor
from .optimizer import Optimizer
+from typing import List, Optional
class Rprop(Optimizer):
@@ -47,17 +48,25 @@
(default: (0.5, 1.2))
step_sizes (Tuple[float, float], optional): a pair of minimal and
maximal allowed step sizes (default: (1e-6, 50))
+ foreach (bool, optional): whether foreach implementation of optimizer
+ is used (default: None)
"""
- def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):
+ def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50),
+ foreach: Optional[bool] = None):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 < etas[0] < 1.0 < etas[1]:
raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
- defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes)
+ defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=foreach)
super(Rprop, self).__init__(params, defaults)
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('foreach', None)
+
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@@ -76,6 +85,9 @@
grads = []
prevs = []
step_sizes = []
+ etaminus, etaplus = group['etas']
+ step_size_min, step_size_max = group['step_sizes']
+ foreach = group['foreach']
for p in group['params']:
if p.grad is None:
@@ -97,18 +109,128 @@
prevs.append(state['prev'])
step_sizes.append(state['step_size'])
- etaminus, etaplus = group['etas']
- step_size_min, step_size_max = group['step_sizes']
-
state['step'] += 1
- F.rprop(params,
- grads,
- prevs,
- step_sizes,
- step_size_min=step_size_min,
- step_size_max=step_size_max,
- etaminus=etaminus,
- etaplus=etaplus)
+ rprop(params,
+ grads,
+ prevs,
+ step_sizes,
+ step_size_min=step_size_min,
+ step_size_max=step_size_max,
+ etaminus=etaminus,
+ etaplus=etaplus,
+ foreach=foreach)
return loss
+
+
+def rprop(params: List[Tensor],
+ grads: List[Tensor],
+ prevs: List[Tensor],
+ step_sizes: List[Tensor],
+ # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
+ # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
+ foreach: bool = None,
+ *,
+ step_size_min: float,
+ step_size_max: float,
+ etaminus: float,
+ etaplus: float):
+ r"""Functional API that performs rprop algorithm computation.
+
+ See :class:`~torch.optim.Rprop` for details.
+ """
+
+ if foreach is None:
+ # Placeholder for more complex foreach logic to be added when value is not set
+ foreach = False
+
+ if foreach and torch.jit.is_scripting():
+ raise RuntimeError('torch.jit.script not supported with foreach optimizers')
+
+ if foreach and not torch.jit.is_scripting():
+ func = _multi_tensor_rprop
+ else:
+ func = _single_tensor_rprop
+
+ func(params,
+ grads,
+ prevs,
+ step_sizes,
+ step_size_min=step_size_min,
+ step_size_max=step_size_max,
+ etaminus=etaminus,
+ etaplus=etaplus)
+
+
+def _single_tensor_rprop(params: List[Tensor],
+ grads: List[Tensor],
+ prevs: List[Tensor],
+ step_sizes: List[Tensor],
+ *,
+ step_size_min: float,
+ step_size_max: float,
+ etaminus: float,
+ etaplus: float):
+
+ for i, param in enumerate(params):
+ grad = grads[i]
+ prev = prevs[i]
+ step_size = step_sizes[i]
+
+ sign = grad.mul(prev).sign()
+ sign[sign.gt(0)] = etaplus
+ sign[sign.lt(0)] = etaminus
+ sign[sign.eq(0)] = 1
+
+ # update stepsizes with step size updates
+ step_size.mul_(sign).clamp_(step_size_min, step_size_max)
+
+ # for dir<0, dfdx=0
+ # for dir>=0 dfdx=dfdx
+ grad = grad.clone(memory_format=torch.preserve_format)
+ grad[sign.eq(etaminus)] = 0
+
+ # update parameters
+ param.addcmul_(grad.sign(), step_size, value=-1)
+
+ prev.copy_(grad)
+
+
+def _multi_tensor_rprop(params: List[Tensor],
+ grads: List[Tensor],
+ prevs: List[Tensor],
+ step_sizes: List[Tensor],
+ *,
+ step_size_min: float,
+ step_size_max: float,
+ etaminus: float,
+ etaplus: float):
+
+ if len(params) == 0:
+ return
+
+ signs = torch._foreach_mul(grads, prevs)
+ signs = [s.sign() for s in signs]
+ for sign in signs:
+ sign[sign.gt(0)] = etaplus
+ sign[sign.lt(0)] = etaminus
+ sign[sign.eq(0)] = 1
+
+ # update stepsizes with step size updates
+ torch._foreach_mul_(step_sizes, signs)
+ for step_size in step_sizes:
+ step_size.clamp_(step_size_min, step_size_max)
+
+ # for dir<0, dfdx=0
+ # for dir>=0 dfdx=dfdx
+ for i in range(len(grads)):
+ grads[i] = grads[i].clone(memory_format=torch.preserve_format)
+ grads[i][signs[i].eq(etaminus)] = 0
+
+ # update parameters
+ grad_signs = [grad.sign() for grad in grads]
+ torch._foreach_addcmul_(params, grad_signs, step_sizes, value=-1)
+
+ for i in range(len(prevs)):
+ prevs[i].copy_(grads[i])