| r"""Functional interface""" |
| import math |
| import torch |
| from torch import Tensor |
| from typing import List, Optional |
| |
| # TODO: use foreach API in optim._functional to do all the computation |
| |
| def _make_sparse(grad, grad_indices, values): |
| size = grad.size() |
| if grad_indices.numel() == 0 or values.numel() == 0: |
| return torch.empty_like(grad) |
| return torch.sparse_coo_tensor(grad_indices, values, size) |
| |
| |
| def adagrad(params: List[Tensor], |
| grads: List[Tensor], |
| state_sums: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| lr: float, |
| weight_decay: float, |
| lr_decay: float, |
| eps: float): |
| r"""Functional API that performs Adagrad algorithm computation. |
| |
| See :class:`~torch.optim.Adagrad` for details. |
| """ |
| |
| if not all([isinstance(t, torch.Tensor) for t in state_steps]): |
| raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
| |
| for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps): |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| if weight_decay != 0: |
| if grad.is_sparse: |
| raise RuntimeError("weight_decay option is not compatible with sparse gradients") |
| grad = grad.add(param, alpha=weight_decay) |
| |
| clr = lr / (1 + (step - 1) * lr_decay) |
| |
| if grad.is_sparse: |
| grad = grad.coalesce() # the update is non-linear so indices must be unique |
| grad_indices = grad._indices() |
| grad_values = grad._values() |
| size = grad.size() |
| |
| state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) |
| std = state_sum.sparse_mask(grad) |
| std_values = std._values().sqrt_().add_(eps) |
| param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) |
| else: |
| is_complex = torch.is_complex(param) |
| if is_complex: |
| grad = torch.view_as_real(grad) |
| state_sum = torch.view_as_real(state_sum) |
| param = torch.view_as_real(param) |
| state_sum.addcmul_(grad, grad, value=1) |
| std = state_sum.sqrt().add_(eps) |
| param.addcdiv_(grad, std, value=-clr) |
| if is_complex: |
| param = torch.view_as_complex(param) |
| state_sum = torch.view_as_complex(state_sum) |
| |
| |
| |
| |
| def adam(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_avg_sqs: List[Tensor], |
| max_exp_avg_sqs: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| amsgrad: bool, |
| beta1: float, |
| beta2: float, |
| lr: float, |
| weight_decay: float, |
| eps: float, |
| maximize: bool): |
| r"""Functional API that performs Adam algorithm computation. |
| |
| See :class:`~torch.optim.Adam` for details. |
| """ |
| |
| if not all([isinstance(t, torch.Tensor) for t in state_steps]): |
| raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
| |
| for i, param in enumerate(params): |
| |
| grad = grads[i] if not maximize else -grads[i] |
| exp_avg = exp_avgs[i] |
| exp_avg_sq = exp_avg_sqs[i] |
| step_t = state_steps[i] |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| bias_correction1 = 1 - beta1 ** step |
| bias_correction2 = 1 - beta2 ** step |
| |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| # Decay the first and second moment running average coefficient |
| exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) |
| if amsgrad: |
| # Maintains the maximum of all 2nd moment running avg. till now |
| torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) |
| # Use the max. for normalizing running avg. of gradient |
| denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) |
| else: |
| denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) |
| |
| |
| |
| step_size = lr / bias_correction1 |
| param.addcdiv_(exp_avg, denom, value=-step_size) |
| |
| def adamw(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_avg_sqs: List[Tensor], |
| max_exp_avg_sqs: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| amsgrad: bool, |
| beta1: float, |
| beta2: float, |
| lr: float, |
| weight_decay: float, |
| eps: float, |
| maximize: bool): |
| r"""Functional API that performs AdamW algorithm computation. |
| |
| See :class:`~torch.optim.AdamW` for details. |
| """ |
| |
| if not all([isinstance(t, torch.Tensor) for t in state_steps]): |
| raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
| |
| for i, param in enumerate(params): |
| grad = grads[i] if not maximize else -grads[i] |
| exp_avg = exp_avgs[i] |
| exp_avg_sq = exp_avg_sqs[i] |
| step_t = state_steps[i] |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| # Perform stepweight decay |
| param.mul_(1 - lr * weight_decay) |
| |
| bias_correction1 = 1 - beta1 ** step |
| bias_correction2 = 1 - beta2 ** step |
| |
| # Decay the first and second moment running average coefficient |
| exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| if amsgrad: |
| # Maintains the maximum of all 2nd moment running avg. till now |
| torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) |
| # Use the max. for normalizing running avg. of gradient |
| denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) |
| else: |
| denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) |
| |
| step_size = lr / bias_correction1 |
| |
| param.addcdiv_(exp_avg, denom, value=-step_size) |
| |
| |
| def sgd(params: List[Tensor], |
| d_p_list: List[Tensor], |
| momentum_buffer_list: List[Optional[Tensor]], |
| *, |
| weight_decay: float, |
| momentum: float, |
| lr: float, |
| dampening: float, |
| nesterov: bool, |
| maximize: bool): |
| r"""Functional API that performs SGD algorithm computation. |
| |
| See :class:`~torch.optim.SGD` for details. |
| """ |
| |
| for i, param in enumerate(params): |
| |
| d_p = d_p_list[i] |
| if weight_decay != 0: |
| d_p = d_p.add(param, alpha=weight_decay) |
| |
| if momentum != 0: |
| buf = momentum_buffer_list[i] |
| |
| if buf is None: |
| buf = torch.clone(d_p).detach() |
| momentum_buffer_list[i] = buf |
| else: |
| buf.mul_(momentum).add_(d_p, alpha=1 - dampening) |
| |
| if nesterov: |
| d_p = d_p.add(buf, alpha=momentum) |
| else: |
| d_p = buf |
| |
| alpha = lr if maximize else -lr |
| param.add_(d_p, alpha=alpha) |
| |
| |
| def adadelta(params: List[Tensor], |
| grads: List[Tensor], |
| square_avgs: List[Tensor], |
| acc_deltas: List[Tensor], |
| *, |
| lr: float, |
| rho: float, |
| eps: float, |
| weight_decay: float): |
| r"""Functional API that performs Adadelta algorithm computation. |
| |
| See :class:`~torch.optim.Adadelta` for details. |
| """ |
| |
| for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas): |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| if torch.is_complex(param): |
| square_avg = torch.view_as_real(square_avg) |
| acc_delta = torch.view_as_real(acc_delta) |
| grad = torch.view_as_real(grad) |
| |
| square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) |
| std = square_avg.add(eps).sqrt_() |
| delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) |
| acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) |
| if torch.is_complex(param): |
| delta = torch.view_as_complex(delta) |
| param.add_(delta, alpha=-lr) |
| |
| def rmsprop(params: List[Tensor], |
| grads: List[Tensor], |
| square_avgs: List[Tensor], |
| grad_avgs: List[Tensor], |
| momentum_buffer_list: List[Tensor], |
| *, |
| lr: float, |
| alpha: float, |
| eps: float, |
| weight_decay: float, |
| momentum: float, |
| centered: bool): |
| r"""Functional API that performs rmsprop algorithm computation. |
| |
| See :class:`~torch.optim.RMSProp` for details. |
| """ |
| |
| for i, param in enumerate(params): |
| grad = grads[i] |
| square_avg = square_avgs[i] |
| |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) |
| |
| if centered: |
| grad_avg = grad_avgs[i] |
| grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) |
| avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps) |
| else: |
| avg = square_avg.sqrt().add_(eps) |
| |
| if momentum > 0: |
| buf = momentum_buffer_list[i] |
| buf.mul_(momentum).addcdiv_(grad, avg) |
| param.add_(buf, alpha=-lr) |
| else: |
| param.addcdiv_(grad, avg, value=-lr) |
| |
| |
| def rprop(params: List[Tensor], |
| grads: List[Tensor], |
| prevs: List[Tensor], |
| step_sizes: List[Tensor], |
| *, |
| step_size_min: float, |
| step_size_max: float, |
| etaminus: float, |
| etaplus: float): |
| r"""Functional API that performs rprop algorithm computation. |
| |
| See :class:`~torch.optim.Rprop` for details. |
| """ |
| |
| for i, param in enumerate(params): |
| grad = grads[i] |
| prev = prevs[i] |
| step_size = step_sizes[i] |
| |
| sign = grad.mul(prev).sign() |
| sign[sign.gt(0)] = etaplus |
| sign[sign.lt(0)] = etaminus |
| sign[sign.eq(0)] = 1 |
| |
| # update stepsizes with step size updates |
| step_size.mul_(sign).clamp_(step_size_min, step_size_max) |
| |
| # for dir<0, dfdx=0 |
| # for dir>=0 dfdx=dfdx |
| grad = grad.clone(memory_format=torch.preserve_format) |
| grad[sign.eq(etaminus)] = 0 |
| |
| # update parameters |
| param.addcmul_(grad.sign(), step_size, value=-1) |
| |
| prev.copy_(grad) |
| |
| |
| def adamax(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_infs: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| eps: float, |
| beta1: float, |
| beta2: float, |
| lr: float, |
| weight_decay: float): |
| r"""Functional API that performs adamax algorithm computation. |
| |
| See :class:`~torch.optim.Adamax` for details. |
| """ |
| |
| if not all([isinstance(t, torch.Tensor) for t in state_steps]): |
| raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
| |
| for i, param in enumerate(params): |
| grad = grads[i] |
| exp_avg = exp_avgs[i] |
| exp_inf = exp_infs[i] |
| step_t = state_steps[i] |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| # Update biased first moment estimate. |
| exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| # Update the exponentially weighted infinity norm. |
| norm_buf = torch.cat([ |
| exp_inf.mul_(beta2).unsqueeze(0), |
| grad.abs().add_(eps).unsqueeze_(0) |
| ], 0) |
| torch.amax(norm_buf, 0, keepdim=False, out=exp_inf) |
| |
| bias_correction = 1 - beta1 ** step |
| clr = lr / bias_correction |
| |
| param.addcdiv_(exp_avg, exp_inf, value=-clr) |
| |
| |
| def asgd(params: List[Tensor], |
| grads: List[Tensor], |
| axs: List[Tensor], |
| mus: List[Tensor], |
| etas: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| lambd: float, |
| lr: float, |
| t0: float, |
| alpha: float, |
| weight_decay: float): |
| r"""Functional API that performs asgd algorithm computation. |
| See :class:`~torch.optim.ASGD` for details. |
| """ |
| |
| for i, param in enumerate(params): |
| grad = grads[i] |
| mu = mus[i] |
| ax = axs[i] |
| eta = etas[i] |
| step_t = state_steps[i] |
| |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| # decay term |
| param.mul_(1 - lambd * eta.item()) |
| |
| # update parameter |
| param.add_(grad, alpha=-eta.item()) |
| |
| # averaging |
| if mu.item() != 1: |
| ax.add_(param.sub(ax).mul(mu)) |
| else: |
| ax.copy_(param) |
| |
| new_eta = torch.tensor(lr / math.pow((1 + lambd * lr * step), alpha)) |
| eta.copy_(new_eta) |
| new_mu = torch.tensor(1 / max(1, step - t0)) |
| mu.copy_(new_mu) |
| |
| |
| def nadam(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_avg_sqs: List[Tensor], |
| mu_products: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| beta1: float, |
| beta2: float, |
| lr: float, |
| weight_decay: float, |
| momentum_decay: float, |
| eps: float): |
| r"""Functional API that performs NAdam algorithm computation. |
| See :class:`~torch.optim.NAdam` for details. |
| """ |
| |
| if not all([isinstance(t, torch.Tensor) for t in state_steps]): |
| raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
| |
| if not all([isinstance(t, torch.Tensor) for t in mu_products]): |
| raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors") |
| |
| for i, param in enumerate(params): |
| grad = grads[i] |
| exp_avg = exp_avgs[i] |
| exp_avg_sq = exp_avg_sqs[i] |
| mu_product = mu_products[i] |
| step_t = state_steps[i] |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| bias_correction2 = 1 - beta2 ** step |
| |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| # calculate the momentum cache \mu^{t} and \mu^{t+1} |
| mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay))) |
| mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) |
| |
| # update mu_product |
| mu_product *= mu |
| mu_product_next = mu_product * mu * mu_next |
| |
| # decay the first and second moment running average coefficient |
| exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| |
| denom = exp_avg_sq.div(bias_correction2).sqrt().add_(eps) |
| param.addcdiv_(grad, denom, value=-lr * (1. - mu) / (1. - mu_product.item())) |
| param.addcdiv_(exp_avg, denom, value=-lr * mu_next / (1. - mu_product_next.item())) |
| |
| |
| def radam(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_avg_sqs: List[Tensor], |
| state_steps: List[Tensor], |
| *, |
| beta1: float, |
| beta2: float, |
| lr: float, |
| weight_decay: float, |
| eps: float): |
| r"""Functional API that performs RAdam algorithm computation. |
| |
| See :class:`~torch.optim.RAdam` for details. |
| """ |
| |
| if not all([isinstance(t, torch.Tensor) for t in state_steps]): |
| raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
| |
| for i, param in enumerate(params): |
| grad = grads[i] |
| exp_avg = exp_avgs[i] |
| exp_avg_sq = exp_avg_sqs[i] |
| step_t = state_steps[i] |
| # update step |
| step_t += 1 |
| step = step_t.item() |
| |
| bias_correction1 = 1 - beta1 ** step |
| bias_correction2 = 1 - beta2 ** step |
| |
| if weight_decay != 0: |
| grad = grad.add(param, alpha=weight_decay) |
| |
| # Decay the first and second moment running average coefficient |
| exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| |
| # correcting bias for the first moving moment |
| bias_corrected_exp_avg = exp_avg / bias_correction1 |
| |
| # maximum length of the approximated SMA |
| rho_inf = 2 / (1 - beta2) - 1 |
| # compute the length of the approximated SMA |
| rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2 |
| |
| if rho_t > 5.: |
| # Compute the variance rectification term and update parameters accordingly |
| rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t)) |
| adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps) |
| |
| param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0) |
| else: |
| param.add_(bias_corrected_exp_avg * lr, alpha=-1.0) |
| |
| |
| def sparse_adam(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_avg_sqs: List[Tensor], |
| state_steps: List[int], |
| *, |
| eps: float, |
| beta1: float, |
| beta2: float, |
| lr: float): |
| r"""Functional API that performs Sparse Adam algorithm computation. |
| |
| See :class:`~torch.optim.SparseAdam` for details. |
| """ |
| for i, param in enumerate(params): |
| grad = grads[i] |
| grad = grad.coalesce() # the update is non-linear so indices must be unique |
| grad_indices = grad._indices() |
| grad_values = grad._values() |
| size = grad.size() |
| |
| exp_avg = exp_avgs[i] |
| exp_avg_sq = exp_avg_sqs[i] |
| step = state_steps[i] |
| |
| |
| def make_sparse(values): |
| constructor = grad.new |
| if grad_indices.dim() == 0 or values.dim() == 0: |
| return constructor().resize_as_(grad) |
| return constructor(grad_indices, values, size) |
| |
| # Decay the first and second moment running average coefficient |
| # old <- b * old + (1 - b) * new |
| # <==> old += (1 - b) * (new - old) |
| old_exp_avg_values = exp_avg.sparse_mask(grad)._values() |
| exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) |
| exp_avg.add_(make_sparse(exp_avg_update_values)) |
| old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() |
| exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) |
| exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) |
| |
| # Dense addition again is intended, avoiding another sparse_mask |
| numer = exp_avg_update_values.add_(old_exp_avg_values) |
| exp_avg_sq_update_values.add_(old_exp_avg_sq_values) |
| denom = exp_avg_sq_update_values.sqrt_().add_(eps) |
| del exp_avg_update_values, exp_avg_sq_update_values |
| |
| bias_correction1 = 1 - beta1 ** step |
| bias_correction2 = 1 - beta2 ** step |
| step_size = lr * math.sqrt(bias_correction2) / bias_correction1 |
| |
| param.add_(make_sparse(-step_size * numer.div_(denom))) |