blob: 435c3e15f86c0a17cfab3c29b2e3ff2875c48560 [file] [log] [blame]
class Optimizer(object):
def __init__(self, model):
self.model = model
self.parameters = list(self.model.parameters())
def _forward_backward(self, forward_closure):
self.model.zero_grad()
loss = forward_closure()
loss.backward()
return loss.data[0]
def step(self, forward_closure):
raise NotImplementedError