blob: db2402b04ff8468c3a36e69ebd1dc21ed6e49917 [file] [log] [blame]
import torch
from .Module import Module
class CriterionTable(Module):
def __init__(self, criterion):
super(CriterionTable, self).__init__()
self.criterion = criterion
self.gradInput = [criterion.gradInput]
def updateOutput(self, input):
self.output = self.criterion.updateOutput(*input)
return self.output
def updateGradInput(self, input):
self.criterion.updateGradInput(*input)
return self.gradInput