blob: 2f0da29077d50805a5da8c4fd6fb10cdc495fbc3 [file] [log] [blame]
import torch
from torch.nn.functional import _Reduction
from .Criterion import Criterion
class WeightedMSECriterion(Criterion):
def __init__(self, weight, sizeAverage=True):
super(WeightedMSECriterion, self).__init__()
self.weight = weight.clone()
self.buffer = None
self.output_tensor = None
self.sizeAverage = sizeAverage
def updateOutput(self, input, target):
if self.buffer is None:
self.buffer = input.new()
self.buffer.resize_as_(input).copy_(target)
if input.dim() - 1 == self.weight.dim():
for i in range(input.size(0)):
self.buffer[i].mul_(self.weight)
else:
self.buffer.mul_(self.weight)
if self.output_tensor is None:
self.output_tensor = input.new(1)
self._backend.MSECriterion_updateOutput(
self._backend.library_state,
input,
self.buffer,
self.output_tensor,
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
def updateGradInput(self, input, target):
self.buffer.resize_as_(input).copy_(target)
if input.dim() - 1 == self.weight.dim():
for i in range(input.size(0)):
self.buffer[i].mul_(self.weight)
else:
self.buffer.mul_(self.weight)
implicit_gradOutput = torch.Tensor([1]).type(input.type())
self._backend.MSECriterion_updateGradInput(
self._backend.library_state,
input,
self.buffer,
implicit_gradOutput,
self.gradInput,
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput