| from torch.testing._internal.common_utils import TestCase, run_tests |
| import torch |
| from torch import vmap |
| |
| class TestVmapAPI(TestCase): |
| def test_non_tensor_output_raises(self): |
| with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"): |
| output = vmap(lambda x: 3.14)(torch.ones(3)) |
| |
| def multiple_outputs(x): |
| return x, 3 |
| |
| with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): |
| vmap(multiple_outputs)(torch.ones(3)) |
| |
| def test_different_map_dim_size_raises(self): |
| x = torch.randn(2) |
| y = torch.randn(3) |
| expected_msg = 'Expected all tensors to have the same size in the mapped dimension' |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(torch.mul)(x, y) |
| |
| def test_func_with_no_inputs(self): |
| expected_msg = 'got no inputs' |
| |
| def foo(): |
| return torch.randn(3) |
| |
| def bar(x): |
| return torch.randn(3) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(foo)() |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(bar)() |
| |
| def test_constant_function(self): |
| output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) |
| self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) |
| |
| def test_single_input(self): |
| x = torch.randn(2, 3) |
| |
| def square(x): |
| return x * x |
| |
| output = vmap(square)(x) |
| self.assertEqual(output, x * x) |
| |
| def test_multiple_inputs(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(2, 3) |
| output = vmap(torch.mul)(x, y) |
| self.assertEqual(output, x * y) |
| |
| def test_multiple_outputs(self): |
| def foo(x): |
| return x * x, x * x * x |
| |
| x = torch.randn(3) |
| outputs = vmap(foo)(x) |
| self.assertEqual(outputs[0], x * x) |
| self.assertEqual(outputs[1], x * x * x) |
| |
| def test_multiple_outputs_error_cases(self): |
| # This is the same thing as |
| # def returns_tuple_of_tensors(x): |
| # return x, x |
| def returns_tuple_of_tensors(x): |
| return (x, x) |
| |
| def returns_list_of_two_tensors(x): |
| return [x, x] |
| |
| def returns_list_of_one_tensor(x): |
| return [x] |
| |
| x = torch.randn(3) |
| |
| # should not throw |
| vmap(returns_tuple_of_tensors)(x) |
| |
| # jax supports these, but we don't yet |
| msg = "must only return Tensors, got type <class 'list'>" |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(returns_list_of_two_tensors)(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(returns_list_of_one_tensor)(x) |
| |
| def test_nested_with_same_map_dim(self): |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| output = vmap(vmap(torch.mul))(x, y) |
| self.assertEqual(output, x * y) |
| |
| output = vmap(vmap(vmap(torch.mul)))(x, y) |
| self.assertEqual(output, x * y) |
| |
| def test_nested_with_different_map_dim(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(5, 3) |
| output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) |
| self.assertEqual(output.shape, (2, 5, 3)) |
| self.assertEqual(output, x.view(2, 1, 3) * y) |
| |
| z = torch.randn(7, 3) |
| output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) |
| self.assertEqual(output.shape, (2, 5, 7, 3)) |
| self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) |
| |
| def test_noop_in_inner_vmap(self): |
| x = torch.randn(3) |
| y = torch.randn(5) |
| output = vmap(lambda x: vmap(lambda y: x)(y))(x) |
| self.assertEqual(output, x.view(3, 1).expand(3, 5)) |
| |
| def test_unsupported_op_err_msg(self): |
| def foo(x): |
| return torch.cos(x) |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, 'NYI: Calling aten::cos inside of vmap'): |
| vmap(foo)(x) |
| |
| def test_unsupported_inplace_op_err_msg(self): |
| def foo(x): |
| return x.cos_() |
| |
| x = torch.randn(3) |
| # TODO(rzou): Yeah, this error message is pretty bad because the |
| # dispatcher's fallback mechanism doesn't work for ops that don't support |
| # boxing. Fix the error message at some point. |
| with self.assertRaisesRegex( |
| RuntimeError, 'Tried to call KernelFunction::call'): |
| vmap(foo)(x) |
| |
| def test_nonzero_out_dims(self): |
| # Basic test |
| tensor = torch.randn(2, 3) |
| result = vmap(lambda x: x, out_dims=1)(tensor) |
| self.assertEqual(result, tensor.permute(1, 0)) |
| self.assertEqual(result.data_ptr(), tensor.data_ptr()) |
| |
| # Test that the batch dimension gets permuted to dim 2 |
| tensor = torch.randn(2, 3, 5, 7) |
| result = vmap(lambda x: x, out_dims=2)(tensor) |
| self.assertEqual(result, tensor.permute(1, 2, 0, 3)) |
| self.assertEqual(result.data_ptr(), tensor.data_ptr()) |
| |
| # negative out_dim |
| tensor = torch.randn(2, 3, 5, 7) |
| result = vmap(lambda x: x, out_dims=-1)(tensor) |
| self.assertEqual(result, tensor.permute(1, 2, 3, 0)) |
| self.assertEqual(result.data_ptr(), tensor.data_ptr()) |
| |
| # check that out_dims works on ALL outputs |
| tensor = torch.randn(2, 3, 5, 7) |
| other = torch.randn(2, 3, 5, 7) |
| result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) |
| self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))) |
| |
| # use out_dims with the maximum vmap-able tensor dims (64 dims) |
| ndims = 64 |
| shape = [2] + [1] * (ndims - 1) |
| expected_shape = [1, 1, 2] + [1] * (ndims - 3) |
| tensor = torch.randn(shape) |
| result = vmap(lambda x: x, out_dims=2)(tensor) |
| self.assertEqual(result.shape, expected_shape) |
| |
| # test something that is not the identity function |
| def foo(x, y): |
| return x, x * y, x * y * y |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| result = vmap(foo, out_dims=1)(x, y) |
| self.assertEqual( |
| result, |
| (x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2))) |
| |
| def test_multiple_out_dims(self): |
| def foo(x): |
| return x, x |
| |
| def bar(x, y): |
| return x, x, x, x * y |
| |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| result = vmap(foo, out_dims=(0, 1))(x) |
| self.assertEqual(result, (x, x.permute(1, 0, 2))) |
| |
| result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) |
| expected = ( |
| x.permute(1, 2, 0), |
| x, |
| x.permute(1, 0, 2), |
| (x * y).permute(1, 2, 0), |
| ) |
| self.assertEqual(result, expected) |
| |
| def test_nested_out_dims(self): |
| y = torch.randn(2, 3, 5, 7) |
| |
| # Inner vmap has non-zero out_dim |
| result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) |
| self.assertEqual(result.shape, (2, 5, 3, 7)) |
| self.assertEqual(result, y.permute(0, 2, 1, 3)) |
| |
| # all vmaps have non-zero out_dim |
| result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) |
| self.assertEqual(result.shape, (5, 2, 3, 7)) |
| self.assertEqual(result, y.permute(2, 0, 1, 3)) |
| |
| # throwing in some negative out_dims |
| result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) |
| self.assertEqual(result.shape, (5, 7, 3, 2)) |
| self.assertEqual(result, y.permute(2, 3, 1, 0)) |
| |
| # testing fn that isn't the identity |
| x = torch.randn(2, 3) |
| y = torch.randn(5, 3) |
| result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) |
| self.assertEqual(result.shape, (3, 2, 5)) |
| self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) |
| |
| def test_out_dims_edge_case(self): |
| def foo(x): |
| return x |
| |
| # Test that we accept out_dims=(1,) for a function with one output. |
| tensor = torch.randn(2, 3) |
| expected = vmap(foo, out_dims=1)(tensor) |
| result = vmap(foo, out_dims=(1,))(tensor) |
| self.assertEqual(result, expected) |
| |
| def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self): |
| msg = '`out_dims` must be an int or a tuple of int' |
| tensor = torch.randn(2, 3) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims='lol')(tensor) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=('lol',))(tensor) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=None)(tensor) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=(None,))(tensor) |
| |
| def test_out_dims_and_num_outputs_mismatch_err_msg(self): |
| msg = '`out_dims` must have one dim per output' |
| x = torch.randn(2, 3, 5) |
| |
| # Too many out_dims |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=(0, 0))(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) |
| |
| # Too few out_dims |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: (x, x), out_dims=(0,))(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) |
| |
| def test_out_dim_out_of_bounds_err_msg(self): |
| # TODO(rzou): This error message isn't that great. It comes straight |
| # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to |
| # the error message in the future in C++ |
| msg = 'Dimension out of range' |
| x = torch.randn(2, 3, 5) |
| with self.assertRaisesRegex(IndexError, msg): |
| vmap(lambda x: x, out_dims=3)(x) |
| with self.assertRaisesRegex(IndexError, msg): |
| vmap(lambda x: x, out_dims=-4)(x) |
| |
| if __name__ == '__main__': |
| run_tests() |