blob: 838db031743dc5017c3a203992b77cdc01c0d31e [file] [log] [blame]
import torch
def lbfgs(opfunc, x, config, state=None):
"""
An implementation of L-BFGS, heavily inspired by minFunc (Mark Schmidt)
This implementation of L-BFGS relies on a user-provided line
search function (state.lineSearch). If this function is not
provided, then a simple learningRate is used to produce fixed
size steps. Fixed size steps are much less costly than line
searches, and can be useful for stochastic problems.
The learning rate is used even when a line search is provided.
This is also useful for large-scale stochastic problems, where
opfunc is a noisy approximation of f(x). In that case, the learning
rate allows a reduction of confidence in the step size.
Args:
- `opfunc` : a function that takes a single input (X), the point of
evaluation, and returns f(X) and df/dX
- `x` : the initial point
- `state` : a table describing the state of the optimizer; after each
call the state is modified
- `state.maxIter` : Maximum number of iterations allowed
- `state.maxEval` : Maximum number of function evaluations
- `state.tolFun` : Termination tolerance on the first-order optimality
- `state.tolX` : Termination tol on progress in terms of func/param changes
- `state.lineSearch` : A line search function
- `state.learningRate` : If no line search provided, then a fixed step size is used
Returns:
- `x*` : the new `x` vector, at the optimal point
- `f` : a table of all function values:
`f[1]` is the value of the function before any optimization and
`f[#f]` is the final fully optimized value, at `x*`
(Clement Farabet, 2012)
"""
# (0) get/update state
if config is None and state is None:
raise ValueError("lbfgs requires a dictionary to retain state between iterations")
state = state if state is not None else config
maxIter = config.get('maxIter', 20)
maxEval = config.get('maxEval', maxIter * 1.25)
tolFun = config.get('tolFun', 1e-5)
tolX = config.get('tolX', 1e-9)
nCorrection = config.get('nCorrection', 100)
lineSearch = config.get('lineSearch')
lineSearchOptions = config.get('lineSearchOptions')
learningRate = config.get('learningRate', 1)
isverbose = config.get('verbose', False)
state.setdefault('funcEval', 0)
state.setdefault('nIter', 0)
# verbose function
if isverbose:
def verbose(*args):
args = ('<optim.lbfgs>',) + args
print(args)
else:
def verbose(*args):
pass
# evaluate initial f(x) and df/dx
f, g = opfunc(x)
f_hist = [f]
currentFuncEval = 1
state['funcEval'] += 1
p = g.size(0)
# check optimality of initial point
if 'tmp1' not in state:
state['tmp1'] = g.new(g.size()).zero_()
tmp1 = state['tmp1']
tmp1.copy_(g).abs_()
if tmp1.sum() <= tolFun:
verbose('optimality condition below tolFun')
return x, f_hist
if 'dir_bufs' not in state:
# reusable buffers for y's and s's, and their histories
verbose('creating recyclable direction/step/history buffers')
state['dir_bufs'] = list(g.new(nCorrection + 1, p).split(1))
state['stp_bufs'] = list(g.new(nCorrection + 1, p).split(1))
for i in range(len(state['dir_bufs'])):
state['dir_bufs'][i] = state['dir_bufs'][i].squeeze(0)
state['stp_bufs'][i] = state['stp_bufs'][i].squeeze(0)
# variables cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
Hdiag = state.get('Hdiag')
g_old = state.get('g_old')
f_old = state.get('f_old')
# optimize for a max of maxIter iterations
nIter = 0
while nIter < maxIter:
# keep track of nb of iterations
nIter += 1
state['nIter'] += 1
############################################################
# compute gradient descent direction
############################################################
if state['nIter'] == 1:
d = g.neg()
old_dirs = []
old_stps = []
Hdiag = 1
else:
# do lbfgs update (update memory)
y = state['dir_bufs'].pop()
s = state['stp_bufs'].pop()
torch.add(g, g_old, alpha=-1, out=y)
torch.mul(d, t, out=s)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == nCorrection:
# shift history by one (limited-memory)
state['dir_bufs'].append(old_dirs.pop(0))
state['stp_bufs'].append(old_stps.pop(0))
# store new direction/step
old_dirs.append(s)
old_stps.append(y)
# update scale of initial Hessian approximation
Hdiag = ys / y.dot(y) # (y*y)
else:
# put y and s back into the buffer pool
state['dir_bufs'].append(y)
state['stp_bufs'].append(s)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
k = len(old_dirs)
# need to be accessed element-by-element, so don't re-type tensor:
if 'ro' not in state:
state['ro'] = torch.Tensor(nCorrection)
ro = state['ro']
for i in range(k):
ro[i] = 1 / old_stps[i].dot(old_dirs[i])
# iteration in L-BFGS loop collapsed to use just one buffer
q = tmp1 # reuse tmp1 for the q buffer
# need to be accessed element-by-element, so don't re-type tensor:
if 'al' not in state:
state['al'] = torch.zeros(nCorrection)
al = state['al']
torch.mul(g, -1, out=q)
for i in range(k - 1, -1, -1):
al[i] = old_dirs[i].dot(q) * ro[i]
q.add_(-al[i], old_stps[i])
# multiply by initial Hessian
r = d # share the same buffer, since we don't need the old d
torch.mul(q, Hdiag, out=r)
for i in range(k):
be_i = old_stps[i].dot(r) * ro[i]
r.add_(al[i] - be_i, old_dirs[i])
# final direction is in r/d (same object)
if g_old is None:
g_old = g.clone()
else:
g_old.copy_(g)
f_old = f
############################################################
# compute step length
############################################################
# directional derivative
gtd = g.dot(d) # g * d
# reset initial guess for step size
if state['nIter'] == 1:
tmp1.copy_(g).abs_()
t = min(1, 1 / tmp1.sum()) * learningRate
else:
t = learningRate
# optional line search: user function
lsFuncEval = 0
if lineSearch is not None:
# perform line search, using user function
f, g, x, t, lsFuncEval = lineSearch(opfunc, x, t, d, f, g, gtd, lineSearchOpts)
f_hist.append(f)
else:
# no line search, simply move with fixed-step
x.add_(t, d)
if nIter != maxIter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
f, g = opfunc(x)
lsFuncEval = 1
f_hist.append(f)
# update func eval
currentFuncEval += lsFuncEval
state['funcEval'] += lsFuncEval
############################################################
# check conditions
############################################################
if nIter == maxIter:
# no use to run tests
verbose('reached max number of iterations')
break
if currentFuncEval >= maxEval:
# max nb of function evals
verbose('max nb of function evals')
break
tmp1.copy_(g).abs_()
if tmp1.sum() <= tolFun:
# check optimality
verbose('optimality condition below tolFun')
break
# check that progress can be made along that direction
if gtd > -tolX:
break
tmp1.copy_(d).mul_(t).abs_()
if tmp1.sum() <= tolX:
# step size below tolX
verbose('step size below tolX')
break
if abs(f - f_old) < tolX:
# function value changing less than tolX
verbose('function value changing less than tolX')
break
# save state
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['Hdiag'] = Hdiag
state['g_old'] = g_old
state['f_old'] = f_old
state['t'] = t
state['d'] = d
# return optimal x, and history of f(x)
return x, f_hist, currentFuncEval