repeat_interleaves meta function
Taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78602
Approved by: https://github.com/mruberry
diff --git a/test/test_meta.py b/test/test_meta.py
index 4532bb3..f6487ae 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -440,6 +440,9 @@
if func is torch.tensor_split:
# Use original indices_or_sections, this argument is data dependent
meta_args = (meta_args[0], args[1]) + meta_args[2:]
+ elif func is torch.ops.aten.repeat_interleave.Tensor:
+ if kwargs.get("output_size", None) is None:
+ meta_args = args
try:
# Suppress warnings, this doesn't matter for test_meta.py
# but it does matter if you want to use this decorator
@@ -840,7 +843,6 @@
aten.polar.default: {f64, f32},
aten.prelu.default: {bf16, f64, f32},
aten.relu.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
- aten.repeat_interleave.Tensor: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.roll.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.rrelu_with_noise.default: {bf16, f64, f32},
aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index cc6b8f9..6a3edb5 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -135,3 +135,11 @@
def meta_adaptive_avg_pool3d(self, output_size):
check(self.ndim == 4 or self.ndim == 5, f"Expected 4D or 5D tensor, but got {self.shape}")
return self.new_empty(self.shape[:-3] + tuple(output_size))
+
+@torch.library.impl(meta_lib, "repeat_interleave.Tensor")
+def meta_repeat_interleave_Tensor(repeats, output_size=None):
+ if output_size is None:
+ raise RuntimeError(
+ "cannot repeat_interleave a meta tensor without output_size"
+ )
+ return repeats.new_empty(output_size)