blob: 5c77cf9d0b5c17db258f6c14a783106cf9d57a6e [file] [log] [blame]
# Owner(s): ["module: primTorch"]
from functools import partial
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
dtypes,
)
import torch._prims as prims
from torch._prims.executor import make_traced
class TestPrims(TestCase):
@onlyCUDA
@dtypes(torch.float32)
def test_broadcast_in_dim(self, device, dtype):
def _wrapper(a, shape, broadcast_dimensions):
return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
# TODO: FIXME:
# for executor in ('aten', 'nvfuser'):
for executor in ("aten",):
fn = partial(traced, executor=executor)
# Same shape
shape = (5, 5)
a = make_arg(shape)
result = fn(a, shape, (0, 1))
self.assertEqual(result.shape, a.shape)
self.assertTrue(result.is_contiguous)
self.assertEqual(a, result)
# Error input: reordering dims
with self.assertRaises(Exception):
result = fn(a, shape, (1, 0))
# Adding outermost dimensions
a = make_arg((5, 5))
target_shape = (3, 3, 5, 5)
result = fn(a, target_shape, (2, 3))
self.assertEqual(result.shape, target_shape)
self.assertEqual(a.broadcast_to(target_shape), result)
# Expands
a = make_arg((1, 5, 1))
target_shape = (3, 5, 7)
result = fn(a, target_shape, (0, 1, 2))
self.assertEqual(result.shape, target_shape)
self.assertEqual(a.expand_as(result), result)
# Unsqueezes
a = make_arg((1, 2, 3))
target_shape = (1, 2, 1, 3)
result = fn(a, target_shape, (0, 1, 3))
self.assertEqual(result.shape, target_shape)
self.assertEqual(a.unsqueeze(2), result)
# Adds outermost, expands, and unsqueezes
a = make_arg((1, 2, 3))
target_shape = (4, 1, 7, 2, 3, 3)
result = fn(a, target_shape, (1, 3, 4))
self.assertEqual(result.shape, target_shape)
a.unsqueeze_(3)
a.unsqueeze_(1)
a.unsqueeze_(0)
self.assertEqual(a.expand_as(result), result)
instantiate_device_type_tests(TestPrims, globals())
if __name__ == "__main__":
run_tests()