[pt2] add meta function for `logcumsumexp` (#98683)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98683
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index da26caa..c732163 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2578,7 +2578,6 @@
     xfail('linalg.tensorsolve', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('linalg.vander', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
-    xfail('logcumsumexp', ''),  # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
     xfail('logdet', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('lu', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
     xfail('lu_solve', ''),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
diff --git a/test/test_meta.py b/test/test_meta.py
index 1abd9c4..b8e0226 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -614,7 +614,6 @@
     torch.histogram : {f64, f32},
     torch.histogramdd : {f64, f32},
     torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32},
-    torch.logcumsumexp : {f64, bf16, f32, c64, c128},
     torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32},
     torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
     torch.multinomial : {f64, bf16, f32},
@@ -719,7 +718,6 @@
     torch.kthvalue: {f16},  # aten::kthvalue.values
     torch.linalg.householder_product: {f32, f64},  # aten::linalg_householder_product, aten::linalg_householder_product.out
     torch.linalg.solve_triangular: {f32, f64},  # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
-    torch.logcumsumexp: {bf16, f16},  # aten::_logcumsumexp, aten::_logcumsumexp.out
     torch.matrix_exp: {f16},  # aten::linalg_matrix_exp
     torch.median: {f16},  # aten::median, aten::median.dim_values
     torch.multinomial: {f16},  # aten::multinomial, aten::multinomial.out
@@ -860,8 +858,6 @@
     aten.histogram.bin_ct : {f32, f64},
     aten.histogram.bins_tensor : {f32, f64},
     aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
-    aten.logcumsumexp.default : {bf16, f32, f64, c64, c128},
-    aten.logcumsumexp.out : {bf16, f32, f64, c64, c128},
     aten.max_pool3d_with_indices.default : {f32, f64},
     aten.max_unpool2d.default : {f32, f64},
     aten.max_unpool3d.default : {f32, f64},
@@ -936,8 +932,6 @@
     aten.linalg_solve_triangular.out: {f32, f64},  # aten::linalg_solve_triangular.out
     aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
     aten.log_sigmoid_forward.output : {bf16, f16, f64, f32},  # aten::log_sigmoid_forward.output
-    aten.logcumsumexp.default: {bf16, f16},  # aten::_logcumsumexp
-    aten.logcumsumexp.out: {bf16, f16},  # aten::_logcumsumexp.out
     aten.max_pool3d_with_indices.default: {bf16, f16},  # aten::max_pool3d_with_indices
     aten.max_unpool2d.default: {f16},  # aten::max_unpool2d
     aten.max_unpool3d.default: {f16},  # aten::max_unpool3d
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 493f896..77e4956 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1428,7 +1428,6 @@
     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
-    xfail('logcumsumexp', ''),  # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
     xfail('logdet', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('lu', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
     xfail('lu_solve', ''),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 34b78e0..128c89c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -101,6 +101,14 @@
     return values, indices
 
 
+@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
+@out_wrapper()
+def logcumsumexp(self, dim):
+    # Checks that dim is within bounds
+    maybe_wrap_dim(dim, self.ndim)
+    return torch.empty_like(self).contiguous()
+
+
 @register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
 @out_wrapper()
 def meta_fft_c2c(self, dim, normalization, forward):