[pt2] add meta function for `logcumsumexp` (#98683)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98683
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index da26caa..c732163 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2578,7 +2578,6 @@
xfail('linalg.tensorsolve', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.vander', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
- xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
xfail('logdet', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
diff --git a/test/test_meta.py b/test/test_meta.py
index 1abd9c4..b8e0226 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -614,7 +614,6 @@
torch.histogram : {f64, f32},
torch.histogramdd : {f64, f32},
torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32},
- torch.logcumsumexp : {f64, bf16, f32, c64, c128},
torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32},
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
torch.multinomial : {f64, bf16, f32},
@@ -719,7 +718,6 @@
torch.kthvalue: {f16}, # aten::kthvalue.values
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
- torch.logcumsumexp: {bf16, f16}, # aten::_logcumsumexp, aten::_logcumsumexp.out
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
@@ -860,8 +858,6 @@
aten.histogram.bin_ct : {f32, f64},
aten.histogram.bins_tensor : {f32, f64},
aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
- aten.logcumsumexp.default : {bf16, f32, f64, c64, c128},
- aten.logcumsumexp.out : {bf16, f32, f64, c64, c128},
aten.max_pool3d_with_indices.default : {f32, f64},
aten.max_unpool2d.default : {f32, f64},
aten.max_unpool3d.default : {f32, f64},
@@ -936,8 +932,6 @@
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
- aten.logcumsumexp.default: {bf16, f16}, # aten::_logcumsumexp
- aten.logcumsumexp.out: {bf16, f16}, # aten::_logcumsumexp.out
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices
aten.max_unpool2d.default: {f16}, # aten::max_unpool2d
aten.max_unpool3d.default: {f16}, # aten::max_unpool3d
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 493f896..77e4956 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1428,7 +1428,6 @@
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
- xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 34b78e0..128c89c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -101,6 +101,14 @@
return values, indices
+@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
+@out_wrapper()
+def logcumsumexp(self, dim):
+ # Checks that dim is within bounds
+ maybe_wrap_dim(dim, self.ndim)
+ return torch.empty_like(self).contiguous()
+
+
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
@out_wrapper()
def meta_fft_c2c(self, dim, normalization, forward):