blob: 21aaff890d97a3ab9f4b11c4c95ff1e11dfdda2e [file] [log] [blame]
import torch
from .Module import Module
from .utils import addSingletondimension
class Unsqueeze(Module):
def __init__(self, dim):
super(Unsqueeze, self).__init__()
self.dim = dim
def updateOutput(self, input):
addSingletondimension(self.output, input, self.dim)
return self.output
def updateGradInput(self, input, gradOutput):
assert input.nelement() == gradOutput.nelement()
self.gradInput = gradOutput.contiguous().view(input.size())
return self.gradInput
def __repr__(self):
return super(Unsqueeze, self).__repr__() + '({})'.format(self.dim)