Add a keepdim test to torch_test.
diff --git a/test/test_torch.py b/test/test_torch.py
index 741c6b4..5b8a3e9 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -188,6 +188,35 @@
def test_min(self):
self._testSelection(torch.min, min)
+ def test_dim_reduction(self):
+ dim_red_fns = [
+ "mean", "median", "mode", "norm", "prod",
+ "std", "sum", "var", "max", "min"]
+
+ def normfn_attr(t, dim, keepdim=False):
+ attr = getattr(torch, "norm")
+ return attr(t, 2, dim, keepdim)
+
+ for fn_name in dim_red_fns:
+ x = torch.randn(3, 4, 5)
+ fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
+
+ def fn(t, dim, keepdim=False):
+ ans = fn_attr(x, dim, keepdim)
+ return ans if not isinstance(ans, tuple) else ans[0]
+
+ dim = random.randint(0, 2)
+ self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, True))
+ self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
+ self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())
+
+ # check 1-d behavior
+ x = torch.randn(1)
+ dim = 0
+ self.assertEqual(fn(x, dim), fn(x, dim, True))
+ self.assertEqual(x.ndimension(), fn(x, dim).ndimension())
+ self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())
+
def _testCSelection(self, torchfn, mathfn):
# Two tensors
size = (100, 100)