| from torch.testing._internal.common_utils import TestCase, run_tests |
| import torch |
| from torch import vmap |
| |
| class TestVmapAPI(TestCase): |
| def test_non_tensor_output_raises(self): |
| with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"): |
| output = vmap(lambda x: 3.14)(torch.ones(3)) |
| |
| def multiple_outputs(x): |
| return x, 3 |
| |
| with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): |
| vmap(multiple_outputs)(torch.ones(3)) |
| |
| def test_different_map_dim_size_raises(self): |
| x = torch.randn(2) |
| y = torch.randn(3) |
| expected_msg = 'Expected all tensors to have the same size in the mapped dimension' |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(torch.mul)(x, y) |
| |
| def test_func_with_no_inputs(self): |
| expected_msg = 'got no inputs' |
| |
| def foo(): |
| return torch.randn(3) |
| |
| def bar(x): |
| return torch.randn(3) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(foo)() |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(bar)() |
| |
| def test_constant_function(self): |
| output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) |
| self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) |
| |
| def test_single_input(self): |
| x = torch.randn(2, 3) |
| |
| def square(x): |
| return x * x |
| |
| output = vmap(square)(x) |
| self.assertEqual(output, x * x) |
| |
| def test_multiple_inputs(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(2, 3) |
| output = vmap(torch.mul)(x, y) |
| self.assertEqual(output, x * y) |
| |
| def test_multiple_outputs(self): |
| def foo(x): |
| return x * x, x * x * x |
| |
| x = torch.randn(3) |
| outputs = vmap(foo)(x) |
| self.assertEqual(outputs[0], x * x) |
| self.assertEqual(outputs[1], x * x * x) |
| |
| def test_multiple_outputs_error_cases(self): |
| # This is the same thing as |
| # def returns_tuple_of_tensors(x): |
| # return x, x |
| def returns_tuple_of_tensors(x): |
| return (x, x) |
| |
| def returns_list_of_two_tensors(x): |
| return [x, x] |
| |
| def returns_list_of_one_tensor(x): |
| return [x] |
| |
| x = torch.randn(3) |
| |
| # should not throw |
| vmap(returns_tuple_of_tensors)(x) |
| |
| # jax supports these, but we don't yet |
| msg = "must only return Tensors, got type <class 'list'>" |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(returns_list_of_two_tensors)(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(returns_list_of_one_tensor)(x) |
| |
| def test_nested_with_same_map_dim(self): |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| output = vmap(vmap(torch.mul))(x, y) |
| self.assertEqual(output, x * y) |
| |
| output = vmap(vmap(vmap(torch.mul)))(x, y) |
| self.assertEqual(output, x * y) |
| |
| def test_nested_with_different_map_dim(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(5, 3) |
| output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) |
| self.assertEqual(output.shape, (2, 5, 3)) |
| self.assertEqual(output, x.view(2, 1, 3) * y) |
| |
| z = torch.randn(7, 3) |
| output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) |
| self.assertEqual(output.shape, (2, 5, 7, 3)) |
| self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) |
| |
| def test_noop_in_inner_vmap(self): |
| x = torch.randn(3) |
| y = torch.randn(5) |
| output = vmap(lambda x: vmap(lambda y: x)(y))(x) |
| self.assertEqual(output, x.view(3, 1).expand(3, 5)) |
| |
| def test_unsupported_op_err_msg(self): |
| def foo(x): |
| return torch.cos(x) |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, 'NYI: Calling aten::cos inside of vmap'): |
| vmap(foo)(x) |
| |
| def test_unsupported_inplace_op_err_msg(self): |
| def foo(x): |
| return x.cos_() |
| |
| x = torch.randn(3) |
| # TODO(rzou): Yeah, this error message is pretty bad because the |
| # dispatcher's fallback mechanism doesn't work for ops that don't support |
| # boxing. Fix the error message at some point. |
| with self.assertRaisesRegex( |
| RuntimeError, 'Tried to call KernelFunction::call'): |
| vmap(foo)(x) |
| |
| def test_nonzero_out_dims(self): |
| # Basic test |
| tensor = torch.randn(2, 3) |
| result = vmap(lambda x: x, out_dims=1)(tensor) |
| self.assertEqual(result, tensor.permute(1, 0)) |
| self.assertEqual(result.data_ptr(), tensor.data_ptr()) |
| |
| # Test that the batch dimension gets permuted to dim 2 |
| tensor = torch.randn(2, 3, 5, 7) |
| result = vmap(lambda x: x, out_dims=2)(tensor) |
| self.assertEqual(result, tensor.permute(1, 2, 0, 3)) |
| self.assertEqual(result.data_ptr(), tensor.data_ptr()) |
| |
| # negative out_dim |
| tensor = torch.randn(2, 3, 5, 7) |
| result = vmap(lambda x: x, out_dims=-1)(tensor) |
| self.assertEqual(result, tensor.permute(1, 2, 3, 0)) |
| self.assertEqual(result.data_ptr(), tensor.data_ptr()) |
| |
| # check that out_dims works on ALL outputs |
| tensor = torch.randn(2, 3, 5, 7) |
| other = torch.randn(2, 3, 5, 7) |
| result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) |
| self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))) |
| |
| # use out_dims with the maximum vmap-able tensor dims (64 dims) |
| ndims = 64 |
| shape = [2] + [1] * (ndims - 1) |
| expected_shape = [1, 1, 2] + [1] * (ndims - 3) |
| tensor = torch.randn(shape) |
| result = vmap(lambda x: x, out_dims=2)(tensor) |
| self.assertEqual(result.shape, expected_shape) |
| |
| # test something that is not the identity function |
| def foo(x, y): |
| return x, x * y, x * y * y |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| result = vmap(foo, out_dims=1)(x, y) |
| self.assertEqual( |
| result, |
| (x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2))) |
| |
| def test_multiple_out_dims(self): |
| def foo(x): |
| return x, x |
| |
| def bar(x, y): |
| return x, x, x, x * y |
| |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| result = vmap(foo, out_dims=(0, 1))(x) |
| self.assertEqual(result, (x, x.permute(1, 0, 2))) |
| |
| result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) |
| expected = ( |
| x.permute(1, 2, 0), |
| x, |
| x.permute(1, 0, 2), |
| (x * y).permute(1, 2, 0), |
| ) |
| self.assertEqual(result, expected) |
| |
| def test_nested_out_dims(self): |
| y = torch.randn(2, 3, 5, 7) |
| |
| # Inner vmap has non-zero out_dim |
| result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) |
| self.assertEqual(result.shape, (2, 5, 3, 7)) |
| self.assertEqual(result, y.permute(0, 2, 1, 3)) |
| |
| # all vmaps have non-zero out_dim |
| result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) |
| self.assertEqual(result.shape, (5, 2, 3, 7)) |
| self.assertEqual(result, y.permute(2, 0, 1, 3)) |
| |
| # throwing in some negative out_dims |
| result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) |
| self.assertEqual(result.shape, (5, 7, 3, 2)) |
| self.assertEqual(result, y.permute(2, 3, 1, 0)) |
| |
| # testing fn that isn't the identity |
| x = torch.randn(2, 3) |
| y = torch.randn(5, 3) |
| result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) |
| self.assertEqual(result.shape, (3, 2, 5)) |
| self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) |
| |
| def test_out_dims_edge_case(self): |
| def foo(x): |
| return x |
| |
| # Test that we accept out_dims=(1,) for a function with one output. |
| tensor = torch.randn(2, 3) |
| expected = vmap(foo, out_dims=1)(tensor) |
| result = vmap(foo, out_dims=(1,))(tensor) |
| self.assertEqual(result, expected) |
| |
| def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self): |
| msg = '`out_dims` must be an int or a tuple of int' |
| tensor = torch.randn(2, 3) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims='lol')(tensor) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=('lol',))(tensor) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=None)(tensor) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=(None,))(tensor) |
| |
| def test_out_dims_and_num_outputs_mismatch_err_msg(self): |
| msg = '`out_dims` must have one dim per output' |
| x = torch.randn(2, 3, 5) |
| |
| # Too many out_dims |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: x, out_dims=(0, 0))(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) |
| |
| # Too few out_dims |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: (x, x), out_dims=(0,))(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) |
| |
| def test_out_dim_out_of_bounds_err_msg(self): |
| # TODO(rzou): This error message isn't that great. It comes straight |
| # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to |
| # the error message in the future in C++ |
| msg = 'Dimension out of range' |
| x = torch.randn(2, 3, 5) |
| with self.assertRaisesRegex(IndexError, msg): |
| vmap(lambda x: x, out_dims=3)(x) |
| with self.assertRaisesRegex(IndexError, msg): |
| vmap(lambda x: x, out_dims=-4)(x) |
| |
| def test_non_zero_in_dims(self): |
| tensor = torch.randn(2, 3, 5) |
| |
| # Implicit out_dims = 0; vmap will move the batch dim to the front. |
| output = vmap(lambda x: x, (1,))(tensor) |
| self.assertEqual(output, tensor.permute(1, 0, 2)) |
| self.assertEqual(output.data_ptr(), tensor.data_ptr()) |
| |
| x = torch.randn(2, 3) |
| y = torch.randn(3, 2) |
| output = vmap(torch.mul, (0, 1))(x, y) |
| self.assertEqual(output, x * y.t()) |
| output = vmap(torch.mul, (1, 0))(x, y) |
| self.assertEqual(output, x.t() * y) |
| |
| def test_none_in_dims(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(2, 3) |
| |
| # None in_dim for a Tensor means we don't map over it |
| output = vmap(torch.mul, (0, None))(x, y) |
| self.assertEqual(output.shape, (2, 2, 3)) |
| self.assertEqual(output, x.view(2, 1, 3) * y) |
| |
| # None in_dim for non-tensor arguments |
| output = vmap(torch.mul, (0, None))(x, 2) |
| self.assertEqual(output, x * 2) |
| |
| def test_nested_non_default_in_dims(self): |
| x = torch.rand(5, 2, 3) |
| y = torch.rand(3, 5, 2) |
| result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y) |
| self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1)) |
| |
| def test_non_default_in_dims_out_dims(self): |
| x = torch.randn(2, 3, 5) |
| |
| # Same in_dim as out_dim, vmap over identity |
| result = vmap(lambda x: x, in_dims=1, out_dims=1)(x) |
| self.assertEqual(result, x) |
| self.assertEqual(result.data_ptr(), x.data_ptr()) |
| |
| # Different in_dim from out_dim, vmap over identity |
| result = vmap(lambda x: x, in_dims=2, out_dims=1)(x) |
| self.assertEqual(result.shape, (2, 5, 3)) |
| self.assertEqual(result, x.transpose(1, 2)) |
| self.assertEqual(result.data_ptr(), x.data_ptr()) |
| |
| def foo(x): |
| return x * 2 |
| |
| # Same in_dim as out_dim, vmap over operation |
| result = vmap(foo, in_dims=1, out_dims=1)(x) |
| self.assertEqual(result, x * 2) |
| |
| # Different in_dim as out_dim, vmap over operation |
| result = vmap(foo, in_dims=2, out_dims=1)(x) |
| self.assertEqual(result.shape, (2, 5, 3)) |
| self.assertEqual(result, (x * 2).transpose(1, 2)) |
| |
| # Basic nested test. |
| result = vmap(vmap(foo, 1, 1), 1, 1)(x) |
| self.assertEqual(result, x * 2) |
| |
| def test_in_dims_wrong_type_err_msg(self): |
| x = torch.randn(3) |
| y = torch.randn(3) |
| msg = 'expected `in_dims` to be int or tuple' |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.mul, [0, 0])(x, y) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.mul, set({0, 0}))(x, y) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.mul, 'lol')(x, y) |
| # The following should not throw |
| vmap(torch.mul, (0, 0))(x, y) |
| |
| def test_not_enough_in_dims_err_msg(self): |
| x = torch.randn(3) |
| y = torch.randn(3) |
| msg = r'expected one `in_dim` per input \(got \w+ inputs\)' |
| |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.mul, (0,))(x, y) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.mul, (0, 0, 0))(x, y) |
| # The following should not throw |
| vmap(torch.mul, (0, 0))(x, y) |
| |
| def test_in_dims_must_be_flat_tuple_err_msg(self): |
| msg = 'in_dims must be a flat tuple containing ints and/or Nones' |
| |
| x = torch.randn(3) |
| y = torch.randn(3) |
| z = torch.randn(3) |
| |
| def foo(xy): |
| return xy[0] * xy[1] |
| |
| def bar(x, yz): |
| return x * yz[0] * yz[1] |
| |
| # NB: jax supports all of the following, we don't yet. |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, ((0, 0),))((x, y)) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(bar, (0, (0, 0)))(x, (y, z)) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, ({0: 0, 1: 0},))({0: x, 1: y}) |
| |
| def test_integer_in_dim_but_not_tensor_input_err_msg(self): |
| def foo(xy): |
| return xy[0] * xy[1] |
| |
| def bar(x, yz): |
| return x * yz[0] * yz[1] |
| |
| x = torch.randn(2, 3) |
| y = torch.randn(2, 3) |
| |
| # jax supports these, we too can in the future. |
| msg = 'Got in_dim=0 for input 0, but input 0 is not a Tensor' |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo)((x, y)) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, (0,))((x, y)) |
| |
| # jax supports these as well, we too can in the future. |
| msg = 'Got in_dim=0 for input 1, but input 1 is not a Tensor' |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo)(x, (x, y)) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, (0, 0))(x, (x, y)) |
| |
| # the following are errors in jax (and will always be errors) |
| msg = 'Got in_dim=0 for input 1, but input 1 is not a Tensor' |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.sum)(x, 0) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(torch.sum, (0, 0))(x, 0) |
| # The following should not throw |
| vmap(torch.sum, (0, None))(x, 0) |
| |
| def test_in_dim_not_in_tensor_err_msg(self): |
| def foo(x): |
| return x * x |
| |
| msg = r'Got in_dim=-?\w for input 0, but input 0 is a Tensor of dimensionality \w' |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo)(torch.randn([])) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, in_dims=(0,))(torch.randn([])) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, in_dims=(-1,))(torch.randn(2, 3)) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(foo, in_dims=(2,))(torch.randn(2, 3)) |
| # the following should not throw |
| vmap(foo, in_dims=(0,))(torch.randn(2, 3)) |
| vmap(foo, in_dims=(1,))(torch.randn(2, 3)) |
| |
| if __name__ == '__main__': |
| run_tests() |