blob: 455b32cf928b1ee5ad06edea968a4d84086bad9c [file] [log] [blame]
import torch
from .Criterion import Criterion
from .utils import recursiveResizeAs, recursiveFill, recursiveAdd
class MultiCriterion(Criterion):
def __init__(self, ):
super(MultiCriterion, self).__init__()
self.criterions = []
self.weights = torch.DoubleStorage()
def add(self, criterion, weight=1):
self.criterions.append(criterion)
new_weights = torch.DoubleStorage(len(self.criterions))
for i, v in enumerate(self.weights):
new_weights[i] = v
new_weights[len(self.criterions) - 1] = weight
self.weights = new_weights
return self
def updateOutput(self, input, target):
self.output = 0
for i in range(len(self.criterions)):
self.output = self.output + self.weights[i] * self.criterions[i].updateOutput(input, target)
return self.output
def updateGradInput(self, input, target):
self.gradInput = recursiveResizeAs(self.gradInput, input)[0]
recursiveFill(self.gradInput, 0)
for i in range(len(self.criterions)):
recursiveAdd(self.gradInput, self.weights[i], self.criterions[i].updateGradInput(input, target))
return self.gradInput
def type(self, type):
for criterion in self.criterions:
criterion.type(type)
return super(MultiCriterion, self).type(type)