Allow specifying output size in MaxUnpooling
diff --git a/test/test_nn.py b/test/test_nn.py
index 1333945..1537588 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -562,6 +562,32 @@
                 else:
                     self.assertRaises(ValueError, lambda: m(i, (h, w)))
 
+    def test_MaxUnpool2d_output_size(self):
+        m = nn.MaxPool2d(3, stride=2, return_indices=True)
+        mu = nn.MaxUnpool2d(3, stride=2)
+        big_t = torch.rand(1, 1, 6, 6)
+        big_t[0][0][4][4] = 100
+        output_big, indices_big = m(Variable(big_t))
+        self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
+
+        small_t = torch.rand(1, 1, 5, 5)
+        for i in range(0, 4, 2):
+            for j in range(0, 4, 2):
+                small_t[:,:,i,j] = 100
+        output_small, indices_small = m(Variable(small_t))
+        for h in range(3, 10):
+            for w in range(3, 10):
+                if 4 <= h <= 6 and 4 <= w <= 6:
+                    size = (h, w)
+                    if h == 5:
+                        size = torch.LongStorage(size)
+                    elif h == 6:
+                        size = torch.LongStorage((1, 1) + size)
+                    mu(output_small, indices_small, output_size=size)
+                else:
+                    self.assertRaises(ValueError, lambda:
+                            mu(output_small, indices_small, (h, w)))
+
 
 def add_test(test):
     test_name = test.get_name()
diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py
index 2d64cb8..760295f 100644
--- a/torch/nn/modules/pooling.py
+++ b/torch/nn/modules/pooling.py
@@ -100,7 +100,7 @@
         stride: the stride of the window. Can be a single number s or a tuple (sh x sw). Default: kernel_size
         padding: implicit padding that was added to the input. Can be a single number or a tuple. Default: 0
     Input Shape: [ * , * , *, * ] : Input is minibatch x channels x iH x iW
-    Output Shape:[ * , * , *, * ]  : Output shape = minibatch x channels x padH x (iH - 1) * sH + kH x padW x (iW - 1) * sW + kW
+    Output Shape:[ * , * , *, * ]  : Output shape is minibatch x channels x padH x (iH - 1) * sH + kH x padW x (iW - 1) * sW + kW, or as specified to the call.
     Examples:
         >>> # pool of square window of size=3, stride=2
         >>> m = nn.MaxPool2d(2, stride=2, return_indices = True)
@@ -108,16 +108,47 @@
         >>> input = autograd.Variable(torch.randn(20, 16, 50, 32))
         >>> output, indices = m(input)
         >>> unpooled_output = mu.forward(output, indices)
+        >>> # exact output size can be also specified as an argument
+        >>> input = autograd.Variable(torch.randn(1, 16, 11, 11))
+        >>> downsample = nn.MaxPool2d(3, 3, return_indices=True)
+        >>> upsample = nn.MaxUnpool2d(3, 3)
+        >>> h, indices = downsample(input)
+        >>> output = upsample(h, indices, output_size=input.size())
     """
     def __init__(self, kernel_size, stride=None, padding=0):
         super(MaxUnpool2d, self).__init__()
         self.kh, self.kw = _pair(kernel_size)
         self.dh, self.dw = _pair(stride or kernel_size)
         self.padh, self.padw = _pair(padding)
+        self.output_size = None
+
+    def __call__(self, input, indices, output_size=None):
+        if output_size:
+            self.output_size = list(output_size)
+            if len(self.output_size) == 4:
+                self.output_size = self.output_size[-2:]
+            if len(self.output_size) != 2:
+                raise ValueError("output_size should be a sequence containing "
+                        "2 or 4 elements, but it has a length of {}".format(
+                            len(output_size)))
+        else:
+            self.output_size = None
+        return super(MaxUnpool2d, self).__call__(input, indices)
 
     def forward(self, input, indices):
         out_height = (input.size(2) - 1) * self.dh + self.kh - 2*self.padh
         out_width = (input.size(3) - 1) * self.dw + self.kw - 2*self.padw
+        if self.output_size is not None:
+            h, w = self.output_size
+            h_ok = out_height - self.dh < h < out_height + self.dh
+            w_ok = out_width - self.dw < w < out_width + self.dw
+            if not h_ok or not w_ok:
+                raise ValueError(("specified incorrect output size. Got {}x{}, "
+                        "but valid sizes range from {}x{} to {}x{}").format(
+                            h, w,
+                            out_height - self.dh + 1, out_width - self.dw + 1,
+                            out_height + self.dh - 1, out_width + self.dw - 1))
+            out_height, out_width = h, w
         return self._backend.MaxUnpool2d(out_width,
                 out_height)(input, indices)
 
@@ -194,6 +225,7 @@
         self.ceil_mode = ceil_mode
 
     def forward(self, input):
+        # TODO: allow to specify output size
         return self._backend.MaxPool3d(self.kt, self.kw, self.kh,
                 self.dt, self.dw, self.dh, self.padt, self.padw, self.padh,
                 self.dilt, self.dilw, self.dilh,