|  | # Owner(s): ["module: vmap"] | 
|  |  | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests | 
|  | import torch | 
|  | import torch.nn.functional as F | 
|  | from torch import Tensor, vmap | 
|  | import functools | 
|  | import itertools | 
|  | import warnings | 
|  | from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ | 
|  | skipCUDAIfNoMagma | 
|  | import types | 
|  |  | 
|  |  | 
|  | FALLBACK_REGEX = r'There is a performance drop' | 
|  |  | 
|  | class EnableVmapFallbackWarnings: | 
|  | def __enter__(self): | 
|  | self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled() | 
|  | torch._C._debug_only_display_vmap_fallback_warnings(True) | 
|  |  | 
|  | def __exit__(self, *ignored): | 
|  | torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state) | 
|  |  | 
|  | class TestVmapAPI(TestCase): | 
|  | def test_non_tensor_output_raises(self): | 
|  | with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"): | 
|  | output = vmap(lambda x: 3.14)(torch.ones(3)) | 
|  |  | 
|  | def multiple_outputs(x): | 
|  | return x, 3 | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): | 
|  | vmap(multiple_outputs)(torch.ones(3)) | 
|  |  | 
|  | def test_different_map_dim_size_raises(self): | 
|  | x = torch.randn(2) | 
|  | y = torch.randn(3) | 
|  | expected_msg = 'Expected all tensors to have the same size in the mapped dimension' | 
|  | with self.assertRaisesRegex(ValueError, expected_msg): | 
|  | vmap(torch.mul)(x, y) | 
|  | with self.assertRaisesRegex(ValueError, expected_msg): | 
|  | vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) | 
|  | with self.assertRaisesRegex(ValueError, expected_msg): | 
|  | vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y}) | 
|  |  | 
|  | def test_func_with_no_inputs(self): | 
|  | expected_msg = 'got no inputs' | 
|  |  | 
|  | def foo(): | 
|  | return torch.randn(3) | 
|  |  | 
|  | def bar(x): | 
|  | return torch.randn(3) | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, expected_msg): | 
|  | vmap(foo)() | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, expected_msg): | 
|  | vmap(bar)() | 
|  |  | 
|  | def test_constant_function(self): | 
|  | output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) | 
|  | self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) | 
|  |  | 
|  | def test_single_input(self): | 
|  | x = torch.randn(2, 3) | 
|  |  | 
|  | def square(x): | 
|  | return x * x | 
|  |  | 
|  | output = vmap(square)(x) | 
|  | self.assertEqual(output, x * x) | 
|  |  | 
|  | def test_multiple_inputs(self): | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(2, 3) | 
|  | output = vmap(torch.mul)(x, y) | 
|  | self.assertEqual(output, x * y) | 
|  |  | 
|  | def test_multiple_outputs(self): | 
|  | def foo(x): | 
|  | return x * x, x * x * x | 
|  |  | 
|  | x = torch.randn(3) | 
|  | outputs = vmap(foo)(x) | 
|  | self.assertEqual(outputs[0], x * x) | 
|  | self.assertEqual(outputs[1], x * x * x) | 
|  |  | 
|  | def test_multiple_outputs_error_cases(self): | 
|  | # This is the same thing as | 
|  | # def returns_tuple_of_tensors(x): | 
|  | #     return x, x | 
|  | def returns_tuple_of_tensors(x): | 
|  | return (x, x) | 
|  |  | 
|  | def returns_list_of_two_tensors(x): | 
|  | return [x, x] | 
|  |  | 
|  | def returns_list_of_one_tensor(x): | 
|  | return [x] | 
|  |  | 
|  | x = torch.randn(3) | 
|  |  | 
|  | # should not throw | 
|  | vmap(returns_tuple_of_tensors)(x) | 
|  |  | 
|  | # jax supports these, but we don't yet | 
|  | msg = "must only return Tensors, got type <class 'list'>" | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(returns_list_of_two_tensors)(x) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(returns_list_of_one_tensor)(x) | 
|  |  | 
|  | def test_nested_with_same_map_dim(self): | 
|  | x = torch.randn(2, 3, 5) | 
|  | y = torch.randn(2, 3, 5) | 
|  | output = vmap(vmap(torch.mul))(x, y) | 
|  | self.assertEqual(output, x * y) | 
|  |  | 
|  | output = vmap(vmap(vmap(torch.mul)))(x, y) | 
|  | self.assertEqual(output, x * y) | 
|  |  | 
|  | def test_nested_with_different_map_dim(self): | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(5, 3) | 
|  | output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) | 
|  | self.assertEqual(output.shape, (2, 5, 3)) | 
|  | self.assertEqual(output, x.view(2, 1, 3) * y) | 
|  |  | 
|  | z = torch.randn(7, 3) | 
|  | output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) | 
|  | self.assertEqual(output.shape, (2, 5, 7, 3)) | 
|  | self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) | 
|  |  | 
|  | def test_noop_in_inner_vmap(self): | 
|  | x = torch.randn(3) | 
|  | y = torch.randn(5) | 
|  | output = vmap(lambda x: vmap(lambda y: x)(y))(x) | 
|  | self.assertEqual(output, x.view(3, 1).expand(3, 5)) | 
|  |  | 
|  | def test_unsupported_op_err_msg(self): | 
|  | # Unsupported view op | 
|  | tensor = torch.randn(2, 3) | 
|  | msg = ( | 
|  | r"Batching rule not implemented for aten::.+; the " | 
|  | r"fallback path doesn't work on out= or view ops" | 
|  | ) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(torch.ravel)(tensor) | 
|  |  | 
|  | def out_op(x, y): | 
|  | return torch.abs(x, out=y) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(out_op)(tensor, tensor) | 
|  |  | 
|  | tensor = torch.randn(2) | 
|  | # The fallback doesn't support TensorList | 
|  | with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'): | 
|  | vmap(lambda t: torch.atleast_1d([t]))(tensor) | 
|  |  | 
|  | # Don't support non-tensor returns. This is a limitation of vmap; | 
|  | # functions that don't return tensors must be special cased | 
|  | with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'): | 
|  | vmap(torch.Tensor.item)(tensor) | 
|  |  | 
|  | def test_nonzero_out_dims(self): | 
|  | # Basic test | 
|  | tensor = torch.randn(2, 3) | 
|  | result = vmap(lambda x: x, out_dims=1)(tensor) | 
|  | self.assertEqual(result, tensor.permute(1, 0)) | 
|  | self.assertEqual(result.data_ptr(), tensor.data_ptr()) | 
|  |  | 
|  | # Test that the batch dimension gets permuted to dim 2 | 
|  | tensor = torch.randn(2, 3, 5, 7) | 
|  | result = vmap(lambda x: x, out_dims=2)(tensor) | 
|  | self.assertEqual(result, tensor.permute(1, 2, 0, 3)) | 
|  | self.assertEqual(result.data_ptr(), tensor.data_ptr()) | 
|  |  | 
|  | # negative out_dim | 
|  | tensor = torch.randn(2, 3, 5, 7) | 
|  | result = vmap(lambda x: x, out_dims=-1)(tensor) | 
|  | self.assertEqual(result, tensor.permute(1, 2, 3, 0)) | 
|  | self.assertEqual(result.data_ptr(), tensor.data_ptr()) | 
|  |  | 
|  | # check that out_dims works on ALL outputs | 
|  | tensor = torch.randn(2, 3, 5, 7) | 
|  | other = torch.randn(2, 3, 5, 7) | 
|  | result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) | 
|  | self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))) | 
|  |  | 
|  | # use out_dims with the maximum vmap-able tensor dims (64 dims) | 
|  | ndims = 64 | 
|  | shape = [2] + [1] * (ndims - 1) | 
|  | expected_shape = [1, 1, 2] + [1] * (ndims - 3) | 
|  | tensor = torch.randn(shape) | 
|  | result = vmap(lambda x: x, out_dims=2)(tensor) | 
|  | self.assertEqual(result.shape, expected_shape) | 
|  |  | 
|  | # test something that is not the identity function | 
|  | def foo(x, y): | 
|  | return x, x * y, x * y * y | 
|  | x = torch.randn(2, 3, 5) | 
|  | y = torch.randn(2, 3, 5) | 
|  | result = vmap(foo, out_dims=1)(x, y) | 
|  | self.assertEqual( | 
|  | result, | 
|  | (x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2))) | 
|  |  | 
|  | def test_multiple_out_dims(self): | 
|  | def foo(x): | 
|  | return x, x | 
|  |  | 
|  | def bar(x, y): | 
|  | return x, x, x, x * y | 
|  |  | 
|  | x = torch.randn(2, 3, 5) | 
|  | y = torch.randn(2, 3, 5) | 
|  | result = vmap(foo, out_dims=(0, 1))(x) | 
|  | self.assertEqual(result, (x, x.permute(1, 0, 2))) | 
|  |  | 
|  | result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) | 
|  | expected = ( | 
|  | x.permute(1, 2, 0), | 
|  | x, | 
|  | x.permute(1, 0, 2), | 
|  | (x * y).permute(1, 2, 0), | 
|  | ) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | def test_nested_out_dims(self): | 
|  | y = torch.randn(2, 3, 5, 7) | 
|  |  | 
|  | # Inner vmap has non-zero out_dim | 
|  | result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) | 
|  | self.assertEqual(result.shape, (2, 5, 3, 7)) | 
|  | self.assertEqual(result, y.permute(0, 2, 1, 3)) | 
|  |  | 
|  | # all vmaps have non-zero out_dim | 
|  | result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) | 
|  | self.assertEqual(result.shape, (5, 2, 3, 7)) | 
|  | self.assertEqual(result, y.permute(2, 0, 1, 3)) | 
|  |  | 
|  | # throwing in some negative out_dims | 
|  | result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) | 
|  | self.assertEqual(result.shape, (5, 7, 3, 2)) | 
|  | self.assertEqual(result, y.permute(2, 3, 1, 0)) | 
|  |  | 
|  | # testing fn that isn't the identity | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(5, 3) | 
|  | result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) | 
|  | self.assertEqual(result.shape, (3, 2, 5)) | 
|  | self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) | 
|  |  | 
|  | def test_out_dims_edge_case(self): | 
|  | def foo(x): | 
|  | return x | 
|  |  | 
|  | # Test that we accept out_dims=(1,) for a function with one output. | 
|  | tensor = torch.randn(2, 3) | 
|  | expected = vmap(foo, out_dims=1)(tensor) | 
|  | result = vmap(foo, out_dims=(1,))(tensor) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self): | 
|  | msg = '`out_dims` must be an int or a tuple of int' | 
|  | tensor = torch.randn(2, 3) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: x, out_dims='lol')(tensor) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: x, out_dims=('lol',))(tensor) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: x, out_dims=None)(tensor) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: x, out_dims=(None,))(tensor) | 
|  |  | 
|  | def test_out_dims_and_num_outputs_mismatch_err_msg(self): | 
|  | msg = '`out_dims` must have one dim per output' | 
|  | x = torch.randn(2, 3, 5) | 
|  |  | 
|  | # Too many out_dims | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: x, out_dims=(0, 0))(x) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) | 
|  |  | 
|  | # Too few out_dims | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: (x, x), out_dims=(0,))(x) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) | 
|  |  | 
|  | def test_out_dim_out_of_bounds_err_msg(self): | 
|  | # TODO(rzou): This error message isn't that great. It comes straight | 
|  | # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to | 
|  | # the error message in the future in C++ | 
|  | msg = 'Dimension out of range' | 
|  | x = torch.randn(2, 3, 5) | 
|  | with self.assertRaisesRegex(IndexError, msg): | 
|  | vmap(lambda x: x, out_dims=3)(x) | 
|  | with self.assertRaisesRegex(IndexError, msg): | 
|  | vmap(lambda x: x, out_dims=-4)(x) | 
|  |  | 
|  | def test_non_zero_in_dims(self): | 
|  | tensor = torch.randn(2, 3, 5) | 
|  |  | 
|  | # Implicit out_dims = 0; vmap will move the batch dim to the front. | 
|  | output = vmap(lambda x: x, (1,))(tensor) | 
|  | self.assertEqual(output, tensor.permute(1, 0, 2)) | 
|  | self.assertEqual(output.data_ptr(), tensor.data_ptr()) | 
|  |  | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(3, 2) | 
|  | output = vmap(torch.mul, (0, 1))(x, y) | 
|  | self.assertEqual(output, x * y.t()) | 
|  | output = vmap(torch.mul, (1, 0))(x, y) | 
|  | self.assertEqual(output, x.t() * y) | 
|  |  | 
|  | def test_none_in_dims(self): | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(2, 3) | 
|  |  | 
|  | # None in_dim for a Tensor means we don't map over it | 
|  | output = vmap(torch.mul, (0, None))(x, y) | 
|  | self.assertEqual(output.shape, (2, 2, 3)) | 
|  | self.assertEqual(output, x.view(2, 1, 3) * y) | 
|  |  | 
|  | # None in_dim for non-tensor arguments | 
|  | output = vmap(torch.mul, (0, None))(x, 2) | 
|  | self.assertEqual(output, x * 2) | 
|  |  | 
|  | def test_nested_non_default_in_dims(self): | 
|  | x = torch.rand(5, 2, 3) | 
|  | y = torch.rand(3, 5, 2) | 
|  | result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y) | 
|  | self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1)) | 
|  |  | 
|  | def test_non_default_in_dims_out_dims(self): | 
|  | x = torch.randn(2, 3, 5) | 
|  |  | 
|  | # Same in_dim as out_dim, vmap over identity | 
|  | result = vmap(lambda x: x, in_dims=1, out_dims=1)(x) | 
|  | self.assertEqual(result, x) | 
|  | self.assertEqual(result.data_ptr(), x.data_ptr()) | 
|  |  | 
|  | # Different in_dim from out_dim, vmap over identity | 
|  | result = vmap(lambda x: x, in_dims=2, out_dims=1)(x) | 
|  | self.assertEqual(result.shape, (2, 5, 3)) | 
|  | self.assertEqual(result, x.transpose(1, 2)) | 
|  | self.assertEqual(result.data_ptr(), x.data_ptr()) | 
|  |  | 
|  | def foo(x): | 
|  | return x * 2 | 
|  |  | 
|  | # Same in_dim as out_dim, vmap over operation | 
|  | result = vmap(foo, in_dims=1, out_dims=1)(x) | 
|  | self.assertEqual(result, x * 2) | 
|  |  | 
|  | # Different in_dim as out_dim, vmap over operation | 
|  | result = vmap(foo, in_dims=2, out_dims=1)(x) | 
|  | self.assertEqual(result.shape, (2, 5, 3)) | 
|  | self.assertEqual(result, (x * 2).transpose(1, 2)) | 
|  |  | 
|  | # Basic nested test. | 
|  | result = vmap(vmap(foo, 1, 1), 1, 1)(x) | 
|  | self.assertEqual(result, x * 2) | 
|  |  | 
|  | def test_accepts_nested_inputs(self): | 
|  | B0 = 2 | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(2, 3) | 
|  |  | 
|  | # Single layer of nesting | 
|  | out = vmap(lambda z: z[0] + z[1])((x, y)) | 
|  | self.assertEqual(out, x + y) | 
|  | out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y)) | 
|  | self.assertEqual(out, x + y) | 
|  | out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) | 
|  | self.assertEqual(out, x + y) | 
|  |  | 
|  | out = vmap(lambda z: z[0] + z[1])([x, y]) | 
|  | self.assertEqual(out, x + y) | 
|  | out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y]) | 
|  | self.assertEqual(out, x + y) | 
|  | out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y]) | 
|  | self.assertEqual(out, x + y) | 
|  |  | 
|  | out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y}) | 
|  | self.assertEqual(out, x + y) | 
|  | out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y}) | 
|  | self.assertEqual(out, x + y) | 
|  | out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y}) | 
|  | self.assertEqual(out, x + y) | 
|  |  | 
|  | # Multiple layers of nesting | 
|  | out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1]) | 
|  | out = out_fn({'x': [x, (x,)], 'y': [y, y]}) | 
|  | self.assertEqual(out, x + x + y + y) | 
|  |  | 
|  | def test_in_dims_wrong_type_err_msg(self): | 
|  | x = torch.randn(3) | 
|  | y = torch.randn(3) | 
|  | msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple' | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.mul, [0, 0])(x, y) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.mul, set({0, 0}))(x, y) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.mul, 'lol')(x, y) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y]) | 
|  | # The following should not throw | 
|  | vmap(torch.mul, (0, 0))(x, y) | 
|  |  | 
|  | def test_not_enough_in_dims_err_msg(self): | 
|  | x = torch.randn(3) | 
|  | y = torch.randn(3) | 
|  | msg = r'in_dims is not compatible with the structure of `inputs`' | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.mul, (0,))(x, y) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.mul, (0, 0, 0))(x, y) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y]) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y]) | 
|  | # The following should not throw | 
|  | vmap(torch.mul, (0, 0))(x, y) | 
|  |  | 
|  | def test_integer_in_dim_but_not_tensor_input_err_msg(self): | 
|  | def foo(xy): | 
|  | return xy[0] * xy[1] | 
|  |  | 
|  | def bar(x, yz): | 
|  | return x * yz[0] * yz[1] | 
|  |  | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(2, 3) | 
|  |  | 
|  | # the following are errors in jax (and will always be errors) | 
|  | msg = 'Got in_dim=0 for an input but the input is of type' | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.sum)(x, 0) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(torch.sum, (0, 0))(x, 0) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1]) | 
|  | # The following should not throw | 
|  | vmap(torch.sum, (0, None))(x, 0) | 
|  |  | 
|  | def test_in_dim_not_in_tensor_err_msg(self): | 
|  | def foo(x): | 
|  | return x * x | 
|  |  | 
|  | x = torch.randn(2, 3) | 
|  | y = torch.randn(2, 3) | 
|  |  | 
|  | msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w' | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(foo)(torch.randn([])) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(foo, in_dims=(0,))(torch.randn([])) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(foo, in_dims=(-1,))(x) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(foo, in_dims=(2,))(y) | 
|  | with self.assertRaisesRegex(ValueError, msg): | 
|  | vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y]) | 
|  | # the following should not throw | 
|  | vmap(foo, in_dims=(0,))(torch.randn(2, 3)) | 
|  | vmap(foo, in_dims=(1,))(torch.randn(2, 3)) | 
|  |  | 
|  | def test_fallback_does_not_warn_by_default(self): | 
|  | # NB: One day we will implement a batching rule for torch.atan2. | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = torch.atan2 | 
|  | x = torch.randn(11) | 
|  | y = torch.randn(11) | 
|  | with warnings.catch_warnings(record=True) as wa: | 
|  | result = vmap(op)(x, y) | 
|  | # The single warning here is the "vmap is experimental" | 
|  | # warning, not a warning from the vmap fallback path. | 
|  | self.assertEqual(len(wa), 1) | 
|  |  | 
|  | def test_fallback_warns_when_warnings_are_enabled(self): | 
|  | # NB: One day we will implement a batching rule for torch.atan2. | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = torch.atan2 | 
|  | x = torch.randn(11) | 
|  | y = torch.randn(11) | 
|  | with warnings.catch_warnings(record=True) as wa: | 
|  | with EnableVmapFallbackWarnings(): | 
|  | result = vmap(op)(x, y) | 
|  | self.assertEqual(len(wa), 2) | 
|  | self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) | 
|  |  | 
|  | def _assert_uses_vmap_fallback(self, vmap_args, inputs): | 
|  | with warnings.catch_warnings(record=True) as wa: | 
|  | with EnableVmapFallbackWarnings(): | 
|  | result = vmap(*vmap_args)(*inputs) | 
|  | self.assertEqual(len(wa), 2) | 
|  | self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) | 
|  |  | 
|  | def test_fallback_zero_dim(self): | 
|  | # NB: One day we will implement a batching rule for torch.atan2. | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = torch.atan2 | 
|  | x = torch.randn(11) | 
|  | y = torch.randn(11) | 
|  | self._assert_uses_vmap_fallback((op,), (x, y)) | 
|  |  | 
|  | B0, B1 = 0, 3 | 
|  | x = torch.randn(B0, 11) | 
|  | y = torch.randn(11) | 
|  |  | 
|  | msg = 'The fallback path does not support vmap over dims of size 0' | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, (0, None))(x, y) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, (None, 0))(y, x) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(x, x) | 
|  |  | 
|  | x = torch.randn(B0, B1, 11) | 
|  | y = torch.randn(B1, 11) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, (0, None))(x, y) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, (None, 0))(y, x) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(x, x) | 
|  |  | 
|  | def test_fallback_atan2(self): | 
|  | # NB: One day we will implement a batching rule for torch.atan2. | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = torch.atan2 | 
|  |  | 
|  | x = torch.randn(5, 7, 11) | 
|  | y = torch.randn(5, 7, 11) | 
|  |  | 
|  | self._assert_uses_vmap_fallback((op,), (x, y)) | 
|  |  | 
|  | # fallback on torch.atan2 | 
|  | x = torch.randn(7, 11, 5) | 
|  | y = torch.randn(5, 7, 11) | 
|  | result = vmap(op, (2, 0))(x, y) | 
|  | self.assertEqual(result, op(x.permute(2, 0, 1), y)) | 
|  |  | 
|  | # fallback on torch.atan2, nested vmap | 
|  | x = torch.randn(7, 11, 5) | 
|  | y = torch.randn(5, 7, 11) | 
|  | result = vmap(vmap(op), (2, 0))(x, y) | 
|  | self.assertEqual(result, op(x.permute(2, 0, 1), y)) | 
|  |  | 
|  | # big batch size (total 10000) | 
|  | x = torch.randn(100, 10, 10, 5) | 
|  | y = torch.randn(100, 10, 10) | 
|  | result = vmap(vmap(vmap(op)))(x, y) | 
|  | self.assertEqual(result, op(x, y.view(100, 10, 10, 1))) | 
|  |  | 
|  | def test_fallback_masked_fill(self): | 
|  | # NB: One day we will implement a batching rule for masked_fill | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | def run_test(batch_size): | 
|  | B0 = batch_size | 
|  | x = torch.randn(B0, 7, 11, 13) | 
|  | dim = 0 | 
|  | index = torch.tensor([0, 4, 2]) | 
|  | values = torch.randn(B0, 3, 13) | 
|  |  | 
|  | self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values)) | 
|  |  | 
|  | result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values) | 
|  | expected = torch.index_add( | 
|  | x, dim + 1, index, values.view(B0, 3, 1, 13)) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | run_test(batch_size=5) | 
|  | run_test(batch_size=1237) | 
|  |  | 
|  | def test_fallback_multiple_returns(self): | 
|  | # NB: One day we will implement a batching rule for torch.var_mean | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | B0, B1, B2 = 2, 3, 1237 | 
|  | tensor = torch.randn(B0, 10) | 
|  |  | 
|  | self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,)) | 
|  |  | 
|  | # fallback correctness on torch.var_mean | 
|  | result = vmap(torch.var_mean)(tensor) | 
|  | expected = torch.var_mean(tensor, dim=1) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | # nested vmap | 
|  | tensor = torch.randn(B0, B1, 10) | 
|  | result = vmap(vmap(torch.var_mean))(tensor) | 
|  | expected = torch.var_mean(tensor, dim=2) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | # big batch size, nested vmap | 
|  | tensor = torch.randn(B0, B1, B2, 10) | 
|  | result = vmap(vmap(vmap(torch.var_mean)))(tensor) | 
|  | expected = torch.var_mean(tensor, dim=3) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | def test_inplace_fallback_unary(self): | 
|  | # Test the in-place fallback on an in-place method that takes no | 
|  | # additional Tensor arguments. This is the simplest case of the fallback. | 
|  | # NB: One day we will implement a batching rule for acos_. | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = Tensor.acos_ | 
|  | B0, B1, B2 = 2, 3, 10000 | 
|  |  | 
|  | x = torch.randn(B0, 5) | 
|  | self._assert_uses_vmap_fallback((op,), (x,)) | 
|  |  | 
|  | # Single vmap | 
|  | x_orig = torch.rand(B0, 5) | 
|  | x = x_orig.clone() | 
|  | result = vmap(op)(x) | 
|  | self.assertTrue(result is x) | 
|  | self.assertEqual(result, x_orig.acos()) | 
|  |  | 
|  | # Single vmap + different out_dim produces a view(!) | 
|  | x_orig = torch.rand(B0, 5) | 
|  | x = x_orig.clone() | 
|  | result = vmap(op, out_dims=(1,))(x) | 
|  | self.assertTrue(result._base is x) | 
|  | self.assertEqual(result, x_orig.t().acos()) | 
|  |  | 
|  | # Nested vmap | 
|  | x_orig = torch.randn(B0, B1, 5) | 
|  | x = x_orig.clone() | 
|  | result = vmap(vmap(op))(x) | 
|  | self.assertTrue(result is x) | 
|  | self.assertEqual(result, x_orig.acos()) | 
|  |  | 
|  | # Nested vmap, large batch size | 
|  | x_orig = torch.randn(B0, B1, B2, 5) | 
|  | x = x_orig.clone() | 
|  | result = vmap(vmap(vmap(op)))(x) | 
|  | self.assertTrue(result is x) | 
|  | self.assertEqual(result, x_orig.acos()) | 
|  |  | 
|  | def test_inplace_fallback_nary_same_levels(self): | 
|  | # NB: One day we will implement a batching rule for atan2_ | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = Tensor.atan2_ | 
|  | outplace_op = torch.atan2 | 
|  |  | 
|  | x = torch.randn(5, 7, 11) | 
|  | y = torch.randn(5, 7, 11) | 
|  | self._assert_uses_vmap_fallback((op,), (x, y)) | 
|  |  | 
|  | # Single vmap | 
|  | B0 = 5 | 
|  | x_orig = torch.randn(7, 11, B0) | 
|  | x = x_orig.clone() | 
|  | y = torch.randn(B0, 7, 11) | 
|  | vmap(op, (2, 0))(x, y) | 
|  | self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2))) | 
|  |  | 
|  | # Nested vmap | 
|  | B0, B1 = 5, 7 | 
|  | x_orig = torch.randn(B1, 11, B0) | 
|  | x = x_orig.clone() | 
|  | y = torch.randn(B0, B1, 11) | 
|  | vmap(vmap(op), (2, 0))(x, y) | 
|  | self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0]))) | 
|  |  | 
|  | # big batch size (total 10000) | 
|  | B0, B1, B2 = 100, 10, 10 | 
|  | x_orig = torch.randn(B0, B1, B2, 5) | 
|  | x = x_orig.clone() | 
|  | y = torch.randn(B0, B1, B2) | 
|  | result = vmap(vmap(vmap(op)))(x, y) | 
|  | self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1))) | 
|  |  | 
|  | def test_inplace_fallback_nary_different_levels(self): | 
|  | # NB: One day we will implement a batching rule for atan2_ | 
|  | # If/when we do, this test should be replaced to test the fallback | 
|  | # path on another operator to avoid bitrot. | 
|  | op = Tensor.atan2_ | 
|  | outplace_op = torch.atan2 | 
|  | B0, B1, B2 = 2, 3, 5 | 
|  |  | 
|  | x = torch.rand(B0, 7) | 
|  | y = torch.rand(7) | 
|  | self._assert_uses_vmap_fallback((op, (0, None)), (x, y)) | 
|  |  | 
|  | # op(left, right): All of the levels in right are found in left | 
|  | x_orig = torch.rand(B0, 7) | 
|  | x = x_orig.clone() | 
|  | y = torch.rand(7) | 
|  | vmap(op, in_dims=(0, None))(x, y) | 
|  | self.assertEqual(x, outplace_op(x_orig, y)) | 
|  |  | 
|  | x_orig = torch.rand(B0, B1, 7) | 
|  | x = x_orig.clone() | 
|  | y = torch.rand(B0, 7) | 
|  | vmap(vmap(op, in_dims=(0, None)))(x, y) | 
|  | self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7))) | 
|  |  | 
|  | # op(left, right): Some of the levels in right are not found in left | 
|  | msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible' | 
|  | x = torch.rand(7) | 
|  | y = torch.rand(B0, 7) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(None, 0))(x, y) | 
|  |  | 
|  | x = torch.rand(B1, 7) | 
|  | y = torch.rand(B0, 7) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y) | 
|  |  | 
|  | x = torch.rand(B1, 7) | 
|  | y = torch.rand(7, B0) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y) | 
|  |  | 
|  | x = torch.rand(B0, 7) | 
|  | y = torch.rand(B0, B1, 7) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(vmap(op, in_dims=(None, 0)))(x, y) | 
|  |  | 
|  | def test_backward_unsupported_interaction(self): | 
|  | x = torch.randn(3, requires_grad=True) | 
|  | y = torch.randn(5) | 
|  | grad = torch.randn_like(x) | 
|  | err_msg = r'backward\(\) called inside torch.vmap' | 
|  |  | 
|  | def backward_on_vmapped_tensor(x): | 
|  | x.sum().backward() | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, err_msg): | 
|  | vmap(backward_on_vmapped_tensor)(x) | 
|  |  | 
|  | def backward_with_vmapped_grad(x, grad): | 
|  | x.backward(grad) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, err_msg): | 
|  | vmap(backward_with_vmapped_grad)(x, grad) | 
|  |  | 
|  | def completely_unrelated_backward(y): | 
|  | x.sum().backward() | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, err_msg): | 
|  | vmap(completely_unrelated_backward)(y) | 
|  |  | 
|  | def test_grad_unsupported_interaction(self): | 
|  | input_tensor = torch.randn(3, requires_grad=True) | 
|  | err_msg = 'autograd.grad.* called inside torch.vmap' | 
|  |  | 
|  | captured = torch.randn(3, requires_grad=True) | 
|  |  | 
|  | def output_to_grad_is_vmapped(input_tensor): | 
|  | output = (captured * input_tensor).sum() | 
|  | return torch.autograd.grad([output], [captured])[0] | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, err_msg): | 
|  | vmap(output_to_grad_is_vmapped)(input_tensor) | 
|  |  | 
|  | output = (input_tensor ** 2).sum() | 
|  |  | 
|  | def input_to_grad_is_vmapped(input_tensor): | 
|  | return torch.autograd.grad([output], [input_tensor])[0] | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, err_msg): | 
|  | vmap(input_to_grad_is_vmapped)(input_tensor) | 
|  |  | 
|  | def test_batched_gradient_basic(self): | 
|  | N = 3 | 
|  | x = torch.randn(N, requires_grad=True) | 
|  | y = torch.randn(N) | 
|  |  | 
|  | def vjp_mul(v): | 
|  | return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0] | 
|  |  | 
|  | batched_v = torch.eye(N) | 
|  | jacobian = vmap(vjp_mul)(batched_v) | 
|  | self.assertEqual(jacobian, torch.diagflat(y)) | 
|  |  | 
|  | def test_functools_partial(self): | 
|  | x = torch.randn(3) | 
|  | y = torch.randn(2, 3) | 
|  | result = vmap(functools.partial(torch.mul, x))(y) | 
|  | self.assertEqual(result, x * y) | 
|  |  | 
|  | def test_nn_module(self): | 
|  | tensor = torch.randn(2, 3) | 
|  | model = torch.nn.Linear(3, 3, bias=False) | 
|  | result = vmap(model)(tensor) | 
|  | self.assertEqual(result, model(tensor)) | 
|  |  | 
|  | def test_fallback_with_undefined_grad(self): | 
|  | B0 = 7 | 
|  | x = torch.randn(2, 3, 4, 5, requires_grad=True) | 
|  | weight = torch.randn(3, 3, 1, 1) | 
|  | v = torch.randn(B0, 2, 3, 4, 5) | 
|  |  | 
|  | def get_vjp(v): | 
|  | result = torch.nn.functional.conv2d(x, weight) | 
|  | grad_x, = torch.autograd.grad(result, x, v) | 
|  | return grad_x | 
|  |  | 
|  | # Runs vmap(get_vjp)(v), which should not error out. | 
|  | # The backward formula for convolution returns an undefined | 
|  | # Tensor for grad_bias because the original bias does not exist. | 
|  | # | 
|  | # In the future we'll probably add a batching rule for convolution | 
|  | # backward. When this happens, we should modify this test to use a | 
|  | # different op (and/or create and use a dummy operator) to avoid bitrot. | 
|  | self._assert_uses_vmap_fallback([get_vjp], [v]) | 
|  |  | 
|  | def slice_inputs(inputs, bdims, i): | 
|  | result = [] | 
|  | for inp, bdim in zip(inputs, bdims): | 
|  | if bdim is None: | 
|  | result.append(inp) | 
|  | else: | 
|  | result.append(inp.select(bdim, i)) | 
|  | return tuple(result) | 
|  |  | 
|  |  | 
|  | def reference_vmap(op, inputs, in_dims=0, out_dims=0): | 
|  | if isinstance(in_dims, int): | 
|  | in_dims = (in_dims,) * len(inputs) | 
|  | bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None] | 
|  | assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes) | 
|  | bdim_size = bdim_sizes[0] | 
|  | results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size)) | 
|  |  | 
|  | assert len(results) > 0 | 
|  | op_has_single_return = not isinstance(results[0], tuple) | 
|  | if op_has_single_return: | 
|  | assert all(isinstance(result, torch.Tensor) for result in results) | 
|  | if isinstance(out_dims, int): | 
|  | out_dims = (out_dims,) * 1 | 
|  | return torch.stack(results, dim=out_dims[0]) | 
|  |  | 
|  | assert all(isinstance(result, tuple) for result in results) | 
|  | num_returns = len(results[0]) | 
|  | assert all(len(result) == num_returns for result in results) | 
|  | if isinstance(out_dims, int): | 
|  | out_dims = (out_dims,) * num_returns | 
|  | return tuple(torch.stack(result_shards, out_dim) | 
|  | for result_shards, out_dim in zip(zip(*results), out_dims)) | 
|  |  | 
|  |  | 
|  | class TensorFactory: | 
|  | @staticmethod | 
|  | def rand(size, device='cpu', dtype=torch.float): | 
|  | return torch.rand(size, device=device, dtype=dtype) | 
|  |  | 
|  | @staticmethod | 
|  | def randn(size, device='cpu', dtype=torch.float): | 
|  | return torch.randn(size, device=device, dtype=dtype) | 
|  |  | 
|  | @staticmethod | 
|  | def randp1(size, device='cpu', dtype=torch.float): | 
|  | return torch.rand(size, device=device, dtype=dtype) + 1 | 
|  |  | 
|  | # Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a | 
|  | # (slow) sequential map+stack fallback. | 
|  | # | 
|  | # check_view: Test if the first returned output is a view of the first input | 
|  | # check_propagates_grad: Test if the operation propagates gradients. | 
|  | def _vmap_test(self, op, inputs, in_dims=0, out_dims=0, | 
|  | check_view=False, check_propagates_grad=True): | 
|  | result = vmap(op, in_dims, out_dims)(*inputs) | 
|  | reference_result = reference_vmap(op, inputs, in_dims, out_dims) | 
|  | self.assertEqual(result, reference_result) | 
|  | op_has_single_return = not isinstance(result, tuple) | 
|  |  | 
|  | if check_view: | 
|  | result_as_tuple = (result,) if op_has_single_return else result | 
|  | for output in result_as_tuple: | 
|  | input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base | 
|  | self.assertTrue(output._base is input0_base, | 
|  | msg="result was not a view of the first input!") | 
|  |  | 
|  | if not check_propagates_grad: | 
|  | return | 
|  | # Assuming input[0] is a floating-point tensor. Check if the vmap | 
|  | # operation propagates the requires_grad flag to the zeroth output. | 
|  | # Some vmap operators are implemented in a way that assumes that | 
|  | # they are composite with respect to autograd. If the operator ever is | 
|  | # changed to not be composite with respect to autograd, then the | 
|  | # following check should fail. | 
|  | inputs_clone = list(inputs) | 
|  | inputs_clone[0] = inputs[0].clone().requires_grad_() | 
|  | result = vmap(op, in_dims, out_dims)(*inputs_clone) | 
|  | result_as_tuple = (result,) if op_has_single_return else result | 
|  | self.assertTrue(result[0].requires_grad) | 
|  |  | 
|  | def should_allow_vmap_fallback_usage(fn): | 
|  | return getattr(fn, '_allow_vmap_fallback_usage', False) | 
|  |  | 
|  | def allowVmapFallbackUsage(fn): | 
|  | fn._allow_vmap_fallback_usage = True | 
|  | return fn | 
|  |  | 
|  | # All tests of TestVmapBase check that the slow vmap fallback is never invoked. | 
|  | # This is so that we can incrementally add batching rules for operators to | 
|  | # replace the slow vmap fallback path for said operators. To skip this check, | 
|  | # please use the allowVmapFallbackUsage decorator. | 
|  | # | 
|  | # NB: Don't add tests to TestVmapBase directly, unless you want them to run | 
|  | # on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators. | 
|  | # | 
|  | # NB: TestVmapBase is a nested class. This prevents test runners from picking | 
|  | # it up and running it. | 
|  | class Namespace: | 
|  | class TestVmapBase(TestCase): | 
|  | def __init__(self, method_name='runTest'): | 
|  | super().__init__(method_name) | 
|  |  | 
|  | test_method = getattr(self, method_name, None) | 
|  | if test_method is None: | 
|  | return | 
|  |  | 
|  | if not should_allow_vmap_fallback_usage(test_method): | 
|  | setattr(self, method_name, | 
|  | self._wrap_method_with_vmap_fallback_check(test_method)) | 
|  |  | 
|  | def _wrap_method_with_vmap_fallback_check(self, method): | 
|  | msg = ( | 
|  | 'Expected the test to not invoke the vmap fallback path, i.e., ' | 
|  | 'all of the operators being tested in this test should have batching ' | 
|  | 'rules implemented. If you are intentionally testing something to ' | 
|  | 'do with the fallback path, use allowVmapFallbackUsage. Otherwise, ' | 
|  | 'please make sure that batching rules are implemented for the ' | 
|  | 'operator(s) being tested.' | 
|  | ) | 
|  |  | 
|  | @functools.wraps(method) | 
|  | def wrapper(self, *args, **kwargs): | 
|  | with warnings.catch_warnings(record=True) as wa: | 
|  | warnings.simplefilter('always') | 
|  | with EnableVmapFallbackWarnings(): | 
|  | method(*args, **kwargs) | 
|  | for captured_warning in wa: | 
|  | self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg) | 
|  | return types.MethodType(wrapper, self) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_vmap_fallback_check_ok(self): | 
|  | # One day we'll implement a batching rule for torch.var_mean. | 
|  | # When that happens, please change the example to use an | 
|  | # operator that doesn't have a batching rule implemented. | 
|  | op_using_fallback = torch.var_mean | 
|  | vmap(op_using_fallback)(torch.rand(3)) | 
|  |  | 
|  | def test_vmap_fallback_check(self): | 
|  | @self._wrap_method_with_vmap_fallback_check | 
|  | def no_fallback(self): | 
|  | pass | 
|  |  | 
|  | # One day we'll implement a batching rule for torch.var_mean. | 
|  | # When that happens, please change the example to use an | 
|  | # operator that doesn't have a batching rule implemented. | 
|  | op_using_fallback = torch.var_mean | 
|  |  | 
|  | @self._wrap_method_with_vmap_fallback_check | 
|  | def uses_fallback(self): | 
|  | vmap(op_using_fallback)(torch.rand(3)) | 
|  |  | 
|  | no_fallback(self) | 
|  |  | 
|  | with self.assertRaises(AssertionError): | 
|  | uses_fallback(self) | 
|  |  | 
|  |  | 
|  | class TestVmapOperators(Namespace.TestVmapBase): | 
|  | def _vmap_test(self, *args, **kwargs): | 
|  | return _vmap_test(self, *args, **kwargs) | 
|  |  | 
|  | def _vmap_view_test(self, *args, **kwargs): | 
|  | self._vmap_test(*args, **kwargs, check_view=True) | 
|  |  | 
|  | def _test_unary(self, op, getter, device, *args, **kwargs): | 
|  | test = functools.partial(self._vmap_test, *args, **kwargs) | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # Single vmap, various in_dims / out_dims | 
|  | test(op, [getter([B0, 3], device)]) | 
|  | test(op, [getter([2, 5, B0, 3], device)], in_dims=2) | 
|  | test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2) | 
|  |  | 
|  | # Doubly nested vmap | 
|  | test(vmap(op), [getter([B0, B1], device)]) | 
|  | test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2) | 
|  | test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)], | 
|  | in_dims=2, out_dims=2) | 
|  |  | 
|  | def test_unary_pointwise_ops(self): | 
|  | cases = [ | 
|  | (torch.abs, TensorFactory.randn), | 
|  | (torch.acos, TensorFactory.rand), | 
|  | (torch.asin, TensorFactory.rand), | 
|  | (torch.atan, TensorFactory.rand), | 
|  | (torch.ceil, TensorFactory.randn), | 
|  | (torch.cos, TensorFactory.rand), | 
|  | (torch.cosh, TensorFactory.rand), | 
|  | (torch.digamma, TensorFactory.rand), | 
|  | (torch.exp, TensorFactory.randn), | 
|  | (torch.expm1, TensorFactory.randn), | 
|  | (torch.floor, TensorFactory.randn), | 
|  | (torch.frac, TensorFactory.randn), | 
|  | (torch.lgamma, TensorFactory.rand), | 
|  | (torch.log, TensorFactory.randp1), | 
|  | (torch.log10, TensorFactory.randp1), | 
|  | (torch.log1p, TensorFactory.randp1), | 
|  | (torch.log2, TensorFactory.randp1), | 
|  | (torch.neg, TensorFactory.randn), | 
|  | (torch.reciprocal, TensorFactory.randp1), | 
|  | (torch.relu, TensorFactory.randn), | 
|  | (torch.round, TensorFactory.randn), | 
|  | (torch.rsqrt, TensorFactory.randp1), | 
|  | (torch.sigmoid, TensorFactory.randn), | 
|  | (torch.sign, TensorFactory.randn), | 
|  | (torch.sin, TensorFactory.rand), | 
|  | (torch.sinh, TensorFactory.rand), | 
|  | (torch.sqrt, TensorFactory.rand), | 
|  | (torch.tan, TensorFactory.rand), | 
|  | (torch.tanh, TensorFactory.rand), | 
|  | (torch.trunc, TensorFactory.randn), | 
|  | ] | 
|  | for op, getter in cases: | 
|  | self._test_unary(op, getter, 'cpu') | 
|  |  | 
|  | def test_clone(self): | 
|  | # Some basic tests | 
|  | self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu') | 
|  | self._test_unary(lambda x: x.clone(memory_format=torch.preserve_format), | 
|  | TensorFactory.randn, 'cpu') | 
|  | self._test_unary(lambda x: x.clone(memory_format=torch.contiguous_format), | 
|  | TensorFactory.randn, 'cpu') | 
|  |  | 
|  | # Test that the per-examples are contiguous when using torch.contiguous_format | 
|  | def clone_contiguous(x): | 
|  | return x.clone(memory_format=torch.contiguous_format) | 
|  |  | 
|  | B0, B1 = 3, 5 | 
|  | x = torch.randn(2, B0, 7) | 
|  | y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x) | 
|  | self.assertTrue(y.movedim(1, 0).is_contiguous()) | 
|  | self.assertTrue(y[:, 0, :].is_contiguous()) | 
|  |  | 
|  | x = torch.randn(2, B0, 7, B1) | 
|  | y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x) | 
|  | self.assertTrue(y.is_contiguous()) | 
|  | self.assertTrue(y[0][0].is_contiguous()) | 
|  |  | 
|  |  | 
|  | msg = r'only supported with memory_format torch.preserve_format or torch.contiguous_format' | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0)) | 
|  |  | 
|  | def test_binary_pointwise_ops(self): | 
|  | def get_number(getter): | 
|  | return getter([]).item() | 
|  |  | 
|  | def make_case(op, input_getter=TensorFactory.randn): | 
|  | return (op, input_getter) | 
|  |  | 
|  | cases = [ | 
|  | # Basic arithmetic | 
|  | make_case(torch.add), | 
|  | make_case(lambda x, y: x + y), | 
|  | make_case(torch.sub), | 
|  | make_case(lambda x, y: x - y), | 
|  | make_case(torch.mul), | 
|  | make_case(lambda x, y: x * y), | 
|  | make_case(torch.div, input_getter=TensorFactory.randp1), | 
|  | make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), | 
|  | make_case(torch.pow, input_getter=TensorFactory.randp1), | 
|  | make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1), | 
|  | ] | 
|  | test = self._vmap_test | 
|  |  | 
|  | for op, getter in cases: | 
|  | device = 'cpu' | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # Single vmap: op(Tensor, Tensor) | 
|  | test(op, (getter([B0, 3], device), getter([B0, 3], device))) | 
|  | test(op, (getter([B0], device), getter([B0, 2, 3], device))) | 
|  | test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) | 
|  | test(op, (getter([B0], device), getter([2, B0, 3], device)), | 
|  | in_dims=(0, 1), out_dims=1) | 
|  | test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) | 
|  | test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None)) | 
|  |  | 
|  | # Nested vmap: op(Tensor, Tensor) | 
|  | test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))) | 
|  | test(vmap(op, in_dims=(None, 0)), | 
|  | (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None)) | 
|  |  | 
|  | # Python number overload: op(Tensor, Number) (and vice-versa) | 
|  | number = get_number(getter) | 
|  | self._test_unary(lambda t: op(t, number), getter, device) | 
|  | number = get_number(getter) | 
|  | self._test_unary(lambda t: op(number, t), getter, device) | 
|  |  | 
|  | # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor) | 
|  | test(op, (getter([B0], device), getter([B0], device, dtype=torch.double))) | 
|  | test(op, (getter([B0], device, dtype=torch.double), getter([B0], device))) | 
|  | test(op, (getter([B0], device), getter([B0], device))) | 
|  |  | 
|  | # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa) | 
|  | test(op, (getter([B0, 2], device), getter([B0], device, torch.double))) | 
|  | test(op, (getter([B0], device, torch.double), getter([B0, 2], device))) | 
|  |  | 
|  | if not torch.cuda.is_available(): | 
|  | continue | 
|  |  | 
|  | # TODO(rzou): fix the following | 
|  | # # Test cross-device scalars | 
|  | # number = get_number(getter) | 
|  | # self._test_unary(lambda t: op(t, number), getter, device='cuda') | 
|  | # self._test_unary(lambda t: op(number, t), getter, device='cuda') | 
|  | # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda') | 
|  |  | 
|  | def test_as_strided(self): | 
|  | def _test(sizes, strides, offset, tensor, lambd): | 
|  | result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor) | 
|  | expected = vmap(lambd)(tensor) | 
|  | self.assertTrue(result._base is expected._base) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | # single vmap test | 
|  | B0 = 5 | 
|  | tensors = [ | 
|  | # contiguous | 
|  | torch.randn(B0, 2, 3), | 
|  | # non-contiguous | 
|  | torch.randn(B0, 3, 2).transpose(1, 2), | 
|  | # non-zero storage offset | 
|  | torch.randn(2, B0, 2, 3)[1], | 
|  | # non-contiguous strides, zero storage offset | 
|  | torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0], | 
|  | # non-contiguous strides, non-zero storage offset | 
|  | torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1], | 
|  | ] | 
|  |  | 
|  | for x in tensors: | 
|  | S0, S1 = x.stride()[1:] | 
|  | offset = x.storage_offset() | 
|  |  | 
|  | # Broadcast | 
|  | _test([5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)) | 
|  | # transpose | 
|  | _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1)) | 
|  | # select | 
|  | _test([2], [S0], offset + S1, x, lambda x: x[:, 1]) | 
|  |  | 
|  | # Nested vmap test | 
|  | B1 = 7 | 
|  | x = torch.randn(B1, B0, 2, 3) | 
|  | S0, S1 = x.stride()[2:] | 
|  | result = vmap(vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1)(x) | 
|  | expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x) | 
|  | self.assertTrue(result._base is expected._base) | 
|  | self.assertEqual(result, expected) | 
|  |  | 
|  | # Check that mal-formatted size/strides doesn't crash | 
|  | with self.assertRaisesRegex(RuntimeError, 'size and stride must have the same length'): | 
|  | x = torch.randn(B0, 2, 3).transpose(0, 1) | 
|  | vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x) | 
|  |  | 
|  | # Sanity check #1: we require the batch dims to be at the front of the | 
|  | # tensor (in memory layout). | 
|  | msg = 'batch dims being vmapped over are at the front of the tensor' | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(2, B0, 3).transpose(0, 1) | 
|  | vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(B0, 2, 3, B1).movedim(3, 1) | 
|  | vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x) | 
|  |  | 
|  | # All the Sanity check #2{a,b,c} cases check that | 
|  | # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) | 
|  | # doesn't index memory that is out of bounds of xs[i]. This condition | 
|  | # is important to the correctness of the as_strided batching rule | 
|  | # (see NOTE: [When will the as_strided_batching_rule fail?]) | 
|  |  | 
|  | # Sanity check #2a: The maximum indexable location of | 
|  | # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) | 
|  | # is less than or equal to the maximum indexable location of xs[i]. | 
|  | msg = 'This is not supported inside of vmap' | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(B0, 3) | 
|  | vmap(lambda x: x.as_strided([3], [1], 1))(x) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(B0, 3, 5) | 
|  | vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(B0, B1, 3, 5) | 
|  | vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x) | 
|  |  | 
|  | # Sanity check #2b: The min indexable location of | 
|  | # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) | 
|  | # is greater than or equal to the min indexable location of xs[i]. | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(2, B0, 3)[1] | 
|  | vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x) | 
|  |  | 
|  | # Sanity check #2c: | 
|  | # xs[i] is a zero-dim tensor, but | 
|  | # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) | 
|  | # is not | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | x = torch.randn(B0, 0, 3) | 
|  | vmap(lambda x: x.as_strided([3], [1]))(x) | 
|  |  | 
|  | def test_bmm(self): | 
|  | op = torch.bmm | 
|  | test = self._vmap_test | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # shape mismatch | 
|  | msg = "Shape mismatch" | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) | 
|  |  | 
|  | # left arg is vmapped | 
|  | test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None)) | 
|  | test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)), | 
|  | in_dims=(1, None)) | 
|  |  | 
|  | # right arg is vmapped | 
|  | test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) | 
|  | test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)), | 
|  | in_dims=(None, 1)) | 
|  |  | 
|  | # both args are vmapped | 
|  | test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3))) | 
|  | test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0)) | 
|  | test(vmap(op, in_dims=(0, None)), | 
|  | (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0)) | 
|  |  | 
|  | def test_cat(self): | 
|  | test = self._vmap_test | 
|  | B0, B1 = 5, 7 | 
|  |  | 
|  | # Quick hack b/c vmap can't accept a list of tensors as an argument | 
|  | def get_op(dim): | 
|  | def op(*tensors): | 
|  | return torch.cat(tensors, dim=dim) | 
|  | return op | 
|  |  | 
|  | test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3))) | 
|  | test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0)) | 
|  | test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2)) | 
|  | test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2)) | 
|  | test(vmap(get_op(0), in_dims=(0, None)), | 
|  | (torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0)) | 
|  | test(vmap(get_op(0), in_dims=(0, 0)), | 
|  | (torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0)) | 
|  |  | 
|  | def test_conj(self): | 
|  | op = torch.conj | 
|  |  | 
|  | def run_test(dtype): | 
|  | def get(shape): | 
|  | return torch.randn(shape, dtype=dtype) | 
|  | B0, B1 = 7, 11 | 
|  | test = self._vmap_test | 
|  |  | 
|  | # Single vmap, various in_dims / out_dims | 
|  | test(op, [get([B0, 3])]) | 
|  | test(op, [get([2, 5, B0, 3])], in_dims=2) | 
|  | test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) | 
|  |  | 
|  | # Doubly nested vmap | 
|  | test(vmap(op), [get([B0, B1])]) | 
|  | test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2) | 
|  | test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], | 
|  | in_dims=2, out_dims=2) | 
|  |  | 
|  | # correctness tests | 
|  | run_test(torch.float) | 
|  | run_test(torch.cfloat) | 
|  |  | 
|  | # check that torch.conj on a non-complex tensor returns the same tensor | 
|  | real_tensor = torch.randn(3) | 
|  | result = vmap(op)(real_tensor) | 
|  | self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) | 
|  |  | 
|  | def test_contiguous(self): | 
|  | op = Tensor.contiguous | 
|  |  | 
|  | self._test_unary(op, TensorFactory.randn, 'cpu') | 
|  |  | 
|  | # check that contiguous returns the original tensor if the per-examples | 
|  | # are already contiguous | 
|  | B0 = 3 | 
|  | x = torch.randn(B0, 2, 5, 7) | 
|  | x = x.movedim(0, 2) | 
|  | result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) | 
|  | self.assertTrue(result is x) | 
|  |  | 
|  | msg = 'NYI: querying is_contiguous inside of vmap for memory_format' | 
|  | tensor = torch.randn(B0, 3) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) | 
|  |  | 
|  | def test_stride(self): | 
|  | B0 = 3 | 
|  |  | 
|  | x = torch.randn(B0, 2, 5, 7) | 
|  |  | 
|  | def foo(x): | 
|  | assert x.stride() == (7 * 5, 7, 1) | 
|  | return x | 
|  |  | 
|  | vmap(foo)(x) | 
|  |  | 
|  | x = torch.randn(2, B0, 5, 7).movedim(1, 0) | 
|  |  | 
|  | def bar(x): | 
|  | assert x.stride() == (7 * 5 * B0, 7, 1) | 
|  | return x | 
|  |  | 
|  | vmap(bar)(x) | 
|  |  | 
|  | def test_chunk(self): | 
|  | test = self._vmap_view_test | 
|  | op = torch.chunk | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | # tests for torch.split(self, split_size: int, dim) | 
|  | test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0), | 
|  | in_dims=(2, None, None)) | 
|  | test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) | 
|  |  | 
|  | def test_clamp(self): | 
|  | clamp_cases = ( | 
|  | (lambda t: t.clamp(min=-0.5), TensorFactory.randn), | 
|  | (lambda t: t.clamp(max=0.5), TensorFactory.randn), | 
|  | (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn), | 
|  | (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn), | 
|  | (lambda t: t.clamp_max(max=0.5), TensorFactory.randn), | 
|  | ) | 
|  | for op, getter in clamp_cases: | 
|  | self._test_unary(op, getter, 'cpu') | 
|  |  | 
|  | def test_comparison_ops(self): | 
|  | test = functools.partial(self._vmap_test, check_propagates_grad=False) | 
|  |  | 
|  | getter = TensorFactory.randn | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | ops = ( | 
|  | torch.eq, lambda x, y: x == y, | 
|  | torch.gt, lambda x, y: x > y, | 
|  | torch.ge, lambda x, y: x >= y, | 
|  | torch.le, lambda x, y: x <= y, | 
|  | torch.lt, lambda x, y: x < y, | 
|  | torch.ne, lambda x, y: x != y, | 
|  | ) | 
|  |  | 
|  | for op in ops: | 
|  | # Single vmap: op(Tensor, Tensor) | 
|  | test(op, (getter([B0, 3]), getter([B0, 3]))) | 
|  | test(op, (getter([B0]), getter([B0, 2, 3]))) | 
|  | test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1)) | 
|  | test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1) | 
|  | test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None)) | 
|  | test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None)) | 
|  |  | 
|  | # Nested vmap: op(Tensor, Tensor) | 
|  | test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3]))) | 
|  | test(vmap(op, in_dims=(None, 0)), | 
|  | (getter([B0, 2, 3]), getter([B1, 3])), in_dims=(0, None)) | 
|  |  | 
|  | # test number as inputs | 
|  | number = getter([]).item() | 
|  | self._test_unary(lambda t: op(t, number), getter, 'cpu', check_propagates_grad=False) | 
|  |  | 
|  | def test_diagonal(self): | 
|  | tensor = torch.randn(3, 5, 7, 11, 13) | 
|  | test = self._vmap_view_test | 
|  | op = torch.diagonal | 
|  | test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None)) | 
|  | test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None)) | 
|  | test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None)) | 
|  | test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1) | 
|  | test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1) | 
|  | test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3), | 
|  | (tensor,), in_dims=1, out_dims=1) | 
|  |  | 
|  | def test_dot(self): | 
|  | op = torch.dot | 
|  | test = self._vmap_test | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # shape mismatch | 
|  | msg = "Shape mismatch" | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2)) | 
|  |  | 
|  | # left arg is vmapped | 
|  | test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None)) | 
|  | test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)), | 
|  | in_dims=(1, None)) | 
|  |  | 
|  | # right arg is vmapped | 
|  | test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0)) | 
|  | test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)), | 
|  | in_dims=(None, 1)) | 
|  |  | 
|  | # both args are vmapped | 
|  | test(op, (torch.rand(B0, 5), torch.rand(B0, 5))) | 
|  | test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) | 
|  | test(vmap(op, in_dims=(0, None)), | 
|  | (torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0)) | 
|  |  | 
|  | def test_expand_as(self): | 
|  | op = torch.Tensor.expand_as | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5))) | 
|  | test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None)) | 
|  | test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) | 
|  | test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5))) | 
|  | test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1)) | 
|  | test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) | 
|  | test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) | 
|  |  | 
|  | def test_fill_and_zero_inplace(self): | 
|  | test = functools.partial(self._vmap_test, check_propagates_grad=False) | 
|  | B0, B1 = 7, 11 | 
|  | ops = ( | 
|  | lambda t: t.fill_(0.1), | 
|  | lambda t: t.fill_(torch.tensor(0.2)), | 
|  | lambda t: t.zero_(), | 
|  | ) | 
|  |  | 
|  | for op in ops: | 
|  | # Single vmap, various in_dims / out_dims | 
|  | test(op, [TensorFactory.randn([B0, 3])]) | 
|  | test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) | 
|  | test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) | 
|  |  | 
|  | # Doubly nested vmap | 
|  | test(vmap(op), [TensorFactory.randn([B0, B1])]) | 
|  | test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) | 
|  | test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])], | 
|  | in_dims=2, out_dims=2) | 
|  |  | 
|  | # test when value is a batched tensor for fill_ operator | 
|  | B0, B1 = 3, 5 | 
|  | test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, | 
|  | r"output with shape .+ doesn't match the broadcast shape"): | 
|  | # Runtime Error is thrown when the tensor being written to isn't being vmapped over | 
|  | vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]), | 
|  | TensorFactory.randn([B0])) | 
|  |  | 
|  | def _test_complex_views(self, op, dtypes): | 
|  | test = self._vmap_view_test | 
|  |  | 
|  | def run_test(op, dtype): | 
|  | def get(shape): | 
|  | return torch.randn(shape, dtype=dtype) | 
|  |  | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # Single vmap, various in_dims / out_dims | 
|  | test(op, [get([B0, 3])]) | 
|  | test(op, [get([3, B0])], in_dims=1) | 
|  | test(op, [get([2, 5, B0, 3])], in_dims=2) | 
|  | test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) | 
|  |  | 
|  | # Doubly nested vmap | 
|  | test(vmap(op), [get([B0, B1])]) | 
|  | test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4) | 
|  | test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], | 
|  | in_dims=2, out_dims=2) | 
|  |  | 
|  | for dtype in dtypes: | 
|  | run_test(op, dtype) | 
|  |  | 
|  | def test_real(self): | 
|  | self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble]) | 
|  |  | 
|  | def test_imag(self): | 
|  | self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble]) | 
|  |  | 
|  | def test_view_as_real(self): | 
|  | self._test_complex_views(torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]) | 
|  |  | 
|  | def test_view_as_complex(self): | 
|  | def run_test(dtype): | 
|  | def get(shape): | 
|  | return torch.randn(shape, dtype=dtype) | 
|  |  | 
|  | op = torch.view_as_complex | 
|  | test = self._vmap_view_test | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # Single vmap, various in_dims / out_dims | 
|  | test(op, [get([B0, 3, 2])]) | 
|  | test(op, [get([2, 5, B0, 3, 2])], in_dims=2) | 
|  | test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2) | 
|  |  | 
|  | # Doubly nested vmap | 
|  | test(vmap(op), [get([B0, B1, 2])]) | 
|  | test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2) | 
|  | test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], | 
|  | in_dims=2, out_dims=2) | 
|  |  | 
|  | # Interesting case #1: Batch dim directly before dim of size 2 | 
|  | test(op, [get([3, B0, 2])], in_dims=1) | 
|  | test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2) | 
|  |  | 
|  | # Interesting case #2: Batch dim at end of tensor, success cases | 
|  | # view_as_complex requires that the dim with size 2 have stride 1 | 
|  | # in order for the view to function propertly | 
|  | test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) | 
|  | test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) | 
|  | test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) | 
|  |  | 
|  | # Interesting case #3: Batch dim at end of tensor, failure cases | 
|  | msg = "Tensor must have a last dimension with stride 1" | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=1)(get([2, B0])) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1])) | 
|  |  | 
|  | # Invalid input: no dimension of size 2 | 
|  | msg = 'Input tensor must have one or more dimensions' | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(get([B0])) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(vmap(op))(get([B0, B1])) | 
|  |  | 
|  | # Invalid input: Batch dim has size 2, but the logical last dim does | 
|  | # not have size 2 | 
|  | msg = 'Tensor must have a last dimension of size 2' | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=1)(get([3, 2])) | 
|  |  | 
|  | for dtype in [torch.float, torch.double]: | 
|  | run_test(dtype) | 
|  |  | 
|  | def test_is_complex(self): | 
|  | ctensor = torch.randn(3, dtype=torch.cfloat) | 
|  | tensor = torch.randn(3) | 
|  |  | 
|  | def foo(x): | 
|  | if x.is_complex(): | 
|  | return torch.tensor(1) | 
|  | else: | 
|  | return torch.tensor(0) | 
|  |  | 
|  | self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) | 
|  | self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) | 
|  |  | 
|  | def test_is_floating_point(self): | 
|  | float_tensor = torch.tensor([1., 2., 3.]) | 
|  | long_tensor = torch.tensor([1, 2, 3]) | 
|  |  | 
|  | def foo(x): | 
|  | if x.is_floating_point(): | 
|  | return torch.tensor(1) | 
|  | else: | 
|  | return torch.tensor(0) | 
|  |  | 
|  | self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1])) | 
|  | self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0])) | 
|  |  | 
|  | def test_is_contiguous(self): | 
|  | def foo(x): | 
|  | if x.is_contiguous(): | 
|  | return torch.tensor(1.) | 
|  | else: | 
|  | return torch.tensor(0.) | 
|  |  | 
|  | B0, B1 = 3, 5 | 
|  |  | 
|  | # Single batch dim | 
|  | contig = torch.randn(B0, 2, 7) | 
|  | self.assertEqual(vmap(foo)(contig), torch.ones(B0)) | 
|  |  | 
|  | noncontig = torch.randn(2, B0, 7) | 
|  | self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) | 
|  |  | 
|  | noncontig = torch.randn(2, B0, 7).movedim(1, 0) | 
|  | self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) | 
|  |  | 
|  | noncontig = torch.randn(2, 7, B0) | 
|  | self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) | 
|  |  | 
|  | # Multiple batch dims | 
|  | contig = torch.randn(B0, B1, 3) | 
|  | self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) | 
|  |  | 
|  | contig = torch.randn(B1, B0, 3) | 
|  | self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) | 
|  |  | 
|  | contig = torch.randn(B1, B0, 3).movedim(0, 1) | 
|  | self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) | 
|  |  | 
|  | noncontig = torch.randn(B0, 3, B1) | 
|  | self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) | 
|  |  | 
|  | # is_contiguous on empty tensor is True | 
|  | def bar(x): | 
|  | assert x.is_contiguous() | 
|  | return x | 
|  |  | 
|  | vmap(bar)(torch.randn(B0, 0, 3)) | 
|  | vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) | 
|  | vmap(bar)(torch.randn(B0, 0, 3).mT) | 
|  |  | 
|  | # is_contiguous with other memory formats | 
|  | def baz(x, memory_format): | 
|  | x.is_contiguous(memory_format=memory_format) | 
|  | return x | 
|  |  | 
|  | msg = 'NYI: querying is_contiguous inside of vmap for memory_format' | 
|  | tensor = torch.randn(B0, 2, 7, 3) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) | 
|  |  | 
|  | def test_movedim(self): | 
|  | op = torch.movedim | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | # movedim(tensor, int, int) variant | 
|  | test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None)) | 
|  | test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), | 
|  | (torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None)) | 
|  |  | 
|  | # movedim(tensor, intlist, intlist) variant | 
|  | test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), | 
|  | (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None)) | 
|  | test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), | 
|  | (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None)) | 
|  |  | 
|  | def test_mm(self): | 
|  | op = torch.mm | 
|  | test = self._vmap_test | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # shape mismatch | 
|  | msg = "Shape mismatch" | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) | 
|  |  | 
|  | # left arg is vmapped | 
|  | test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None)) | 
|  | test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)), | 
|  | in_dims=(1, None)) | 
|  |  | 
|  | # right arg is vmapped | 
|  | test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) | 
|  | test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)), | 
|  | in_dims=(None, 1)) | 
|  |  | 
|  | # both args are vmapped | 
|  | test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2))) | 
|  | test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0)) | 
|  | test(vmap(op, in_dims=(0, None)), | 
|  | (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) | 
|  |  | 
|  | def test_mv(self): | 
|  | op = torch.mv | 
|  | test = self._vmap_test | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | # shape mismatch | 
|  | msg = "Shape mismatch" | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2)) | 
|  | with self.assertRaisesRegex(RuntimeError, msg): | 
|  | vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2)) | 
|  |  | 
|  | # left arg is vmapped | 
|  | test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None)) | 
|  | test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5)), | 
|  | in_dims=(1, None)) | 
|  |  | 
|  | # right arg is vmapped | 
|  | test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) | 
|  | test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5)), | 
|  | in_dims=(None, 1)) | 
|  |  | 
|  | # both args are vmapped | 
|  | test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5))) | 
|  | test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) | 
|  | test(vmap(op, in_dims=(0, None)), | 
|  | (torch.rand(B1, 2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) | 
|  |  | 
|  | def test_narrow(self): | 
|  | op = torch.narrow | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None)) | 
|  | test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None, None)), | 
|  | (torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None)) | 
|  | test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), | 
|  | (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None)) | 
|  |  | 
|  | def test_new_empty(self): | 
|  | # Empty is non-deterministic so we just check that the shape of the | 
|  | # output tensor is what we expect and that the vmap fallback isn't used. | 
|  | op = Tensor.new_empty | 
|  |  | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0)) | 
|  | self.assertEqual(result.shape, [B0, 2, 3]) | 
|  |  | 
|  | result = vmap(lambda x: op(x, []))(torch.randn(B0)) | 
|  | self.assertEqual(result.shape, [B0]) | 
|  |  | 
|  | result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1)) | 
|  | self.assertEqual(result.shape, [B0, B1, 2, 3]) | 
|  |  | 
|  | def test_new_empty_strided(self): | 
|  | # Empty is non-deterministic so we just check that the size and shape | 
|  | # of the output are what we expect and that the vmap fallback isn't used | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | def _test_single_vmap(size, stride, B0): | 
|  | x = torch.randn(B0) | 
|  | result = vmap(lambda x: x.new_empty_strided(size, stride))(x) | 
|  | S = torch.empty_strided(size, stride).storage().size() | 
|  | self.assertEqual(result.shape, [B0] + size) | 
|  | self.assertEqual(result.stride(), [S] + stride) | 
|  |  | 
|  | def _test_double_vmap(size, stride, B0, B1): | 
|  | x = torch.randn(B0, B1) | 
|  | result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x) | 
|  | S = torch.empty_strided(size, stride).storage().size() | 
|  | self.assertEqual(result.shape, [B0, B1] + size) | 
|  | self.assertEqual(result.stride(), [B1 * S, S] + stride) | 
|  |  | 
|  | x = torch.randn(B1, B0) | 
|  | result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(x) | 
|  | S = x.new_empty_strided(size, stride).storage().size() | 
|  | self.assertEqual(result.shape, [B0, B1] + size) | 
|  | self.assertEqual(result.stride(), [B1 * S, S] + stride) | 
|  |  | 
|  | # contiguous case | 
|  | _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0) | 
|  | _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1) | 
|  |  | 
|  | # expanded | 
|  | _test_single_vmap([2, 3, 5], [0, 5, 1], B0) | 
|  | _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1) | 
|  |  | 
|  | # some of these cases are pretty strange, just verifying that if | 
|  | # empty_strided allows them then BatchedTensor.new_empty_strided | 
|  | # can as well | 
|  | for shape in [[2, 3, 4], [0, 2, 0]]: | 
|  | for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]: | 
|  | _test_single_vmap(shape, strides, B0) | 
|  | _test_double_vmap(shape, strides, B0, B1) | 
|  |  | 
|  | def test_new_zeros(self): | 
|  | op = Tensor.new_zeros | 
|  | test = functools.partial(self._vmap_test, check_propagates_grad=False) | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | test(lambda x: op(x, 2, 3), (torch.rand(B0),)) | 
|  | test(lambda x: op(x, []), (torch.rand(B0),)) | 
|  | test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),)) | 
|  |  | 
|  | def test_select(self): | 
|  | op = torch.select | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None)) | 
|  | test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) | 
|  | test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2) | 
|  |  | 
|  | def test_stack(self): | 
|  | test = self._vmap_test | 
|  | B0, B1 = 5, 7 | 
|  |  | 
|  | # Quick hack b/c vmap can't accept a list of tensors as an argument | 
|  | def get_op(dim): | 
|  | def op(*tensors): | 
|  | return torch.stack(tensors, dim=dim) | 
|  | return op | 
|  |  | 
|  | test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3))) | 
|  | test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0)) | 
|  | test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) | 
|  | test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) | 
|  | test(vmap(get_op(0), in_dims=(0, None)), | 
|  | (torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0)) | 
|  | test(vmap(get_op(0), in_dims=(0, 0)), | 
|  | (torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0)) | 
|  |  | 
|  |  | 
|  | def test_slice(self): | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | test(lambda t: t[0:1], (torch.rand(B0, 3, 5),)) | 
|  | test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2) | 
|  | test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2) | 
|  | test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), | 
|  | (torch.rand(3, 5, B0, B1, B2),), in_dims=2) | 
|  |  | 
|  | def test_squeeze(self): | 
|  | test = self._vmap_view_test | 
|  | op = torch.squeeze | 
|  | B0, B1 = 1, 11 | 
|  | test(op, (torch.rand(B0),)) | 
|  | test(op, (torch.rand(B0, 3, 5),)) | 
|  | test(op, (torch.rand(1, B0, 5),), in_dims=1) | 
|  | test(op, (torch.rand(B0, 0, 1, 5, 1),)) | 
|  | test(op, (torch.rand(B0, 1, 1, 1, 1),)) | 
|  | test(vmap(op), (torch.rand(B0, B1, 1),)) | 
|  | test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) | 
|  |  | 
|  | def test_sum_dim(self): | 
|  | test = self._vmap_test | 
|  | B0, B1 = 5, 7 | 
|  |  | 
|  | # Single vmap, various in_dims / out_dims | 
|  | test(lambda x: x.sum(0), [torch.randn([B0])]) | 
|  | test(lambda x: x.sum(-1), [torch.randn([B0])]) | 
|  | test(lambda x: x.sum(0), [torch.randn([B0, 3])]) | 
|  | test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2) | 
|  | test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) | 
|  |  | 
|  | # Doubly nested vmap | 
|  | test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])]) | 
|  | test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])]) | 
|  | test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) | 
|  | test(vmap(lambda x: x.sum(2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])], | 
|  | in_dims=2, out_dims=2) | 
|  |  | 
|  | def test_reshape(self): | 
|  | test = self._vmap_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | op = torch.reshape | 
|  | test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True) | 
|  | test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False) | 
|  | test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True) | 
|  | test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1), | 
|  | (torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False) | 
|  |  | 
|  | def test_reshape_as(self): | 
|  | test = self._vmap_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | op = torch.Tensor.reshape_as | 
|  | test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True) | 
|  | test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True) | 
|  | test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True) | 
|  |  | 
|  | test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False) | 
|  |  | 
|  | test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True) | 
|  | test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)), | 
|  | (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)), | 
|  | in_dims=(5, 0), check_view=False) | 
|  |  | 
|  | def test_result_type(self): | 
|  | def scalar_tensor_with_dtype(op): | 
|  | def wrapped(*args, **kwargs): | 
|  | dtype = op(*args, **kwargs) | 
|  | return torch.ones([], dtype=dtype) | 
|  | return wrapped | 
|  |  | 
|  | test = self._vmap_test | 
|  | op = scalar_tensor_with_dtype(torch.result_type) | 
|  |  | 
|  | B0 = 2 | 
|  |  | 
|  | test(op, (torch.randn(B0), torch.randn(B0, dtype=torch.float64)), | 
|  | check_propagates_grad=False) | 
|  | test(op, (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)), | 
|  | check_propagates_grad=False) | 
|  |  | 
|  | test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False) | 
|  | test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False) | 
|  |  | 
|  | test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0),), | 
|  | check_propagates_grad=False) | 
|  | test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), | 
|  | (torch.randn(B0),), check_propagates_grad=False) | 
|  |  | 
|  | test(op, (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)), | 
|  | check_propagates_grad=False) | 
|  | test(op, (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)), | 
|  | check_propagates_grad=False) | 
|  |  | 
|  | test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False) | 
|  | test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False) | 
|  |  | 
|  | test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0, 2),), | 
|  | check_propagates_grad=False) | 
|  | test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), | 
|  | (torch.randn(B0, 2),), check_propagates_grad=False) | 
|  |  | 
|  | test(op, (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)), | 
|  | check_propagates_grad=False) | 
|  | test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), | 
|  | check_propagates_grad=False) | 
|  |  | 
|  | def test_tensor_split(self): | 
|  | test = self._vmap_view_test | 
|  | op = torch.tensor_split | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | # tests for torch.tensor_split(self, indices_or_sections: int, dim) | 
|  | test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), | 
|  | in_dims=(2, None, None)) | 
|  | test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) | 
|  |  | 
|  | # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) | 
|  | test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), | 
|  | in_dims=(2, None, None)) | 
|  | test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) | 
|  |  | 
|  | def test_split(self): | 
|  | test = self._vmap_view_test | 
|  | op = torch.split | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | # tests for torch.split(self, split_size: int, dim) | 
|  | test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), | 
|  | in_dims=(2, None, None)) | 
|  | test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) | 
|  |  | 
|  | # tests for torch.split(self, split_size: List[int], dim) | 
|  | test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None)) | 
|  | test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0), | 
|  | in_dims=(2, None, None)) | 
|  | test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) | 
|  |  | 
|  | def test_trace(self): | 
|  | op = torch.trace | 
|  | test = self._vmap_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | test(op, (torch.rand(B0, 2, 5),)) | 
|  | test(op, (torch.rand(2, B0, 5),), in_dims=1) | 
|  | test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) | 
|  | test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) | 
|  |  | 
|  | def test_transpose(self): | 
|  | op = torch.transpose | 
|  | test = self._vmap_view_test | 
|  |  | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) | 
|  | test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) | 
|  | test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) | 
|  | test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) | 
|  | test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) | 
|  | test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) | 
|  |  | 
|  | # Special case: scalar tensor | 
|  | for dim1, dim2 in itertools.product([0, -1], [0, -1]): | 
|  | x = torch.rand(B0) | 
|  | result = vmap(lambda x: op(x, dim1, dim2))(x) | 
|  | self.assertTrue(result is x) | 
|  |  | 
|  | def test_t(self): | 
|  | op = torch.t | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | test(op, (torch.rand(B0, 2, 5),)) | 
|  | test(op, (torch.rand(2, B0, 5),), in_dims=1) | 
|  | test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) | 
|  | test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) | 
|  |  | 
|  | def test_T_numpy(self): | 
|  | def op(t): | 
|  | return t.T | 
|  |  | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | test(op, (torch.rand(B0, 2, 3, 5),)) | 
|  | test(op, (torch.rand(B0),)) | 
|  | test(op, (torch.rand(2, B0, 3, 5),), in_dims=1) | 
|  | test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) | 
|  | test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2) | 
|  | test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2) | 
|  |  | 
|  | def test_to(self): | 
|  | test = self._vmap_test | 
|  | B0, B1 = 7, 11 | 
|  |  | 
|  | test(lambda t: t.to('cpu'), (torch.rand(B0),)) | 
|  | test(lambda t: t.to(torch.double), (torch.rand(B0),)) | 
|  | test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64))) | 
|  | test(lambda t, o: t.to(o), | 
|  | (torch.rand(B0), torch.randn(B0, dtype=torch.float64)), | 
|  | in_dims=(0, None)) | 
|  | test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),)) | 
|  |  | 
|  | # also test some casting methods | 
|  | test(lambda t: t.double(), (torch.rand(B0),)) | 
|  | test(lambda t: t.float(), (torch.rand(B0),)) | 
|  | test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False) | 
|  | test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False) | 
|  |  | 
|  | def test_unfold(self): | 
|  | op = torch.Tensor.unfold | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 3, 2, 5 | 
|  |  | 
|  | test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None)) | 
|  | test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None)) | 
|  | test(vmap(op, in_dims=(0, None, None, None)), | 
|  | (torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None)) | 
|  | test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), | 
|  | (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None)) | 
|  |  | 
|  | def test_unbind(self): | 
|  | test = self._vmap_view_test | 
|  | op = torch.unbind | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  |  | 
|  | test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None)) | 
|  | test(op, (torch.rand(B0, 2, 0),)) | 
|  | test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None)) | 
|  | test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1), | 
|  | in_dims=(2, None)) | 
|  | test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)), | 
|  | (torch.rand(B1, 2, B0, 32, B2),), in_dims=2) | 
|  |  | 
|  | def test_view(self): | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | op = torch.Tensor.view | 
|  |  | 
|  | # We should error out if the view would produce an incorrect result | 
|  | with self.assertRaises(RuntimeError): | 
|  | vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10]) | 
|  |  | 
|  | test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None)) | 
|  | test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None)) | 
|  | test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),)) | 
|  | test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1), | 
|  | (torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1) | 
|  |  | 
|  | def test_view_as(self): | 
|  | test = self._vmap_view_test | 
|  | B0, B1, B2 = 7, 11, 13 | 
|  | op = torch.Tensor.view_as | 
|  |  | 
|  | # We should error out if the view would produce an incorrect result | 
|  | with self.assertRaises(RuntimeError): | 
|  | vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10)) | 
|  |  | 
|  | test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5))) | 
|  | test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0)) | 
|  | test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None)) | 
|  |  | 
|  | test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None)) | 
|  |  | 
|  | test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10))) | 
|  | test(vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)), | 
|  | (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)), | 
|  | in_dims=(2, 0)) | 
|  |  | 
|  | def test_no_random_op_support(self): | 
|  | B0 = 2 | 
|  |  | 
|  | captured = torch.rand(3) | 
|  |  | 
|  | random_ops = [ | 
|  | # out-of-place on BatchedTensor | 
|  | (torch.bernoulli, (torch.rand(B0, 1),)), | 
|  | (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)), | 
|  | (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)), | 
|  | (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))), | 
|  | (lambda t: torch.normal(t, 1.), (torch.randn(B0, 1),)), | 
|  | (lambda t: torch.normal(0., t), (torch.randn(B0, 1),)), | 
|  | (torch.poisson, (torch.rand(B0, 1),)), | 
|  | (torch.rand_like, (torch.rand(B0, 1),)), | 
|  | (torch.randn_like, (torch.rand(B0, 1),)), | 
|  | (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)), | 
|  | (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)), | 
|  |  | 
|  | # out-of-place on captured tensor | 
|  | (lambda t: torch.bernoulli(captured), (torch.rand(B0),)), | 
|  | (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)), | 
|  | (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)), | 
|  | (lambda t: torch.normal(captured, captured), (torch.randn(B0),)), | 
|  | (lambda t: torch.normal(captured, 1.), (torch.randn(B0),)), | 
|  | (lambda t: torch.normal(0., captured), (torch.randn(B0),)), | 
|  | (lambda t: torch.poisson(captured), (torch.rand(B0),)), | 
|  | (lambda t: torch.rand_like(captured), (torch.rand(B0),)), | 
|  | (lambda t: torch.randn_like(captured) , (torch.rand(B0),)), | 
|  | (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)), | 
|  | (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)), | 
|  |  | 
|  | # in-place on BatchedTensor | 
|  | (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.cauchy_(), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.exponential_(), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.log_normal_(), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.normal_(), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.random_(), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.random_(2), (torch.randn(B0, 1),)), | 
|  | (lambda t: t.uniform_(), (torch.randn(B0, 1),)), | 
|  |  | 
|  | # in-place on captured tensor | 
|  | (lambda t: captured.bernoulli_(), (torch.randn(B0),)), | 
|  | (lambda t: captured.cauchy_(), (torch.randn(B0),)), | 
|  | (lambda t: captured.exponential_(), (torch.randn(B0),)), | 
|  | (lambda t: captured.geometric_(0.5), (torch.randn(B0),)), | 
|  | (lambda t: captured.log_normal_(), (torch.randn(B0),)), | 
|  | (lambda t: captured.normal_(), (torch.randn(B0),)), | 
|  | (lambda t: captured.random_(), (torch.randn(B0),)), | 
|  | (lambda t: captured.random_(0, 2), (torch.randn(B0),)), | 
|  | (lambda t: captured.random_(2), (torch.randn(B0),)), | 
|  | (lambda t: captured.uniform_(), (torch.randn(B0),)), | 
|  |  | 
|  | # factory functions | 
|  | (lambda t: torch.rand(1), (torch.randn(B0),)), | 
|  | (lambda t: torch.randn(1), (torch.randn(B0),)), | 
|  | (lambda t: torch.randint(5, [1]), (torch.randn(B0),)), | 
|  | (lambda t: torch.randperm(5), (torch.randn(B0),)), | 
|  | ] | 
|  | for op, args in random_ops: | 
|  | with self.assertRaisesRegex(RuntimeError, | 
|  | 'vmap: We do not yet support calling random operations'): | 
|  | vmap(op)(*args) | 
|  |  | 
|  | def construct_v(output, batch_size): | 
|  | return torch.randn(batch_size, *output.shape, | 
|  | dtype=output.dtype, device=output.device) | 
|  |  | 
|  | def as_tuple(x): | 
|  | if isinstance(x, tuple): | 
|  | return x | 
|  | elif isinstance(x, list): | 
|  | return tuple(x) | 
|  | else: | 
|  | return x, | 
|  |  | 
|  | def differentiable(args): | 
|  | return tuple(arg for arg in as_tuple(args) | 
|  | if isinstance(arg, torch.Tensor) and arg.requires_grad) | 
|  |  | 
|  | def _get_rand_no_zeros(*args, **kwargs): | 
|  | requires_grad = kwargs.get('requires_grad', False) | 
|  | kwargs_without_requires_grad = kwargs.copy() | 
|  | kwargs_without_requires_grad['requires_grad'] = False | 
|  | result = torch.rand(*args, **kwargs_without_requires_grad) | 
|  | return result.clamp_min_(0.1).requires_grad_(requires_grad) | 
|  |  | 
|  | class TestVmapBatchedGradient(Namespace.TestVmapBase): | 
|  | def _vmap_test(self, *args, **kwargs): | 
|  | return _vmap_test(self, *args, **kwargs) | 
|  |  | 
|  | # Tests batched gradient computation of outputs = op(*args, **kwargs) | 
|  | # by comparing it to a sequential map+stack fallback. | 
|  | # | 
|  | # output_process_fn: a function that maps the outputs to the part | 
|  | #       that should be differentiated. | 
|  | # batch_size: the batch dim size for the batched grad | 
|  | def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): | 
|  | if kwargs is None: | 
|  | kwargs = {} | 
|  | outputs = op(*args, **kwargs) | 
|  | outputs = differentiable(output_process_fn(outputs)) | 
|  | batched_vectors = tuple(construct_v(out, batch_size) for out in outputs) | 
|  |  | 
|  | def vector_jacobian_product(*vectors): | 
|  | return torch.autograd.grad(outputs, differentiable(args), vectors, | 
|  | retain_graph=True) | 
|  | self._vmap_test(vector_jacobian_product, batched_vectors, | 
|  | check_propagates_grad=False) | 
|  |  | 
|  | # Tests batched second grad computation of outputs = op(*args, **kwargs). | 
|  | # by comparing it to a sequential map+stack fallback. | 
|  | # | 
|  | # output_process_fn: a function that maps the outputs to the part | 
|  | #       that should be differentiated. | 
|  | # batch_size: the batch dim size for the batched grad | 
|  | # | 
|  | # NB: we only test computing batched gradients in the second gradient | 
|  | # computation. One specific use case that does this is computing the hessian | 
|  | # matrix of a scalar-valued function; this is useful in Bayesian Logistic | 
|  | # Regression. | 
|  | # It might be useful to have a test that computes batched first gradients and | 
|  | # then uses those to compute batched second gradients in the future. | 
|  | def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): | 
|  | if kwargs is None: | 
|  | kwargs = {} | 
|  | outputs = op(*args, **kwargs) | 
|  | outputs = differentiable(output_process_fn(outputs)) | 
|  | ones = tuple(torch.ones_like(out) for out in outputs) | 
|  | # Same thing as summing together all of the outputs and calling .backward() | 
|  | first_grads = torch.autograd.grad(outputs, differentiable(args), ones, | 
|  | create_graph=True) | 
|  | first_grads = differentiable(first_grads) | 
|  | self.assertNotEqual( | 
|  | len(first_grads), 0, "None of the first grads depend on the input!") | 
|  |  | 
|  | batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads) | 
|  |  | 
|  | def vector_hessian_product(*vectors): | 
|  | outputs = torch.autograd.grad(first_grads, differentiable(args), vectors, | 
|  | retain_graph=True, allow_unused=True) | 
|  | outputs = tuple(out for out in outputs if out is not None) | 
|  | assert len(outputs) > 0 | 
|  | return outputs | 
|  |  | 
|  | self._vmap_test(vector_hessian_product, batched_vectors, | 
|  | check_propagates_grad=False) | 
|  |  | 
|  | def _test_arithmetic(self, op, device, test_grad_grad=True): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) | 
|  | scalar = 3.14 | 
|  | self._batched_grad_test(op, (x, y)) | 
|  | self._batched_grad_test(op, (scalar, y)) | 
|  | self._batched_grad_test(op, (x, scalar)) | 
|  |  | 
|  | if test_grad_grad: | 
|  | self._batched_grad_grad_test(op, (x, y)) | 
|  |  | 
|  | def test_add(self, device): | 
|  | self._test_arithmetic(torch.add, device, test_grad_grad=False) | 
|  | self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False) | 
|  |  | 
|  | def test_sub(self, device): | 
|  | self._test_arithmetic(torch.sub, device, test_grad_grad=False) | 
|  | self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False) | 
|  |  | 
|  | def test_mul(self, device): | 
|  | self._test_arithmetic(torch.mul, device) | 
|  | self._test_arithmetic(lambda x, y: x * y, device) | 
|  |  | 
|  | def test_div(self, device): | 
|  | self._test_arithmetic(torch.div, device) | 
|  | self._test_arithmetic(lambda x, y: x / y, device) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_binary_cross_entropy(self, device): | 
|  | x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True)) | 
|  | target = torch.rand(3, 2, device=device) | 
|  |  | 
|  | op = functools.partial(F.binary_cross_entropy, target=target) | 
|  |  | 
|  | self._batched_grad_test(op, (x,), {}) | 
|  | self._batched_grad_grad_test(op, (x,), {}) | 
|  |  | 
|  | def test_expand(self, device): | 
|  | x = torch.randn(2, 3, device=device, requires_grad=True) | 
|  |  | 
|  | def op(x): | 
|  | return x.expand(5, 5, 2, 3) | 
|  | self._batched_grad_test(op, (x,)) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_index(self, device): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | index = torch.tensor([[0, 0], [1, 1]], device=device) | 
|  |  | 
|  | def op(x): | 
|  | y = x * x | 
|  | return y[index] | 
|  |  | 
|  | self._batched_grad_test(op, (x,)) | 
|  | self._batched_grad_grad_test(op, (x,)) | 
|  |  | 
|  | def test_lgamma(self, device): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | self._batched_grad_test(Tensor.lgamma, (x,)) | 
|  | self._batched_grad_grad_test(Tensor.lgamma, (x,)) | 
|  |  | 
|  | def test_log(self, device): | 
|  | x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) | 
|  | self._batched_grad_test(torch.log, (x,)) | 
|  | self._batched_grad_grad_test(torch.log, (x,)) | 
|  |  | 
|  | def test_logsumexp(self, device): | 
|  | x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) | 
|  |  | 
|  | def op(x): | 
|  | return torch.logsumexp(x, -1) | 
|  |  | 
|  | self._batched_grad_test(op, (x,)) | 
|  | self._batched_grad_grad_test(op, (x,)) | 
|  |  | 
|  | def test_log1p(self, device): | 
|  | x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) | 
|  | self._batched_grad_test(torch.log1p, (x,)) | 
|  | self._batched_grad_grad_test(torch.log1p, (x,)) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_max(self, device): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | self._batched_grad_test(torch.max, (x,)) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_median(self, device): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | self._batched_grad_test(torch.median, (x,)) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_min(self, device): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | self._batched_grad_test(torch.min, (x,)) | 
|  |  | 
|  | def test_permute(self, device): | 
|  | x = torch.randn(2, 3, 5, requires_grad=True, device=device) | 
|  |  | 
|  | def op(x): | 
|  | return x.permute(2, 0, 1) | 
|  |  | 
|  | self._batched_grad_test(op, (x,)) | 
|  |  | 
|  | def test_reshape(self, device): | 
|  | x = torch.randn(2, 3, 5, requires_grad=True, device=device) | 
|  |  | 
|  | def op(x): | 
|  | return x.reshape([2 * 3, 5]) | 
|  |  | 
|  | self._batched_grad_test(op, (x,)) | 
|  |  | 
|  | def test_sigmoid(self, device): | 
|  | x = torch.randn(2, 3, requires_grad=True, device=device) | 
|  | self._batched_grad_test(Tensor.sigmoid, (x,)) | 
|  | self._batched_grad_grad_test(Tensor.sigmoid, (x,)) | 
|  |  | 
|  | def test_stack(self, device): | 
|  | x = torch.randn(2, 3, device=device, requires_grad=True) | 
|  | y = torch.randn(2, 3, device=device, requires_grad=True) | 
|  |  | 
|  | def op(x, y): | 
|  | return torch.stack([x, y]) | 
|  | self._batched_grad_test(op, (x, y)) | 
|  |  | 
|  | def test_select(self, device): | 
|  | x = torch.randn(2, 3, device=device, requires_grad=True) | 
|  | self._batched_grad_test(lambda x: x[1], (x,)) | 
|  | self._batched_grad_test(lambda x: x.select(1, 2), (x,)) | 
|  | self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) | 
|  |  | 
|  | def test_slice(self, device): | 
|  | x = torch.randn(2, 3, 5, device=device, requires_grad=True) | 
|  | self._batched_grad_test(lambda x: x[0:1], (x,)) | 
|  | self._batched_grad_test(lambda x: x[:, 1:3], (x,)) | 
|  | self._batched_grad_test(lambda x: x[..., 1:3], (x,)) | 
|  |  | 
|  | def test_trace(self, device): | 
|  | x = torch.randn(2, 3, device=device, requires_grad=True) | 
|  | self._batched_grad_test(Tensor.trace, (x,)) | 
|  |  | 
|  | @skipCUDAIfNoMagma | 
|  | @allowVmapFallbackUsage | 
|  | def test_symeig(self, device): | 
|  | def op(x): | 
|  | return torch.symeig(x, eigenvectors=True)[0] | 
|  |  | 
|  | x = torch.randn(3, 3, device=device, requires_grad=True) | 
|  | self._batched_grad_test(op, (x,), {}) | 
|  | self._batched_grad_grad_test(op, (x,), {}) | 
|  |  | 
|  | def test_threshold(self, device): | 
|  | x = torch.randn(2, 3, device=device, requires_grad=True) | 
|  | self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) | 
|  |  | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_inplace_on_view(self, device): | 
|  | leaf = torch.randn(4, 5, requires_grad=True) | 
|  |  | 
|  | def func(leaf): | 
|  | # Make sure the function is non-trivially twice differentiable | 
|  | base = leaf * leaf | 
|  | view = base[0] | 
|  | view.cos_() | 
|  | return view | 
|  |  | 
|  | self._batched_grad_test(func, (leaf,), {}) | 
|  | self._batched_grad_grad_test(func, (leaf,), {}) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_inplace_manyview(self, device): | 
|  | leaf = torch.randn(4, 4, 5, requires_grad=True) | 
|  |  | 
|  | def func(leaf): | 
|  | # Make sure the function is non-trivially twice differentiable | 
|  | base = leaf * leaf | 
|  | view = base.transpose(0, 2) | 
|  | view = view[1] | 
|  | view = view.diagonal() | 
|  | view = view[::2] | 
|  | view.cos_() | 
|  | return view | 
|  |  | 
|  | self._batched_grad_test(func, (leaf,), {}) | 
|  | self._batched_grad_grad_test(func, (leaf,), {}) | 
|  |  | 
|  | def test_diagonal(self, device): | 
|  | x = torch.randn(4, 5, device=device, requires_grad=True) | 
|  | self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) | 
|  |  | 
|  | x = torch.randn(3, 4, 5, device=device, requires_grad=True) | 
|  | self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_unrelated_output(self, device): | 
|  | B0 = 3 | 
|  | x = torch.randn([], requires_grad=True) | 
|  | y = torch.randn([], requires_grad=True) | 
|  | gy = torch.randn(B0, requires_grad=True) | 
|  |  | 
|  | def vjp(v): | 
|  | res, = torch.autograd.grad(y, x, v, allow_unused=True) | 
|  | return torch.zeros_like(x) if res is None else res | 
|  |  | 
|  | result = vmap(vjp)(gy) | 
|  | self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) | 
|  |  | 
|  | @allowVmapFallbackUsage | 
|  | def test_unrelated_output_multiple_grad(self, device): | 
|  | B0 = 3 | 
|  | x = torch.randn([], requires_grad=True) | 
|  | y = torch.randn([], requires_grad=True) | 
|  | gy = torch.randn(B0, requires_grad=True) | 
|  |  | 
|  | def vjp(v): | 
|  | res, = torch.autograd.grad(y, x, v, allow_unused=True) | 
|  | return torch.zeros_like(x) if res is None else res | 
|  |  | 
|  | _ = vjp(gy[0]) | 
|  | result = vmap(vjp)(gy) | 
|  | self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) | 
|  |  | 
|  | instantiate_device_type_tests( | 
|  | TestVmapBatchedGradient, | 
|  | globals(), | 
|  | None, | 
|  | ) | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |