blob: 67e8b0d9ab9d4825b3fc4229638efde6d49e6fea [file] [log] [blame]
import torch
from .Criterion import Criterion
from .LogSoftMax import LogSoftMax
from .ClassNLLCriterion import ClassNLLCriterion
class CrossEntropyCriterion(Criterion):
def __init__(self, weights=None):
super(CrossEntropyCriterion, self).__init__()
self.lsm = LogSoftMax()
self.nll = ClassNLLCriterion(weights)
def updateOutput(self, input, target):
input = input.squeeze()
target = target.squeeze()
self.lsm.updateOutput(input)
self.nll.updateOutput(self.lsm.output, target)
self.output = self.nll.output
return self.output
def updateGradInput(self, input, target):
size = input.size()
input = input.squeeze()
target = target.squeeze()
self.nll.updateGradInput(self.lsm.output, target)
self.lsm.updateGradInput(input, self.nll.gradInput)
self.gradInput = self.lsm.gradInput.view(size)
return self.gradInput