Revert "Wrap indirect indexing on CUDA (#105055)"

This reverts commit 85c673e6b25173e2697a0dd741a9b2ebb33dec1d.

Reverted https://github.com/pytorch/pytorch/pull/105055 on behalf of https://github.com/peterbell10 due to Causes failure in inductor_torchbench ([comment](https://github.com/pytorch/pytorch/pull/105055#issuecomment-1688871947))
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index f8d405c..f4efe6b 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -7194,63 +7194,6 @@
 
             self.assertTrue(max_live_tensors == 2)
 
-        @skipIfRocm
-        def test_neg_index(self):
-            def test(fn, inps, has_assert: bool, has_wrapping=True):
-                for dynamic in (True, False):
-                    fn_opt = torch.compile(dynamic=dynamic)(fn)
-                    code = run_and_get_triton_code(fn_opt, *inps)
-                    self.assertTrue(("tl.where" in code) is has_wrapping)
-                    self.assertTrue(("device_assert" in code) is has_assert)
-                    self.assertEqual(fn(*inps), fn_opt(*inps))
-
-            def indirect(a, b):
-                return a[b - 1]
-
-            a = torch.rand(1024, device="cuda")
-            b = torch.zeros(4, dtype=torch.long, device="cuda")
-            test(indirect, (a, b), has_assert=True)
-
-            def direct(x):
-                return x[:, -1]
-
-            a = torch.rand(1, 64, 32, device="cuda")
-            test(direct, (a,), has_assert=False, has_wrapping=False)
-
-            def flip(a, b):
-                return a[b]
-
-            a = torch.rand(1024, device="cuda")
-            b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
-            test(flip, (a, b), has_assert=True)
-
-            # Constant propagate a constant that's negative
-            def flip_with_index_constant(a):
-                b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
-                return a[b]
-
-            # Wrapping is constant-folded
-            test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
-
-            # Operation where we can't prove that the index is always positive or negative
-            def pos_and_neg(a):
-                b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device="cuda")
-                return a[b]
-
-            # It has wrapping but no assert
-            test(pos_and_neg, (a,), has_assert=False, has_wrapping=True)
-
-            # We currently don't do constant propagation with float constants
-            def flip_with_index(a):
-                b = 1.0 * torch.arange(
-                    start=-1, end=-a.numel() - 1, step=-1, device="cuda"
-                )
-                b = b.int()
-                return a[b]
-
-            # Constant is propagated as we can prove that the result is always negative.
-            test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
-
         # See https://github.com/pytorch/pytorch/issues/100348
         def test_inductor_detach_view(self):
             def fn(x: torch.Tensor) -> torch.Tensor:
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index a461161..487ef9a 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -314,12 +314,6 @@
     def _print_CleanDiv(self, expr):
         return self._print_FloorDiv(expr)  # type: ignore[attr-defined]
 
-    def _print_GreaterThan(self, expr):
-        # GreaterThan:          >=
-        # StrictlyGreaterThan:  >
-        # Go figure...
-        return " >= ".join(map(self.paren, map(self._print, expr.args)))
-
 
 class PythonPrinter(ExprPrinter):
     def _print_ModularIndexing(self, expr):
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 62f7f85..471448a 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -294,12 +294,6 @@
     def _print_Integer(self, expr):
         return f"{int(expr)}L"
 
-    def _print_Where(self, expr):
-        c = self.paren(self.doprint(expr.args[0]))
-        p = self.paren(self.doprint(expr.args[1]))
-        q = self.paren(self.doprint(expr.args[2]))
-        return f"{c} ? {p} : {q}"
-
     def _print_ModularIndexing(self, expr):
         x, div, mod = expr.args
         x = self.paren(self.doprint(x))
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 0fc8b7f..9fd9c19 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -65,12 +65,6 @@
     def _helper_sqrt(self, expr):
         return f"tl.math.sqrt({self.paren(self._print(expr))}.to(tl.float32))"
 
-    def _print_Where(self, expr):
-        c = self.doprint(expr.args[0])
-        p = self.doprint(expr.args[1])
-        q = self.doprint(expr.args[2])
-        return f"tl.where({c}, {p}, {q})"
-
     def _print_Min(self, expr):
         nargs = len(expr.args)
         if len(expr.args) == 1:
@@ -1192,8 +1186,6 @@
             self._load_mask = prior
 
     def indirect_indexing(self, var, size, check=True):
-        # TODO(lezcano) This code should be lifted to codegen/common.py.
-        # This should be easy, as now CSE variables carry bounds info
         class IndirectAssertLine(DeferredLineBase):
             def __init__(self, line, var, mask, size_map):
                 self.var = var
@@ -1231,28 +1223,6 @@
             def _new_line(self, line):
                 return IndirectAssertLine(line, self.var, self.mask, self.size_map)
 
-        if var.bounds.lower < 0:
-            new_bounds = ValueRanges.unknown()
-            if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
-                # Take the negative part of the bound and add size to it
-                # Then take union of that and the positive part
-                # This is a tighter bound than that of a generic ops.where, as we have info on the cond
-                neg = var.bounds & ValueRanges(-sympy.oo, -1)
-                new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
-                # We don't have a good way of representing the empty range
-                if var.bounds.upper >= 0:
-                    pos = var.bounds & ValueRanges(0, sympy.oo)
-                    new_bounds = new_bounds | pos
-
-            stm = f"{var} + {self.index_to_str(size)}"
-            # Mixed negative and non-negative
-            if var.bounds.upper >= 0:
-                stm = f"tl.where({var} < 0, {stm}, {var})"
-            new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
-
-            new_var.update_on_args("index_wrap", (var,), {})
-            var = new_var
-
         generate_assert = (
             (check or config.debug_index_asserts)
             and config.triton.assert_indirect_indexing
diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py
index bbadb02..7449205 100644
--- a/torch/_inductor/index_propagation.py
+++ b/torch/_inductor/index_propagation.py
@@ -27,7 +27,7 @@
 
 import torch
 from torch._prims_common import is_boolean_dtype, is_integer_dtype
-from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
+from torch.utils._sympy.functions import FloorDiv, ModularIndexing
 
 
 @dataclass
@@ -229,12 +229,7 @@
     def indirect_indexing(
         self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
     ) -> Any:
-        # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE
-        #     for SymPy expressions, so we don't want to repeat idx too much
-
         # indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
         if isinstance(index, IndexPropVar) and index.is_symbolic:
-            # If we are turning a indirect indexing into direct, we need to wrap it.
-            index = index.value.expr
-            return index + Where(index >= 0, 0, size)
+            return index.value.expr
         return self.fallback("indirect_indexing", (index, size, check), {}).value
diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py
index 4ce4adf..30b8b86 100644
--- a/torch/utils/_sympy/functions.py
+++ b/torch/utils/_sympy/functions.py
@@ -137,19 +137,6 @@
         if isinstance(base, FloorDiv):
             return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
 
-class Where(sympy.Function):
-    """
-    Good ol' ternary operator
-    """
-
-    nargs = (3,)
-
-    @classmethod
-    def eval(cls, c, p, q):
-        if c == sympy.true:
-            return p
-        elif c == sympy.false:
-            return q
 
 class Mod(sympy.Function):
     """
diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py
index 3eaf4e1..672eae8 100644
--- a/torch/utils/_sympy/interp.py
+++ b/torch/utils/_sympy/interp.py
@@ -14,7 +14,7 @@
 from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
 
 import torch
-from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing, Where
+from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing
 
 
 # TODO: Dedupe this with SYMPY_INTERP
@@ -42,7 +42,6 @@
         TrueDiv: "truediv",
         FloorDiv: "floordiv",
         CleanDiv: "div",
-        Where: "where",
         sympy.Add: "add",
         sympy.Mul: "mul",
         Pow: "pow",
diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py
index 87db5f7..ba6a2fa 100644
--- a/torch/utils/_sympy/value_ranges.py
+++ b/torch/utils/_sympy/value_ranges.py
@@ -82,12 +82,9 @@
         x = simple_sympify(x)
         return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
 
-    def tighten(self, other) -> "ValueRanges":
+    def tighten(self, other: "ValueRanges"):
         """Given two ValueRanges, returns their intersection"""
-        return self & other
-
-    # Intersection
-    def __and__(self, other) -> "ValueRanges":
+        # Some invariants
         if other == ValueRanges.unknown():
             return self
         if self == ValueRanges.unknown():
@@ -99,16 +96,9 @@
             range = ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
         return range
 
-    # Union
-    def __or__(self, other) -> "ValueRanges":
-        if ValueRanges.unknown() in (self, other):
-            return ValueRanges.unknown()
-        assert self.is_bool == other.is_bool, (self, other)
-        if self.is_bool:
-            range = ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
-        else:
-            range = ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
-        return range
+    # Intersection
+    def __and__(self, other):
+        return ValueRanges(lower=max(self.lower, other.lower), upper=min(self.upper, other.upper))
 
     def is_singleton(self) -> bool:
         return self.lower == self.upper
@@ -445,17 +435,6 @@
             return ValueRanges.unknown()
         return ValueRanges.increasing_map(x, sympy.sqrt)
 
-    @staticmethod
-    def where(a, b, c):
-        b = ValueRanges.wrap(b)
-        c = ValueRanges.wrap(c)
-        assert a.is_bool
-        assert b.is_bool == c.is_bool
-        if b.is_bool:
-            return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
-        else:
-            return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
-
 
 class ValueRangeAnalysis(SymPyValueRangeAnalysis):
     def __init__(self):
@@ -549,6 +528,17 @@
     def sub(cls, a, b):
         return cls.add(a, cls.neg(b))
 
+    @staticmethod
+    def where(a, b, c):
+        b = ValueRanges.wrap(b)
+        c = ValueRanges.wrap(c)
+        assert a.is_bool
+        assert b.is_bool == c.is_bool
+        if b.is_bool:
+            return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
+        else:
+            return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
+
     def __getattr__(self, name):
         log.debug("unhandled ValueRange op %s", name)
         return self.default_handler