Revert "Wrap indirect indexing on CUDA (#105055)"
This reverts commit 85c673e6b25173e2697a0dd741a9b2ebb33dec1d.
Reverted https://github.com/pytorch/pytorch/pull/105055 on behalf of https://github.com/peterbell10 due to Causes failure in inductor_torchbench ([comment](https://github.com/pytorch/pytorch/pull/105055#issuecomment-1688871947))
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index f8d405c..f4efe6b 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -7194,63 +7194,6 @@
self.assertTrue(max_live_tensors == 2)
- @skipIfRocm
- def test_neg_index(self):
- def test(fn, inps, has_assert: bool, has_wrapping=True):
- for dynamic in (True, False):
- fn_opt = torch.compile(dynamic=dynamic)(fn)
- code = run_and_get_triton_code(fn_opt, *inps)
- self.assertTrue(("tl.where" in code) is has_wrapping)
- self.assertTrue(("device_assert" in code) is has_assert)
- self.assertEqual(fn(*inps), fn_opt(*inps))
-
- def indirect(a, b):
- return a[b - 1]
-
- a = torch.rand(1024, device="cuda")
- b = torch.zeros(4, dtype=torch.long, device="cuda")
- test(indirect, (a, b), has_assert=True)
-
- def direct(x):
- return x[:, -1]
-
- a = torch.rand(1, 64, 32, device="cuda")
- test(direct, (a,), has_assert=False, has_wrapping=False)
-
- def flip(a, b):
- return a[b]
-
- a = torch.rand(1024, device="cuda")
- b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
- test(flip, (a, b), has_assert=True)
-
- # Constant propagate a constant that's negative
- def flip_with_index_constant(a):
- b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
- return a[b]
-
- # Wrapping is constant-folded
- test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
-
- # Operation where we can't prove that the index is always positive or negative
- def pos_and_neg(a):
- b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device="cuda")
- return a[b]
-
- # It has wrapping but no assert
- test(pos_and_neg, (a,), has_assert=False, has_wrapping=True)
-
- # We currently don't do constant propagation with float constants
- def flip_with_index(a):
- b = 1.0 * torch.arange(
- start=-1, end=-a.numel() - 1, step=-1, device="cuda"
- )
- b = b.int()
- return a[b]
-
- # Constant is propagated as we can prove that the result is always negative.
- test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
-
# See https://github.com/pytorch/pytorch/issues/100348
def test_inductor_detach_view(self):
def fn(x: torch.Tensor) -> torch.Tensor:
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index a461161..487ef9a 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -314,12 +314,6 @@
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr) # type: ignore[attr-defined]
- def _print_GreaterThan(self, expr):
- # GreaterThan: >=
- # StrictlyGreaterThan: >
- # Go figure...
- return " >= ".join(map(self.paren, map(self._print, expr.args)))
-
class PythonPrinter(ExprPrinter):
def _print_ModularIndexing(self, expr):
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 62f7f85..471448a 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -294,12 +294,6 @@
def _print_Integer(self, expr):
return f"{int(expr)}L"
- def _print_Where(self, expr):
- c = self.paren(self.doprint(expr.args[0]))
- p = self.paren(self.doprint(expr.args[1]))
- q = self.paren(self.doprint(expr.args[2]))
- return f"{c} ? {p} : {q}"
-
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 0fc8b7f..9fd9c19 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -65,12 +65,6 @@
def _helper_sqrt(self, expr):
return f"tl.math.sqrt({self.paren(self._print(expr))}.to(tl.float32))"
- def _print_Where(self, expr):
- c = self.doprint(expr.args[0])
- p = self.doprint(expr.args[1])
- q = self.doprint(expr.args[2])
- return f"tl.where({c}, {p}, {q})"
-
def _print_Min(self, expr):
nargs = len(expr.args)
if len(expr.args) == 1:
@@ -1192,8 +1186,6 @@
self._load_mask = prior
def indirect_indexing(self, var, size, check=True):
- # TODO(lezcano) This code should be lifted to codegen/common.py.
- # This should be easy, as now CSE variables carry bounds info
class IndirectAssertLine(DeferredLineBase):
def __init__(self, line, var, mask, size_map):
self.var = var
@@ -1231,28 +1223,6 @@
def _new_line(self, line):
return IndirectAssertLine(line, self.var, self.mask, self.size_map)
- if var.bounds.lower < 0:
- new_bounds = ValueRanges.unknown()
- if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
- # Take the negative part of the bound and add size to it
- # Then take union of that and the positive part
- # This is a tighter bound than that of a generic ops.where, as we have info on the cond
- neg = var.bounds & ValueRanges(-sympy.oo, -1)
- new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
- # We don't have a good way of representing the empty range
- if var.bounds.upper >= 0:
- pos = var.bounds & ValueRanges(0, sympy.oo)
- new_bounds = new_bounds | pos
-
- stm = f"{var} + {self.index_to_str(size)}"
- # Mixed negative and non-negative
- if var.bounds.upper >= 0:
- stm = f"tl.where({var} < 0, {stm}, {var})"
- new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
-
- new_var.update_on_args("index_wrap", (var,), {})
- var = new_var
-
generate_assert = (
(check or config.debug_index_asserts)
and config.triton.assert_indirect_indexing
diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py
index bbadb02..7449205 100644
--- a/torch/_inductor/index_propagation.py
+++ b/torch/_inductor/index_propagation.py
@@ -27,7 +27,7 @@
import torch
from torch._prims_common import is_boolean_dtype, is_integer_dtype
-from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
+from torch.utils._sympy.functions import FloorDiv, ModularIndexing
@dataclass
@@ -229,12 +229,7 @@
def indirect_indexing(
self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
) -> Any:
- # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE
- # for SymPy expressions, so we don't want to repeat idx too much
-
# indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
if isinstance(index, IndexPropVar) and index.is_symbolic:
- # If we are turning a indirect indexing into direct, we need to wrap it.
- index = index.value.expr
- return index + Where(index >= 0, 0, size)
+ return index.value.expr
return self.fallback("indirect_indexing", (index, size, check), {}).value
diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py
index 4ce4adf..30b8b86 100644
--- a/torch/utils/_sympy/functions.py
+++ b/torch/utils/_sympy/functions.py
@@ -137,19 +137,6 @@
if isinstance(base, FloorDiv):
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
-class Where(sympy.Function):
- """
- Good ol' ternary operator
- """
-
- nargs = (3,)
-
- @classmethod
- def eval(cls, c, p, q):
- if c == sympy.true:
- return p
- elif c == sympy.false:
- return q
class Mod(sympy.Function):
"""
diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py
index 3eaf4e1..672eae8 100644
--- a/torch/utils/_sympy/interp.py
+++ b/torch/utils/_sympy/interp.py
@@ -14,7 +14,7 @@
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
-from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing, Where
+from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing
# TODO: Dedupe this with SYMPY_INTERP
@@ -42,7 +42,6 @@
TrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "div",
- Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
Pow: "pow",
diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py
index 87db5f7..ba6a2fa 100644
--- a/torch/utils/_sympy/value_ranges.py
+++ b/torch/utils/_sympy/value_ranges.py
@@ -82,12 +82,9 @@
x = simple_sympify(x)
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
- def tighten(self, other) -> "ValueRanges":
+ def tighten(self, other: "ValueRanges"):
"""Given two ValueRanges, returns their intersection"""
- return self & other
-
- # Intersection
- def __and__(self, other) -> "ValueRanges":
+ # Some invariants
if other == ValueRanges.unknown():
return self
if self == ValueRanges.unknown():
@@ -99,16 +96,9 @@
range = ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
return range
- # Union
- def __or__(self, other) -> "ValueRanges":
- if ValueRanges.unknown() in (self, other):
- return ValueRanges.unknown()
- assert self.is_bool == other.is_bool, (self, other)
- if self.is_bool:
- range = ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
- else:
- range = ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
- return range
+ # Intersection
+ def __and__(self, other):
+ return ValueRanges(lower=max(self.lower, other.lower), upper=min(self.upper, other.upper))
def is_singleton(self) -> bool:
return self.lower == self.upper
@@ -445,17 +435,6 @@
return ValueRanges.unknown()
return ValueRanges.increasing_map(x, sympy.sqrt)
- @staticmethod
- def where(a, b, c):
- b = ValueRanges.wrap(b)
- c = ValueRanges.wrap(c)
- assert a.is_bool
- assert b.is_bool == c.is_bool
- if b.is_bool:
- return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
- else:
- return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
-
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def __init__(self):
@@ -549,6 +528,17 @@
def sub(cls, a, b):
return cls.add(a, cls.neg(b))
+ @staticmethod
+ def where(a, b, c):
+ b = ValueRanges.wrap(b)
+ c = ValueRanges.wrap(c)
+ assert a.is_bool
+ assert b.is_bool == c.is_bool
+ if b.is_bool:
+ return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
+ else:
+ return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
+
def __getattr__(self, name):
log.debug("unhandled ValueRange op %s", name)
return self.default_handler