[torch][repeat_interleave] remove stream syncronization if output size is given (#58417)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58417
Same as title.
Test Plan:
Rely on CI signal.
Update unit test to exercise new code path as well.
Reviewed By: ngimel
Differential Revision: D28482927
fbshipit-source-id: 3ec8682810ed5c8547b1e8d3869924480ce63dcd
diff --git a/test/test_torch.py b/test/test_torch.py
index f071ffe..d0ef81b 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -4181,14 +4181,25 @@
def test_repeat_interleave(self, device):
y = torch.tensor([[1, 2], [3, 4]], device=device)
for dtype in [torch.int, torch.long]:
+ lengths = torch.tensor([1, 2], dtype=dtype, device=device)
+ output_size = torch.sum(lengths)
a = torch.repeat_interleave(
y,
- torch.tensor([1, 2], dtype=dtype, device=device),
+ lengths,
dim=0,
)
self.assertEqual(a.dtype, y.dtype)
self.assertEqual(a.size(), torch.Size([3, 2]))
+ a_with_output = torch.repeat_interleave(
+ y,
+ lengths,
+ dim=0,
+ output_size=output_size,
+ )
+ self.assertEqual(a_with_output.dtype, y.dtype)
+ self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
+
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
def test_bernoulli_p(self, device, dtype):