blob: 782b8f5cd90b08d954dff992187c9363bae41050 [file] [log] [blame]
import torch
from .Module import Module
class Squeeze(Module):
def __init__(self, dim=None):
super(Squeeze, self).__init__()
self.dim = dim
def updateOutput(self, input):
dim = self.dim
self.output.set_(input.squeeze(dim) if dim is not None else input.squeeze())
return self.output
def updateGradInput(self, input, gradOutput):
assert input.nelement() == gradOutput.nelement()
self.gradInput.set_(gradOutput.contiguous().view_as(input))
return self.gradInput