|  | import torch | 
|  | from .Criterion import Criterion | 
|  | from .LogSoftMax import LogSoftMax | 
|  | from .ClassNLLCriterion import ClassNLLCriterion | 
|  |  | 
|  |  | 
|  | class CrossEntropyCriterion(Criterion): | 
|  |  | 
|  | def __init__(self, weights=None): | 
|  | super(CrossEntropyCriterion, self).__init__() | 
|  | self.lsm = LogSoftMax() | 
|  | self.nll = ClassNLLCriterion(weights) | 
|  |  | 
|  | def updateOutput(self, input, target): | 
|  | input = input.squeeze() | 
|  | target = target.squeeze() | 
|  | self.lsm.updateOutput(input) | 
|  | self.nll.updateOutput(self.lsm.output, target) | 
|  | self.output = self.nll.output | 
|  | return self.output | 
|  |  | 
|  | def updateGradInput(self, input, target): | 
|  | size = input.size() | 
|  | input = input.squeeze() | 
|  | target = target.squeeze() | 
|  | self.nll.updateGradInput(self.lsm.output, target) | 
|  | self.lsm.updateGradInput(input, self.nll.gradInput) | 
|  | self.gradInput = self.lsm.gradInput.view(size) | 
|  | return self.gradInput |