blob: d01835c33d22793d75527ea7ff3e4d43ffa60fa8 [file] [log] [blame]
import torch
from .Module import Module
class CSubTable(Module):
def __init__(self, ):
super(CSubTable, self).__init__()
self.gradInput = [torch.Tensor(), torch.Tensor()]
def updateOutput(self, input):
self.output.resizeAs_(input[0]).copy_(input[0])
self.output.add_(-1, input[1])
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput[0] = self.gradInput[0] or input[0].new()
self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[0].resizeAs_(input[0]).copy_(gradOutput)
self.gradInput[1].resizeAs_(input[1]).copy_(gradOutput).mul_(-1)
self.gradInput = self.gradInput[:2]
return self.gradInput