Revert "Add lowering for logcumsumexp (#118753)"
This reverts commit 5a77ee65879b58e99911fd53d92ddb55a1c234eb.
Reverted https://github.com/pytorch/pytorch/pull/118753 on behalf of https://github.com/jeffdaily due to broke ROCm CI, but not seen until trunk job ([comment](https://github.com/pytorch/pytorch/pull/118753#issuecomment-1935074235))
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 6fad1d2..3e24d74 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -287,8 +287,6 @@
*,
atol=None,
rtol=None,
- grad_atol=None,
- grad_rtol=None,
check_lowp=True,
exact_dtype=True,
nopython=True,
@@ -465,8 +463,8 @@
self.assertEqual(
actual_grad,
expect_grad,
- atol=grad_atol or atol,
- rtol=grad_rtol or rtol,
+ atol=atol,
+ rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
)
@@ -483,8 +481,6 @@
*,
atol=None,
rtol=None,
- grad_atol=None,
- grad_rtol=None,
check_lowp=True,
exact_dtype=True,
nopython=True,
@@ -511,8 +507,6 @@
kwargs,
atol=atol,
rtol=rtol,
- grad_atol=grad_atol,
- grad_rtol=grad_rtol,
exact_dtype=exact_dtype,
nopython=nopython,
reference_in_float=reference_in_float,
@@ -1300,24 +1294,6 @@
a = torch.rand(())
self.common(fn, (a,))
- def test_logcumsumexp(self):
- def fn(x):
- return x.logcumsumexp(0), x.logcumsumexp(1)
-
- # Persistent reductions
- self.common(fn, (torch.rand(16, 32),), check_lowp=not TEST_WITH_ROCM)
- self.common(fn, (torch.rand(20, 30),), check_lowp=not TEST_WITH_ROCM)
-
- # Non-persistent reduction
- self.common(fn, (torch.rand(100, 4000),), check_lowp=not TEST_WITH_ROCM)
-
- def test_logcumsumexp_zero_dim(self):
- def fn(x):
- return x.logcumsumexp(0), x.logcumsumexp(-1)
-
- a = torch.rand(())
- self.common(fn, (a,))
-
def test_clamp(self):
def fn(a, b):
return (a.clamp(-0.1, 0.1), b.clamp(0), torch.clamp(a + b, max=0))
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
index e065b14..b3a7b48 100644
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
@@ -168,8 +168,6 @@
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")),
- "test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)),
- "test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)),
"test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_max_pool2d8_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure(
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index ffe129a..11173fa 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -322,7 +322,6 @@
("cauchy", "cuda"): {"reference_in_float": True},
("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002},
("cumprod", "cuda"): {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002},
- ("logcumsumexp", "cuda"): {"grad_atol": 8e-4, "grad_rtol": 0.001},
("exponential", "cuda"): {"reference_in_float": True},
("geometric", "cuda"): {"reference_in_float": True},
("kron", "cuda", f16): {"reference_in_float": True},
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 716e6dd..99a7dd9 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -2264,6 +2264,7 @@
make_fallback(aten._linalg_solve_ex)
make_fallback(aten.linalg_solve_triangular)
make_fallback(aten._linalg_svd)
+make_fallback(aten.logcumsumexp)
make_fallback(aten.lu_unpack)
make_fallback(aten.max_pool3d_with_indices)
make_fallback(aten.max_unpool2d)
@@ -5002,7 +5003,6 @@
fallback_cumsum = fallback_handler(aten.cumsum.default)
fallback_cumprod = fallback_handler(aten.cumprod.default)
-fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default)
@register_lowering(aten.cumsum)
@@ -5043,26 +5043,6 @@
return result
-@register_lowering(aten.logcumsumexp)
-def logcumsumexp(x, dim):
- def log_add_exp_helper(a, b):
- min_v = ops.minimum(a, b)
- max_v = ops.maximum(a, b)
- mask = (min_v != max_v) | (~ops.isinf(min_v))
- return ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a)
-
- dtype = x.get_dtype()
- if len(x.get_size()) == 0:
- assert dim in [0, -1]
- return clone(x)
-
- kwargs = _make_scan_inner(x, axis=dim, dtype=dtype)
- result = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper, init=float("-inf"))
- if result is None:
- return fallback_logcumsumexp(x, dim=dim)
- return result
-
-
@register_lowering(aten.prod)
def prod(x, axis=None, keepdims=False, *, dtype=None):
if (