blob: 7ecfd95c6b540a5d8ee75b59a70de0d178052a4b [file] [log] [blame]
import torch
from .Criterion import Criterion
from .utils import recursiveResizeAs, recursiveFill, recursiveAdd
class ParallelCriterion(Criterion):
def __init__(self, repeatTarget=False):
super(ParallelCriterion, self).__init__()
self.criterions = []
self.weights = []
self.gradInput = []
self.repeatTarget = repeatTarget
def add(self, criterion, weight=1):
self.criterions.append(criterion)
self.weights.append(weight)
return self
def updateOutput(self, input, target):
self.output = 0
for i, criterion in enumerate(self.criterions):
current_target = target if self.repeatTarget else target[i]
self.output += self.weights[i] * criterion.updateOutput(input[i], current_target)
return self.output
def updateGradInput(self, input, target):
self.gradInput = recursiveResizeAs(self.gradInput, input)[0]
recursiveFill(self.gradInput, 0)
for i, criterion in enumerate(self.criterions):
current_target = target if self.repeatTarget else target[i]
recursiveAdd(self.gradInput[i], self.weights[i], criterion.updateGradInput(input[i], current_target))
return self.gradInput
def type(self, type=None, tensorCache=None):
self.gradInput = []
return super(ParallelCriterion, self).type(type, tensorCache)