blob: b68c7444f321f723601241623e01ffe3ec7eddbb [file] [log] [blame]
import torch
from .Criterion import Criterion
class MarginRankingCriterion(Criterion):
def __init__(self, margin=1, sizeAverage=True):
super(MarginRankingCriterion, self).__init__()
self.margin = margin
self.sizeAverage = sizeAverage
self.gradInput = [torch.Tensor(), torch.Tensor()]
self._output = None
self.dist = None
self.mask = None
def updateOutput(self, input, y):
if input[0].size(0) == 1:
self.output = max(0, -y * (input[0][0] - input[1][0]) + self.margin)
else:
if self._output is None:
self._output = input[0].clone()
self._output.resize_as_(input[0])
self._output.copy_(input[0])
self._output.add_(-1, input[1])
self._output.mul_(-1).mul_(y)
self._output.add_(self.margin)
self._output.clamp_(min=0)
self.output = self._output.sum()
if self.sizeAverage:
self.output = self.output / y.size(0)
return self.output
def updateGradInput(self, input, y):
if input[0].size(0) == 1:
dist = -y * (input[0][0] - input[1][0]) + self.margin
if dist < 0:
self.gradInput[0][0] = 0
self.gradInput[1][0] = 0
else:
self.gradInput[0][0] = -y
self.gradInput[1][0] = y
else:
if self.dist is None:
self.dist = input[0].new()
self.dist = self.dist.resize_as_(input[0]).copy_(input[0])
dist = self.dist
dist.add_(-1, input[1])
dist.mul_(-1).mul_(y)
dist.add_(self.margin)
if self.mask is None:
self.mask = input[0].new()
self.mask = self.mask.resize_as_(input[0]).copy_(dist)
mask = self.mask
torch.ge(dist, 0, out=mask)
self.gradInput[0].resize_(dist.size())
self.gradInput[1].resize_(dist.size())
self.gradInput[0].copy_(mask)
self.gradInput[0].mul_(-1).mul_(y)
self.gradInput[1].copy_(mask)
self.gradInput[1].mul_(y)
if self.sizeAverage:
self.gradInput[0].div_(y.size(0))
self.gradInput[1].div_(y.size(0))
return self.gradInput