| import torch |
| from functools import reduce |
| from .optimizer import Optimizer |
| |
| |
| class LBFGS(Optimizer): |
| """Implements L-BFGS algorithm. |
| |
| .. warning:: |
| This optimizer doesn't support per-parameter options and parameter |
| groups (there can be only one). |
| |
| .. warning:: |
| Right now all parameters have to be on a single device. This will be |
| improved in the future. |
| |
| .. note:: |
| This is a very memory intensive optimizer (it requires additional |
| ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory |
| try reducing the history size, or use a different algorithm. |
| |
| Arguments: |
| lr (float): learning rate (default: 1) |
| max_iter (int): maximal number of iterations per optimization step |
| (default: 20) |
| max_eval (int): maximal number of function evaluations per optimization |
| step (default: max_iter * 1.25). |
| tolerance_grad (float): termination tolerance on first order optimality |
| (default: 1e-5). |
| tolerance_change (float): termination tolerance on function |
| value/parameter changes (default: 1e-9). |
| history_size (int): update history size (default: 100). |
| """ |
| |
| def __init__(self, params, lr=1, max_iter=20, max_eval=None, |
| tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100, |
| line_search_fn=None): |
| if max_eval is None: |
| max_eval = max_iter * 5 // 4 |
| defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval, |
| tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, |
| history_size=history_size, line_search_fn=line_search_fn) |
| super(LBFGS, self).__init__(params, defaults) |
| |
| if len(self.param_groups) != 1: |
| raise ValueError("LBFGS doesn't support per-parameter options " |
| "(parameter groups)") |
| |
| self._params = self.param_groups[0]['params'] |
| self._numel_cache = None |
| |
| def _numel(self): |
| if self._numel_cache is None: |
| self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) |
| return self._numel_cache |
| |
| def _gather_flat_grad(self): |
| views = [] |
| for p in self._params: |
| if p.grad is None: |
| view = p.data.new(p.data.numel()).zero_() |
| elif p.grad.data.is_sparse: |
| view = p.grad.data.to_dense().view(-1) |
| else: |
| view = p.grad.data.view(-1) |
| views.append(view) |
| return torch.cat(views, 0) |
| |
| def _add_grad(self, step_size, update): |
| offset = 0 |
| for p in self._params: |
| numel = p.numel() |
| # view as to avoid deprecated pointwise semantics |
| p.data.add_(step_size, update[offset:offset + numel].view_as(p.data)) |
| offset += numel |
| assert offset == self._numel() |
| |
| def step(self, closure): |
| """Performs a single optimization step. |
| |
| Arguments: |
| closure (callable): A closure that reevaluates the model |
| and returns the loss. |
| """ |
| assert len(self.param_groups) == 1 |
| |
| group = self.param_groups[0] |
| lr = group['lr'] |
| max_iter = group['max_iter'] |
| max_eval = group['max_eval'] |
| tolerance_grad = group['tolerance_grad'] |
| tolerance_change = group['tolerance_change'] |
| line_search_fn = group['line_search_fn'] |
| history_size = group['history_size'] |
| |
| # NOTE: LBFGS has only global state, but we register it as state for |
| # the first param, because this helps with casting in load_state_dict |
| state = self.state[self._params[0]] |
| state.setdefault('func_evals', 0) |
| state.setdefault('n_iter', 0) |
| |
| # evaluate initial f(x) and df/dx |
| orig_loss = closure() |
| loss = float(orig_loss) |
| current_evals = 1 |
| state['func_evals'] += 1 |
| |
| flat_grad = self._gather_flat_grad() |
| abs_grad_sum = flat_grad.abs().sum() |
| |
| if abs_grad_sum <= tolerance_grad: |
| return orig_loss |
| |
| # tensors cached in state (for tracing) |
| d = state.get('d') |
| t = state.get('t') |
| old_dirs = state.get('old_dirs') |
| old_stps = state.get('old_stps') |
| H_diag = state.get('H_diag') |
| prev_flat_grad = state.get('prev_flat_grad') |
| prev_loss = state.get('prev_loss') |
| |
| n_iter = 0 |
| # optimize for a max of max_iter iterations |
| while n_iter < max_iter: |
| # keep track of nb of iterations |
| n_iter += 1 |
| state['n_iter'] += 1 |
| |
| ############################################################ |
| # compute gradient descent direction |
| ############################################################ |
| if state['n_iter'] == 1: |
| d = flat_grad.neg() |
| old_dirs = [] |
| old_stps = [] |
| H_diag = 1 |
| else: |
| # do lbfgs update (update memory) |
| y = flat_grad.sub(prev_flat_grad) |
| s = d.mul(t) |
| ys = y.dot(s) # y*s |
| if ys > 1e-10: |
| # updating memory |
| if len(old_dirs) == history_size: |
| # shift history by one (limited-memory) |
| old_dirs.pop(0) |
| old_stps.pop(0) |
| |
| # store new direction/step |
| old_dirs.append(y) |
| old_stps.append(s) |
| |
| # update scale of initial Hessian approximation |
| H_diag = ys / y.dot(y) # (y*y) |
| |
| # compute the approximate (L-BFGS) inverse Hessian |
| # multiplied by the gradient |
| num_old = len(old_dirs) |
| |
| if 'ro' not in state: |
| state['ro'] = [None] * history_size |
| state['al'] = [None] * history_size |
| ro = state['ro'] |
| al = state['al'] |
| |
| for i in range(num_old): |
| ro[i] = 1. / old_dirs[i].dot(old_stps[i]) |
| |
| # iteration in L-BFGS loop collapsed to use just one buffer |
| q = flat_grad.neg() |
| for i in range(num_old - 1, -1, -1): |
| al[i] = old_stps[i].dot(q) * ro[i] |
| q.add_(-al[i], old_dirs[i]) |
| |
| # multiply by initial Hessian |
| # r/d is the final direction |
| d = r = torch.mul(q, H_diag) |
| for i in range(num_old): |
| be_i = old_dirs[i].dot(r) * ro[i] |
| r.add_(al[i] - be_i, old_stps[i]) |
| |
| if prev_flat_grad is None: |
| prev_flat_grad = flat_grad.clone() |
| else: |
| prev_flat_grad.copy_(flat_grad) |
| prev_loss = loss |
| |
| ############################################################ |
| # compute step length |
| ############################################################ |
| # reset initial guess for step size |
| if state['n_iter'] == 1: |
| t = min(1., 1. / abs_grad_sum) * lr |
| else: |
| t = lr |
| |
| # directional derivative |
| gtd = flat_grad.dot(d) # g * d |
| |
| # optional line search: user function |
| ls_func_evals = 0 |
| if line_search_fn is not None: |
| # perform line search, using user function |
| raise RuntimeError("line search function is not supported yet") |
| else: |
| # no line search, simply move with fixed-step |
| self._add_grad(t, d) |
| if n_iter != max_iter: |
| # re-evaluate function only if not in last iteration |
| # the reason we do this: in a stochastic setting, |
| # no use to re-evaluate that function here |
| loss = float(closure()) |
| flat_grad = self._gather_flat_grad() |
| abs_grad_sum = flat_grad.abs().sum() |
| ls_func_evals = 1 |
| |
| # update func eval |
| current_evals += ls_func_evals |
| state['func_evals'] += ls_func_evals |
| |
| ############################################################ |
| # check conditions |
| ############################################################ |
| if n_iter == max_iter: |
| break |
| |
| if current_evals >= max_eval: |
| break |
| |
| if abs_grad_sum <= tolerance_grad: |
| break |
| |
| if gtd > -tolerance_change: |
| break |
| |
| if d.mul(t).abs_().sum() <= tolerance_change: |
| break |
| |
| if abs(loss - prev_loss) < tolerance_change: |
| break |
| |
| state['d'] = d |
| state['t'] = t |
| state['old_dirs'] = old_dirs |
| state['old_stps'] = old_stps |
| state['H_diag'] = H_diag |
| state['prev_flat_grad'] = prev_flat_grad |
| state['prev_loss'] = prev_loss |
| |
| return orig_loss |