| from torch.testing._internal.common_utils import TestCase, run_tests |
| import torch |
| from torch import vmap |
| |
| class TestVmap(TestCase): |
| def test_non_tensor_output_raises(self): |
| with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"): |
| output = vmap(lambda x: 3.14)(torch.ones(3)) |
| |
| def multiple_outputs(x): |
| return x, 3 |
| |
| with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): |
| vmap(multiple_outputs)(torch.ones(3)) |
| |
| def test_different_map_dim_size_raises(self): |
| x = torch.randn(2) |
| y = torch.randn(3) |
| expected_msg = 'Expected all tensors to have the same size in the mapped dimension' |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(torch.mul)(x, y) |
| |
| def test_func_with_no_inputs(self): |
| expected_msg = 'got no inputs' |
| |
| def foo(): |
| return torch.randn(3) |
| |
| def bar(x): |
| return torch.randn(3) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(foo)() |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| vmap(bar)() |
| |
| def test_constant_function(self): |
| output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) |
| self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) |
| |
| def test_single_input(self): |
| x = torch.randn(2, 3) |
| |
| def square(x): |
| return x * x |
| |
| output = vmap(square)(x) |
| self.assertEqual(output, x * x) |
| |
| def test_multiple_inputs(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(2, 3) |
| output = vmap(torch.mul)(x, y) |
| self.assertEqual(output, x * y) |
| |
| def test_multiple_outputs(self): |
| def foo(x): |
| return x * x, x * x * x |
| |
| x = torch.randn(3) |
| outputs = vmap(foo)(x) |
| self.assertEqual(outputs[0], x * x) |
| self.assertEqual(outputs[1], x * x * x) |
| |
| def test_multiple_outputs_error_cases(self): |
| # This is the same thing as |
| # def returns_tuple_of_tensors(x): |
| # return x, x |
| def returns_tuple_of_tensors(x): |
| return (x, x) |
| |
| def returns_list_of_two_tensors(x): |
| return [x, x] |
| |
| def returns_list_of_one_tensor(x): |
| return [x] |
| |
| x = torch.randn(3) |
| |
| # should not throw |
| vmap(returns_tuple_of_tensors)(x) |
| |
| # jax supports these, but we don't yet |
| msg = "must only return Tensors, got type <class 'list'>" |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(returns_list_of_two_tensors)(x) |
| with self.assertRaisesRegex(ValueError, msg): |
| vmap(returns_list_of_one_tensor)(x) |
| |
| def test_nested_with_same_map_dim(self): |
| x = torch.randn(2, 3, 5) |
| y = torch.randn(2, 3, 5) |
| output = vmap(vmap(torch.mul))(x, y) |
| self.assertEqual(output, x * y) |
| |
| output = vmap(vmap(vmap(torch.mul)))(x, y) |
| self.assertEqual(output, x * y) |
| |
| def test_nested_with_different_map_dim(self): |
| x = torch.randn(2, 3) |
| y = torch.randn(5, 3) |
| output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) |
| self.assertEqual(output.shape, (2, 5, 3)) |
| self.assertEqual(output, x.view(2, 1, 3) * y) |
| |
| z = torch.randn(7, 3) |
| output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) |
| self.assertEqual(output.shape, (2, 5, 7, 3)) |
| self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) |
| |
| def test_noop_in_inner_vmap(self): |
| x = torch.randn(3) |
| y = torch.randn(5) |
| output = vmap(lambda x: vmap(lambda y: x)(y))(x) |
| self.assertEqual(output, x.view(3, 1).expand(3, 5)) |
| |
| def test_unsupported_op_err_msg(self): |
| def foo(x): |
| return torch.cos(x) |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, 'NYI: Calling aten::cos inside of vmap'): |
| vmap(foo)(x) |
| |
| def test_unsupported_inplace_op_err_msg(self): |
| def foo(x): |
| return x.cos_() |
| |
| x = torch.randn(3) |
| # TODO(rzou): Yeah, this error message is pretty bad because the |
| # dispatcher's fallback mechanism doesn't work for ops that don't support |
| # boxing. Fix the error message at some point. |
| with self.assertRaisesRegex( |
| RuntimeError, 'Tried to call KernelFunction::call'): |
| vmap(foo)(x) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |