blob: 59b2b29b861b17ca0cc0a1c7c3b8819dd0fa3b4b [file] [log] [blame]
import torch
from .Criterion import Criterion
from .Sigmoid import Sigmoid
from .BCECriterion import BCECriterion
class MultiLabelSoftMarginCriterion(Criterion):
"""
A MultiLabel multiclass criterion based on sigmoid:
the loss is:
l(x, y) = - sum_i y[i] * log(p[i]) + (1 - y[i]) * log (1 - p[i])
where p[i] = exp(x[i]) / (1 + exp(x[i]))
and with weights:
l(x, y) = - sum_i weights[i] (y[i] * log(p[i]) + (1 - y[i]) * log (1 - p[i]))
"""
def __init__(self, weights=None):
super(MultiLabelSoftMarginCriterion, self).__init__()
self.lsm = Sigmoid()
self.nll = BCECriterion(weights)
def updateOutput(self, input, target):
input = input if input.nelement() == 1 else input.squeeze()
target = target if target.nelement() == 1 else target.squeeze()
self.lsm.updateOutput(input)
self.nll.updateOutput(self.lsm.output, target)
self.output = self.nll.output
return self.output
def updateGradInput(self, input, target):
size = input.size()
input = input if input.nelement() == 1 else input.squeeze()
target = target if target.nelement() == 1 else target.squeeze()
self.nll.updateGradInput(self.lsm.output, target)
self.lsm.updateGradInput(input, self.nll.gradInput)
self.gradInput = self.lsm.gradInput.view(size)
return self.gradInput