blob: c54a72f3b2751a2078e7fd348a5c4853bc8e6b9c [file] [log] [blame]
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
from torch import vmap
class TestVmapAPI(TestCase):
def test_non_tensor_output_raises(self):
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
output = vmap(lambda x: 3.14)(torch.ones(3))
def multiple_outputs(x):
return x, 3
with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
vmap(multiple_outputs)(torch.ones(3))
def test_different_map_dim_size_raises(self):
x = torch.randn(2)
y = torch.randn(3)
expected_msg = 'Expected all tensors to have the same size in the mapped dimension'
with self.assertRaisesRegex(ValueError, expected_msg):
vmap(torch.mul)(x, y)
def test_func_with_no_inputs(self):
expected_msg = 'got no inputs'
def foo():
return torch.randn(3)
def bar(x):
return torch.randn(3)
with self.assertRaisesRegex(ValueError, expected_msg):
vmap(foo)()
with self.assertRaisesRegex(ValueError, expected_msg):
vmap(bar)()
def test_constant_function(self):
output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
def test_single_input(self):
x = torch.randn(2, 3)
def square(x):
return x * x
output = vmap(square)(x)
self.assertEqual(output, x * x)
def test_multiple_inputs(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
output = vmap(torch.mul)(x, y)
self.assertEqual(output, x * y)
def test_multiple_outputs(self):
def foo(x):
return x * x, x * x * x
x = torch.randn(3)
outputs = vmap(foo)(x)
self.assertEqual(outputs[0], x * x)
self.assertEqual(outputs[1], x * x * x)
def test_multiple_outputs_error_cases(self):
# This is the same thing as
# def returns_tuple_of_tensors(x):
# return x, x
def returns_tuple_of_tensors(x):
return (x, x)
def returns_list_of_two_tensors(x):
return [x, x]
def returns_list_of_one_tensor(x):
return [x]
x = torch.randn(3)
# should not throw
vmap(returns_tuple_of_tensors)(x)
# jax supports these, but we don't yet
msg = "must only return Tensors, got type <class 'list'>"
with self.assertRaisesRegex(ValueError, msg):
vmap(returns_list_of_two_tensors)(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(returns_list_of_one_tensor)(x)
def test_nested_with_same_map_dim(self):
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
output = vmap(vmap(torch.mul))(x, y)
self.assertEqual(output, x * y)
output = vmap(vmap(vmap(torch.mul)))(x, y)
self.assertEqual(output, x * y)
def test_nested_with_different_map_dim(self):
x = torch.randn(2, 3)
y = torch.randn(5, 3)
output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
self.assertEqual(output.shape, (2, 5, 3))
self.assertEqual(output, x.view(2, 1, 3) * y)
z = torch.randn(7, 3)
output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
self.assertEqual(output.shape, (2, 5, 7, 3))
self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
def test_noop_in_inner_vmap(self):
x = torch.randn(3)
y = torch.randn(5)
output = vmap(lambda x: vmap(lambda y: x)(y))(x)
self.assertEqual(output, x.view(3, 1).expand(3, 5))
def test_unsupported_op_err_msg(self):
def foo(x):
return torch.cos(x)
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, 'NYI: Calling aten::cos inside of vmap'):
vmap(foo)(x)
def test_unsupported_inplace_op_err_msg(self):
def foo(x):
return x.cos_()
x = torch.randn(3)
# TODO(rzou): Yeah, this error message is pretty bad because the
# dispatcher's fallback mechanism doesn't work for ops that don't support
# boxing. Fix the error message at some point.
with self.assertRaisesRegex(
RuntimeError, 'Tried to call KernelFunction::call'):
vmap(foo)(x)
def test_nonzero_out_dims(self):
# Basic test
tensor = torch.randn(2, 3)
result = vmap(lambda x: x, out_dims=1)(tensor)
self.assertEqual(result, tensor.permute(1, 0))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# Test that the batch dimension gets permuted to dim 2
tensor = torch.randn(2, 3, 5, 7)
result = vmap(lambda x: x, out_dims=2)(tensor)
self.assertEqual(result, tensor.permute(1, 2, 0, 3))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# negative out_dim
tensor = torch.randn(2, 3, 5, 7)
result = vmap(lambda x: x, out_dims=-1)(tensor)
self.assertEqual(result, tensor.permute(1, 2, 3, 0))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# check that out_dims works on ALL outputs
tensor = torch.randn(2, 3, 5, 7)
other = torch.randn(2, 3, 5, 7)
result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
# use out_dims with the maximum vmap-able tensor dims (64 dims)
ndims = 64
shape = [2] + [1] * (ndims - 1)
expected_shape = [1, 1, 2] + [1] * (ndims - 3)
tensor = torch.randn(shape)
result = vmap(lambda x: x, out_dims=2)(tensor)
self.assertEqual(result.shape, expected_shape)
# test something that is not the identity function
def foo(x, y):
return x, x * y, x * y * y
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
result = vmap(foo, out_dims=1)(x, y)
self.assertEqual(
result,
(x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
def test_multiple_out_dims(self):
def foo(x):
return x, x
def bar(x, y):
return x, x, x, x * y
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
result = vmap(foo, out_dims=(0, 1))(x)
self.assertEqual(result, (x, x.permute(1, 0, 2)))
result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
expected = (
x.permute(1, 2, 0),
x,
x.permute(1, 0, 2),
(x * y).permute(1, 2, 0),
)
self.assertEqual(result, expected)
def test_nested_out_dims(self):
y = torch.randn(2, 3, 5, 7)
# Inner vmap has non-zero out_dim
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
self.assertEqual(result.shape, (2, 5, 3, 7))
self.assertEqual(result, y.permute(0, 2, 1, 3))
# all vmaps have non-zero out_dim
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
self.assertEqual(result.shape, (5, 2, 3, 7))
self.assertEqual(result, y.permute(2, 0, 1, 3))
# throwing in some negative out_dims
result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
self.assertEqual(result.shape, (5, 7, 3, 2))
self.assertEqual(result, y.permute(2, 3, 1, 0))
# testing fn that isn't the identity
x = torch.randn(2, 3)
y = torch.randn(5, 3)
result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
self.assertEqual(result.shape, (3, 2, 5))
self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
def test_out_dims_edge_case(self):
def foo(x):
return x
# Test that we accept out_dims=(1,) for a function with one output.
tensor = torch.randn(2, 3)
expected = vmap(foo, out_dims=1)(tensor)
result = vmap(foo, out_dims=(1,))(tensor)
self.assertEqual(result, expected)
def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
msg = '`out_dims` must be an int or a tuple of int'
tensor = torch.randn(2, 3)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims='lol')(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=('lol',))(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=None)(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=(None,))(tensor)
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
msg = '`out_dims` must have one dim per output'
x = torch.randn(2, 3, 5)
# Too many out_dims
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=(0, 0))(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
# Too few out_dims
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x), out_dims=(0,))(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
def test_out_dim_out_of_bounds_err_msg(self):
# TODO(rzou): This error message isn't that great. It comes straight
# from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
# the error message in the future in C++
msg = 'Dimension out of range'
x = torch.randn(2, 3, 5)
with self.assertRaisesRegex(IndexError, msg):
vmap(lambda x: x, out_dims=3)(x)
with self.assertRaisesRegex(IndexError, msg):
vmap(lambda x: x, out_dims=-4)(x)
if __name__ == '__main__':
run_tests()