|  | import torch | 
|  | from functools import reduce | 
|  | from .optimizer import Optimizer | 
|  |  | 
|  |  | 
|  | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): | 
|  | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua | 
|  | # Compute bounds of interpolation area | 
|  | if bounds is not None: | 
|  | xmin_bound, xmax_bound = bounds | 
|  | else: | 
|  | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) | 
|  |  | 
|  | # Code for most common case: cubic interpolation of 2 points | 
|  | #   w/ function and derivative values for both | 
|  | # Solution in this case (where x2 is the farthest point): | 
|  | #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); | 
|  | #   d2 = sqrt(d1^2 - g1*g2); | 
|  | #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); | 
|  | #   t_new = min(max(min_pos,xmin_bound),xmax_bound); | 
|  | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) | 
|  | d2_square = d1**2 - g1 * g2 | 
|  | if d2_square >= 0: | 
|  | d2 = d2_square.sqrt() | 
|  | if x1 <= x2: | 
|  | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) | 
|  | else: | 
|  | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) | 
|  | return min(max(min_pos, xmin_bound), xmax_bound) | 
|  | else: | 
|  | return (xmin_bound + xmax_bound) / 2. | 
|  |  | 
|  |  | 
|  | def _strong_wolfe(obj_func, | 
|  | x, | 
|  | t, | 
|  | d, | 
|  | f, | 
|  | g, | 
|  | gtd, | 
|  | c1=1e-4, | 
|  | c2=0.9, | 
|  | tolerance_change=1e-9, | 
|  | max_ls=25): | 
|  | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua | 
|  | d_norm = d.abs().max() | 
|  | g = g.clone(memory_format=torch.contiguous_format) | 
|  | # evaluate objective and gradient using initial step | 
|  | f_new, g_new = obj_func(x, t, d) | 
|  | ls_func_evals = 1 | 
|  | gtd_new = g_new.dot(d) | 
|  |  | 
|  | # bracket an interval containing a point satisfying the Wolfe criteria | 
|  | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd | 
|  | done = False | 
|  | ls_iter = 0 | 
|  | while ls_iter < max_ls: | 
|  | # check conditions | 
|  | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): | 
|  | bracket = [t_prev, t] | 
|  | bracket_f = [f_prev, f_new] | 
|  | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] | 
|  | bracket_gtd = [gtd_prev, gtd_new] | 
|  | break | 
|  |  | 
|  | if abs(gtd_new) <= -c2 * gtd: | 
|  | bracket = [t] | 
|  | bracket_f = [f_new] | 
|  | bracket_g = [g_new] | 
|  | done = True | 
|  | break | 
|  |  | 
|  | if gtd_new >= 0: | 
|  | bracket = [t_prev, t] | 
|  | bracket_f = [f_prev, f_new] | 
|  | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] | 
|  | bracket_gtd = [gtd_prev, gtd_new] | 
|  | break | 
|  |  | 
|  | # interpolate | 
|  | min_step = t + 0.01 * (t - t_prev) | 
|  | max_step = t * 10 | 
|  | tmp = t | 
|  | t = _cubic_interpolate( | 
|  | t_prev, | 
|  | f_prev, | 
|  | gtd_prev, | 
|  | t, | 
|  | f_new, | 
|  | gtd_new, | 
|  | bounds=(min_step, max_step)) | 
|  |  | 
|  | # next step | 
|  | t_prev = tmp | 
|  | f_prev = f_new | 
|  | g_prev = g_new.clone(memory_format=torch.contiguous_format) | 
|  | gtd_prev = gtd_new | 
|  | f_new, g_new = obj_func(x, t, d) | 
|  | ls_func_evals += 1 | 
|  | gtd_new = g_new.dot(d) | 
|  | ls_iter += 1 | 
|  |  | 
|  | # reached max number of iterations? | 
|  | if ls_iter == max_ls: | 
|  | bracket = [0, t] | 
|  | bracket_f = [f, f_new] | 
|  | bracket_g = [g, g_new] | 
|  |  | 
|  | # zoom phase: we now have a point satisfying the criteria, or | 
|  | # a bracket around it. We refine the bracket until we find the | 
|  | # exact point satisfying the criteria | 
|  | insuf_progress = False | 
|  | # find high and low points in bracket | 
|  | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) | 
|  | while not done and ls_iter < max_ls: | 
|  | # compute new trial value | 
|  | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], | 
|  | bracket[1], bracket_f[1], bracket_gtd[1]) | 
|  |  | 
|  | # test that we are making sufficient progress: | 
|  | # in case `t` is so close to boundary, we mark that we are making | 
|  | # insufficient progress, and if | 
|  | #   + we have made insufficient progress in the last step, or | 
|  | #   + `t` is at one of the boundary, | 
|  | # we will move `t` to a position which is `0.1 * len(bracket)` | 
|  | # away from the nearest boundary point. | 
|  | eps = 0.1 * (max(bracket) - min(bracket)) | 
|  | if min(max(bracket) - t, t - min(bracket)) < eps: | 
|  | # interpolation close to boundary | 
|  | if insuf_progress or t >= max(bracket) or t <= min(bracket): | 
|  | # evaluate at 0.1 away from boundary | 
|  | if abs(t - max(bracket)) < abs(t - min(bracket)): | 
|  | t = max(bracket) - eps | 
|  | else: | 
|  | t = min(bracket) + eps | 
|  | insuf_progress = False | 
|  | else: | 
|  | insuf_progress = True | 
|  | else: | 
|  | insuf_progress = False | 
|  |  | 
|  | # Evaluate new point | 
|  | f_new, g_new = obj_func(x, t, d) | 
|  | ls_func_evals += 1 | 
|  | gtd_new = g_new.dot(d) | 
|  | ls_iter += 1 | 
|  |  | 
|  | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: | 
|  | # Armijo condition not satisfied or not lower than lowest point | 
|  | bracket[high_pos] = t | 
|  | bracket_f[high_pos] = f_new | 
|  | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) | 
|  | bracket_gtd[high_pos] = gtd_new | 
|  | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) | 
|  | else: | 
|  | if abs(gtd_new) <= -c2 * gtd: | 
|  | # Wolfe conditions satisfied | 
|  | done = True | 
|  | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: | 
|  | # old high becomes new low | 
|  | bracket[high_pos] = bracket[low_pos] | 
|  | bracket_f[high_pos] = bracket_f[low_pos] | 
|  | bracket_g[high_pos] = bracket_g[low_pos] | 
|  | bracket_gtd[high_pos] = bracket_gtd[low_pos] | 
|  |  | 
|  | # new point becomes new low | 
|  | bracket[low_pos] = t | 
|  | bracket_f[low_pos] = f_new | 
|  | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) | 
|  | bracket_gtd[low_pos] = gtd_new | 
|  |  | 
|  | # line-search bracket is so small | 
|  | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: | 
|  | break | 
|  |  | 
|  | # return stuff | 
|  | t = bracket[low_pos] | 
|  | f_new = bracket_f[low_pos] | 
|  | g_new = bracket_g[low_pos] | 
|  | return f_new, g_new, t, ls_func_evals | 
|  |  | 
|  |  | 
|  | class LBFGS(Optimizer): | 
|  | """Implements L-BFGS algorithm, heavily inspired by `minFunc | 
|  | <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`. | 
|  |  | 
|  | .. warning:: | 
|  | This optimizer doesn't support per-parameter options and parameter | 
|  | groups (there can be only one). | 
|  |  | 
|  | .. warning:: | 
|  | Right now all parameters have to be on a single device. This will be | 
|  | improved in the future. | 
|  |  | 
|  | .. note:: | 
|  | This is a very memory intensive optimizer (it requires additional | 
|  | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory | 
|  | try reducing the history size, or use a different algorithm. | 
|  |  | 
|  | Arguments: | 
|  | lr (float): learning rate (default: 1) | 
|  | max_iter (int): maximal number of iterations per optimization step | 
|  | (default: 20) | 
|  | max_eval (int): maximal number of function evaluations per optimization | 
|  | step (default: max_iter * 1.25). | 
|  | tolerance_grad (float): termination tolerance on first order optimality | 
|  | (default: 1e-5). | 
|  | tolerance_change (float): termination tolerance on function | 
|  | value/parameter changes (default: 1e-9). | 
|  | history_size (int): update history size (default: 100). | 
|  | line_search_fn (str): either 'strong_wolfe' or None (default: None). | 
|  | """ | 
|  |  | 
|  | def __init__(self, | 
|  | params, | 
|  | lr=1, | 
|  | max_iter=20, | 
|  | max_eval=None, | 
|  | tolerance_grad=1e-7, | 
|  | tolerance_change=1e-9, | 
|  | history_size=100, | 
|  | line_search_fn=None): | 
|  | if max_eval is None: | 
|  | max_eval = max_iter * 5 // 4 | 
|  | defaults = dict( | 
|  | lr=lr, | 
|  | max_iter=max_iter, | 
|  | max_eval=max_eval, | 
|  | tolerance_grad=tolerance_grad, | 
|  | tolerance_change=tolerance_change, | 
|  | history_size=history_size, | 
|  | line_search_fn=line_search_fn) | 
|  | super(LBFGS, self).__init__(params, defaults) | 
|  |  | 
|  | if len(self.param_groups) != 1: | 
|  | raise ValueError("LBFGS doesn't support per-parameter options " | 
|  | "(parameter groups)") | 
|  |  | 
|  | self._params = self.param_groups[0]['params'] | 
|  | self._numel_cache = None | 
|  |  | 
|  | def _numel(self): | 
|  | if self._numel_cache is None: | 
|  | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) | 
|  | return self._numel_cache | 
|  |  | 
|  | def _gather_flat_grad(self): | 
|  | views = [] | 
|  | for p in self._params: | 
|  | if p.grad is None: | 
|  | view = p.new(p.numel()).zero_() | 
|  | elif p.grad.is_sparse: | 
|  | view = p.grad.to_dense().view(-1) | 
|  | else: | 
|  | view = p.grad.view(-1) | 
|  | views.append(view) | 
|  | return torch.cat(views, 0) | 
|  |  | 
|  | def _add_grad(self, step_size, update): | 
|  | offset = 0 | 
|  | for p in self._params: | 
|  | numel = p.numel() | 
|  | # view as to avoid deprecated pointwise semantics | 
|  | p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) | 
|  | offset += numel | 
|  | assert offset == self._numel() | 
|  |  | 
|  | def _clone_param(self): | 
|  | return [p.clone(memory_format=torch.contiguous_format) for p in self._params] | 
|  |  | 
|  | def _set_param(self, params_data): | 
|  | for p, pdata in zip(self._params, params_data): | 
|  | p.copy_(pdata) | 
|  |  | 
|  | def _directional_evaluate(self, closure, x, t, d): | 
|  | self._add_grad(t, d) | 
|  | loss = float(closure()) | 
|  | flat_grad = self._gather_flat_grad() | 
|  | self._set_param(x) | 
|  | return loss, flat_grad | 
|  |  | 
|  | @torch.no_grad() | 
|  | def step(self, closure): | 
|  | """Performs a single optimization step. | 
|  |  | 
|  | Arguments: | 
|  | closure (callable): A closure that reevaluates the model | 
|  | and returns the loss. | 
|  | """ | 
|  | assert len(self.param_groups) == 1 | 
|  |  | 
|  | # Make sure the closure is always called with grad enabled | 
|  | closure = torch.enable_grad()(closure) | 
|  |  | 
|  | group = self.param_groups[0] | 
|  | lr = group['lr'] | 
|  | max_iter = group['max_iter'] | 
|  | max_eval = group['max_eval'] | 
|  | tolerance_grad = group['tolerance_grad'] | 
|  | tolerance_change = group['tolerance_change'] | 
|  | line_search_fn = group['line_search_fn'] | 
|  | history_size = group['history_size'] | 
|  |  | 
|  | # NOTE: LBFGS has only global state, but we register it as state for | 
|  | # the first param, because this helps with casting in load_state_dict | 
|  | state = self.state[self._params[0]] | 
|  | state.setdefault('func_evals', 0) | 
|  | state.setdefault('n_iter', 0) | 
|  |  | 
|  | # evaluate initial f(x) and df/dx | 
|  | orig_loss = closure() | 
|  | loss = float(orig_loss) | 
|  | current_evals = 1 | 
|  | state['func_evals'] += 1 | 
|  |  | 
|  | flat_grad = self._gather_flat_grad() | 
|  | opt_cond = flat_grad.abs().max() <= tolerance_grad | 
|  |  | 
|  | # optimal condition | 
|  | if opt_cond: | 
|  | return orig_loss | 
|  |  | 
|  | # tensors cached in state (for tracing) | 
|  | d = state.get('d') | 
|  | t = state.get('t') | 
|  | old_dirs = state.get('old_dirs') | 
|  | old_stps = state.get('old_stps') | 
|  | ro = state.get('ro') | 
|  | H_diag = state.get('H_diag') | 
|  | prev_flat_grad = state.get('prev_flat_grad') | 
|  | prev_loss = state.get('prev_loss') | 
|  |  | 
|  | n_iter = 0 | 
|  | # optimize for a max of max_iter iterations | 
|  | while n_iter < max_iter: | 
|  | # keep track of nb of iterations | 
|  | n_iter += 1 | 
|  | state['n_iter'] += 1 | 
|  |  | 
|  | ############################################################ | 
|  | # compute gradient descent direction | 
|  | ############################################################ | 
|  | if state['n_iter'] == 1: | 
|  | d = flat_grad.neg() | 
|  | old_dirs = [] | 
|  | old_stps = [] | 
|  | ro = [] | 
|  | H_diag = 1 | 
|  | else: | 
|  | # do lbfgs update (update memory) | 
|  | y = flat_grad.sub(prev_flat_grad) | 
|  | s = d.mul(t) | 
|  | ys = y.dot(s)  # y*s | 
|  | if ys > 1e-10: | 
|  | # updating memory | 
|  | if len(old_dirs) == history_size: | 
|  | # shift history by one (limited-memory) | 
|  | old_dirs.pop(0) | 
|  | old_stps.pop(0) | 
|  | ro.pop(0) | 
|  |  | 
|  | # store new direction/step | 
|  | old_dirs.append(y) | 
|  | old_stps.append(s) | 
|  | ro.append(1. / ys) | 
|  |  | 
|  | # update scale of initial Hessian approximation | 
|  | H_diag = ys / y.dot(y)  # (y*y) | 
|  |  | 
|  | # compute the approximate (L-BFGS) inverse Hessian | 
|  | # multiplied by the gradient | 
|  | num_old = len(old_dirs) | 
|  |  | 
|  | if 'al' not in state: | 
|  | state['al'] = [None] * history_size | 
|  | al = state['al'] | 
|  |  | 
|  | # iteration in L-BFGS loop collapsed to use just one buffer | 
|  | q = flat_grad.neg() | 
|  | for i in range(num_old - 1, -1, -1): | 
|  | al[i] = old_stps[i].dot(q) * ro[i] | 
|  | q.add_(old_dirs[i], alpha=-al[i]) | 
|  |  | 
|  | # multiply by initial Hessian | 
|  | # r/d is the final direction | 
|  | d = r = torch.mul(q, H_diag) | 
|  | for i in range(num_old): | 
|  | be_i = old_dirs[i].dot(r) * ro[i] | 
|  | r.add_(old_stps[i], alpha=al[i] - be_i) | 
|  |  | 
|  | if prev_flat_grad is None: | 
|  | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) | 
|  | else: | 
|  | prev_flat_grad.copy_(flat_grad) | 
|  | prev_loss = loss | 
|  |  | 
|  | ############################################################ | 
|  | # compute step length | 
|  | ############################################################ | 
|  | # reset initial guess for step size | 
|  | if state['n_iter'] == 1: | 
|  | t = min(1., 1. / flat_grad.abs().sum()) * lr | 
|  | else: | 
|  | t = lr | 
|  |  | 
|  | # directional derivative | 
|  | gtd = flat_grad.dot(d)  # g * d | 
|  |  | 
|  | # directional derivative is below tolerance | 
|  | if gtd > -tolerance_change: | 
|  | break | 
|  |  | 
|  | # optional line search: user function | 
|  | ls_func_evals = 0 | 
|  | if line_search_fn is not None: | 
|  | # perform line search, using user function | 
|  | if line_search_fn != "strong_wolfe": | 
|  | raise RuntimeError("only 'strong_wolfe' is supported") | 
|  | else: | 
|  | x_init = self._clone_param() | 
|  |  | 
|  | def obj_func(x, t, d): | 
|  | return self._directional_evaluate(closure, x, t, d) | 
|  |  | 
|  | loss, flat_grad, t, ls_func_evals = _strong_wolfe( | 
|  | obj_func, x_init, t, d, loss, flat_grad, gtd) | 
|  | self._add_grad(t, d) | 
|  | opt_cond = flat_grad.abs().max() <= tolerance_grad | 
|  | else: | 
|  | # no line search, simply move with fixed-step | 
|  | self._add_grad(t, d) | 
|  | if n_iter != max_iter: | 
|  | # re-evaluate function only if not in last iteration | 
|  | # the reason we do this: in a stochastic setting, | 
|  | # no use to re-evaluate that function here | 
|  | with torch.enable_grad(): | 
|  | loss = float(closure()) | 
|  | flat_grad = self._gather_flat_grad() | 
|  | opt_cond = flat_grad.abs().max() <= tolerance_grad | 
|  | ls_func_evals = 1 | 
|  |  | 
|  | # update func eval | 
|  | current_evals += ls_func_evals | 
|  | state['func_evals'] += ls_func_evals | 
|  |  | 
|  | ############################################################ | 
|  | # check conditions | 
|  | ############################################################ | 
|  | if n_iter == max_iter: | 
|  | break | 
|  |  | 
|  | if current_evals >= max_eval: | 
|  | break | 
|  |  | 
|  | # optimal condition | 
|  | if opt_cond: | 
|  | break | 
|  |  | 
|  | # lack of progress | 
|  | if d.mul(t).abs().max() <= tolerance_change: | 
|  | break | 
|  |  | 
|  | if abs(loss - prev_loss) < tolerance_change: | 
|  | break | 
|  |  | 
|  | state['d'] = d | 
|  | state['t'] = t | 
|  | state['old_dirs'] = old_dirs | 
|  | state['old_stps'] = old_stps | 
|  | state['ro'] = ro | 
|  | state['H_diag'] = H_diag | 
|  | state['prev_flat_grad'] = prev_flat_grad | 
|  | state['prev_loss'] = prev_loss | 
|  |  | 
|  | return orig_loss |