blob: 068acf93a2e2144a4b39e0ee8921ca0a486e605a [file] [log] [blame]
import torch
from .Criterion import Criterion
class HingeEmbeddingCriterion(Criterion):
def __init__(self, margin=1, sizeAverage=True):
super(HingeEmbeddingCriterion, self).__init__()
self.margin = margin
self.sizeAverage = sizeAverage
self.buffer = None
def updateOutput(self, input, y):
if self.buffer is None:
self.buffer = input.new()
self.buffer.resize_as_(input).copy_(input)
self.buffer[torch.eq(y, -1.)] = 0
self.output = self.buffer.sum().item()
self.buffer.fill_(self.margin).add_(-1, input)
self.buffer.clamp_(min=0)
self.buffer[torch.eq(y, 1.)] = 0
self.output = self.output + self.buffer.sum().item()
if self.sizeAverage:
self.output = self.output / input.nelement()
return self.output
def updateGradInput(self, input, y):
self.gradInput.resize_as_(input).copy_(y)
self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0
if self.sizeAverage:
self.gradInput.mul_(1. / input.nelement())
return self.gradInput