Allow specifying output size in MaxUnpooling
diff --git a/test/test_nn.py b/test/test_nn.py
index 1333945..1537588 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -562,6 +562,32 @@
else:
self.assertRaises(ValueError, lambda: m(i, (h, w)))
+ def test_MaxUnpool2d_output_size(self):
+ m = nn.MaxPool2d(3, stride=2, return_indices=True)
+ mu = nn.MaxUnpool2d(3, stride=2)
+ big_t = torch.rand(1, 1, 6, 6)
+ big_t[0][0][4][4] = 100
+ output_big, indices_big = m(Variable(big_t))
+ self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
+
+ small_t = torch.rand(1, 1, 5, 5)
+ for i in range(0, 4, 2):
+ for j in range(0, 4, 2):
+ small_t[:,:,i,j] = 100
+ output_small, indices_small = m(Variable(small_t))
+ for h in range(3, 10):
+ for w in range(3, 10):
+ if 4 <= h <= 6 and 4 <= w <= 6:
+ size = (h, w)
+ if h == 5:
+ size = torch.LongStorage(size)
+ elif h == 6:
+ size = torch.LongStorage((1, 1) + size)
+ mu(output_small, indices_small, output_size=size)
+ else:
+ self.assertRaises(ValueError, lambda:
+ mu(output_small, indices_small, (h, w)))
+
def add_test(test):
test_name = test.get_name()
diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py
index 2d64cb8..760295f 100644
--- a/torch/nn/modules/pooling.py
+++ b/torch/nn/modules/pooling.py
@@ -100,7 +100,7 @@
stride: the stride of the window. Can be a single number s or a tuple (sh x sw). Default: kernel_size
padding: implicit padding that was added to the input. Can be a single number or a tuple. Default: 0
Input Shape: [ * , * , *, * ] : Input is minibatch x channels x iH x iW
- Output Shape:[ * , * , *, * ] : Output shape = minibatch x channels x padH x (iH - 1) * sH + kH x padW x (iW - 1) * sW + kW
+ Output Shape:[ * , * , *, * ] : Output shape is minibatch x channels x padH x (iH - 1) * sH + kH x padW x (iW - 1) * sW + kW, or as specified to the call.
Examples:
>>> # pool of square window of size=3, stride=2
>>> m = nn.MaxPool2d(2, stride=2, return_indices = True)
@@ -108,16 +108,47 @@
>>> input = autograd.Variable(torch.randn(20, 16, 50, 32))
>>> output, indices = m(input)
>>> unpooled_output = mu.forward(output, indices)
+ >>> # exact output size can be also specified as an argument
+ >>> input = autograd.Variable(torch.randn(1, 16, 11, 11))
+ >>> downsample = nn.MaxPool2d(3, 3, return_indices=True)
+ >>> upsample = nn.MaxUnpool2d(3, 3)
+ >>> h, indices = downsample(input)
+ >>> output = upsample(h, indices, output_size=input.size())
"""
def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool2d, self).__init__()
self.kh, self.kw = _pair(kernel_size)
self.dh, self.dw = _pair(stride or kernel_size)
self.padh, self.padw = _pair(padding)
+ self.output_size = None
+
+ def __call__(self, input, indices, output_size=None):
+ if output_size:
+ self.output_size = list(output_size)
+ if len(self.output_size) == 4:
+ self.output_size = self.output_size[-2:]
+ if len(self.output_size) != 2:
+ raise ValueError("output_size should be a sequence containing "
+ "2 or 4 elements, but it has a length of {}".format(
+ len(output_size)))
+ else:
+ self.output_size = None
+ return super(MaxUnpool2d, self).__call__(input, indices)
def forward(self, input, indices):
out_height = (input.size(2) - 1) * self.dh + self.kh - 2*self.padh
out_width = (input.size(3) - 1) * self.dw + self.kw - 2*self.padw
+ if self.output_size is not None:
+ h, w = self.output_size
+ h_ok = out_height - self.dh < h < out_height + self.dh
+ w_ok = out_width - self.dw < w < out_width + self.dw
+ if not h_ok or not w_ok:
+ raise ValueError(("specified incorrect output size. Got {}x{}, "
+ "but valid sizes range from {}x{} to {}x{}").format(
+ h, w,
+ out_height - self.dh + 1, out_width - self.dw + 1,
+ out_height + self.dh - 1, out_width + self.dw - 1))
+ out_height, out_width = h, w
return self._backend.MaxUnpool2d(out_width,
out_height)(input, indices)
@@ -194,6 +225,7 @@
self.ceil_mode = ceil_mode
def forward(self, input):
+ # TODO: allow to specify output size
return self._backend.MaxPool3d(self.kt, self.kw, self.kh,
self.dt, self.dw, self.dh, self.padt, self.padw, self.padh,
self.dilt, self.dilw, self.dilh,