blob: 7ef67ce16f326acefd0b6089f39d96401ad3d5a9 [file] [log] [blame]
import torch
from .Criterion import Criterion
class DistKLDivCriterion(Criterion):
def __init__(self, sizeAverage=True):
super(DistKLDivCriterion, self).__init__()
self.sizeAverage = sizeAverage
self.output_tensor = torch.Tensor(1)
def updateOutput(self, input, target):
assert input.isSameSizeAs(target)
self.output_tensor = self.output_tensor or input.new(1)
self._backend.DistKLDivCriterion_updateOutput(
self._backend.library_state,
input,
target,
self.output_tensor,
self.sizeAverage
)
self.output = self.output_tensor[0]
return self.output
def updateGradInput(self, input, target):
assert input.isSameSizeAs(target)
self._backend.DistKLDivCriterion_updateGradInput(
self._backend.library_state,
input,
target,
self.gradInput,
self.sizeAverage
)
return self.gradInput