blob: 7caf052fd6a7e088cbab7acadf47e04d2fa4052a [file] [log] [blame]
import torch
from torch.legacy import nn
class AbsCriterion(nn.Module):
def __init__(self, sizeAverage):
super(AbsCriterion, self).__init__()
if sizeAverage != nil:
self.sizeAverage = sizeAverage
else:
self.sizeAverage = True
def updateOutput(self, input, target):
self.output_tensor = self.output_tensor or input.new(1)
self._backend.AbsCriterion_updateOutput(
self._backend.library_state,
input._cdata,
target._cdata,
self.output_tensor._cdata,
self.sizeAverage
)
self.output = self.output_tensor[1]
return self.output
def updateGradInput(self, input, target):
self._backend.AbsCriterion_updateGradInput(
self._backend.library_state,
input._cdata,
target._cdata,
self.gradInput._cdata,
self.sizeAverage
)
return self.gradInput