[inductor] Handle more edge cases in slice and slice_scatter (#117377)
Fixes #117110
When slicing we can end up with start and end which are out of bounds, which is
handled in python slicing by clamping to the correct bounds. There is also the
case where end < start which should result in an empty slice.
In the isoneutral_mixing failure we have the second case, with `start=2, end=0`
which in `slice_scatter` became `src_size[dim] = -2`.
This PR improves slice's edge case handling and factors the start and end
normalization code out so it can be shared with slice_scatter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117377
Approved by: https://github.com/lezcano
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 57237e7..7030ae2 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -2546,6 +2546,29 @@
(torch.randn([2, 20, 2]),),
)
+ # It's a view so it doens't generate a kernel
+ @expectedFailureCodegenDynamic
+ def test_slice3(self):
+ def fn(a, b):
+ return torch.ops.aten.slice.Tensor(a, 0, 0, -b)
+
+ x = torch.rand(48, 3, 512, 512)
+ self.common(fn, (x, 2))
+
+ @expectedFailureCodegenDynamic
+ def test_slice4(self):
+ # empty slices that require clamping the start or end
+ def fn(a):
+ return (
+ aten.slice.Tensor(a, 0, 2, 0, 1),
+ aten.slice.Tensor(a, 0, a.shape[0], a.shape[0] + 10, 1),
+ aten.slice.Tensor(a, 0, -20, 0, 1),
+ aten.slice.Tensor(a, 0, -20, -16, 1),
+ )
+
+ x = torch.rand(10)
+ self.common(fn, (x,))
+
def test_split_with_sizes(self):
def fn(a, sizes):
return [t + 1.0 for t in torch.split(a * 2.0, sizes, -1)]
@@ -5635,6 +5658,20 @@
],
)
+ def test_slice_scatter5(self):
+ # empty slices that require clamping the start or end
+ def fn(a, b):
+ return (
+ aten.slice_scatter.default(a, b, 0, 2, 0, 1),
+ aten.slice_scatter.default(a, b, 0, a.shape[0], a.shape[0] + 10, 1),
+ aten.slice_scatter.default(a, b, 0, -20, 0, 1),
+ aten.slice_scatter.default(a, b, 0, -20, -16, 1),
+ )
+
+ a = torch.arange(10, dtype=torch.float)
+ b = torch.empty(0)
+ self.common(fn, [a, b])
+
def test_scatter1(self):
def fn(a, dim, index, b):
return aten.scatter(a, dim, index, b)
@@ -7772,15 +7809,6 @@
x = torch.randn(2, 2)
self.common(fn, (x,), atol=0, rtol=0)
- # It's a view so it doens't generate a kernel
- @expectedFailureCodegenDynamic
- def test_slice(self):
- def fn(a, b):
- return torch.ops.aten.slice.Tensor(a, 0, 0, -b)
-
- x = torch.rand(48, 3, 512, 512)
- self.common(fn, (x, 2))
-
def test_inplace_resize_as(self):
def fn(x, y):
x.resize_as_(y)
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index fad4cec..ba84611 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -2232,6 +2232,35 @@
class SliceView(View):
@classmethod
+ def normalize_start_end(cls, x, dim, start, end):
+ """
+ Normalize start and end such that both are in the range
+ [0, x.get_size()[dim]] and start <= end.
+ """
+ sizevars = V.graph.sizevars
+ dim_size = x.get_size()[dim]
+
+ if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
+
+ def clamp(x, lower, upper):
+ return sympy.Min(sympy.Max(x, lower), upper)
+
+ else:
+
+ def clamp(x, lower, upper):
+ return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
+
+ def clamp_wrap(val, lower, upper, default):
+ if val is None:
+ return default
+ val = cls.handle_negative_index(val, dim_size)
+ return clamp(val, lower, upper)
+
+ start = clamp_wrap(start, 0, dim_size, 0)
+ end = clamp_wrap(end, start, dim_size, dim_size)
+ return start, end
+
+ @classmethod
def create(cls, x, dim, start, end, step=1):
step = sympy.expand(step)
assert step > 0
@@ -2244,15 +2273,7 @@
sizevars = V.graph.sizevars
new_size = list(x.get_size())
- start = cls.handle_negative_index(start, new_size[dim])
- end = cls.handle_negative_index(end, new_size[dim])
-
- if free_unbacked_symbols(start) or free_unbacked_symbols(end):
- end = sympy.Min(end, new_size[dim])
- start = sympy.Min(start, end)
- else:
- end = sizevars.evaluate_min(end, new_size[dim])
- start = sizevars.evaluate_min(start, end)
+ start, end = cls.normalize_start_end(x, dim, start, end)
new_size[dim] = FloorDiv(end - start + (step - 1), step)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 1bc6246..f2368f2 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -883,10 +883,6 @@
assert isinstance(x, TensorBox)
dim = _validate_dim(x, dim, 0)
dim_size = x.get_size()[dim]
- if V.graph.sizevars.evaluate_expr(sympy.Lt(start + dim_size, 0)):
- start = 0
- if V.graph.sizevars.evaluate_expr(sympy.Lt(end + dim_size, 0)):
- end = 0
return TensorBox(ir.SliceView.create(x.data, dim, start, end, step))
@@ -2403,14 +2399,8 @@
x_loader = x.make_loader()
dim = _validate_dim(x, dim, 0)
dim_size = x.get_size()[dim]
- if start is not None and V.graph.sizevars.evaluate_expr(sympy.Lt(start, 0)):
- start = start + dim_size
- if end is not None and V.graph.sizevars.evaluate_expr(sympy.Lt(end, 0)):
- end = end + dim_size
- if start is None:
- start = 0
- if end is None or V.graph.sizevars.statically_known_leq(x.get_size()[dim], end):
- end = dim_size
+
+ start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
src_size = list(x.get_size())
src_size[dim] = FloorDiv(end - start + (step - 1), step)
diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py
index c99fcaa..a417dd0 100644
--- a/torch/_inductor/sizevars.py
+++ b/torch/_inductor/sizevars.py
@@ -383,6 +383,13 @@
self.guard_leq(right, left)
return right
+ def evaluate_max(self, left: Expr, right: Expr) -> Expr:
+ """return the larger of left and right, and guard on that choice"""
+ # Always choose the opposite of eval min for consistency
+ # This means min(a, b) and max(a, b) produce the same guards
+ min_val = self.evaluate_min(left, right)
+ return right if min_val is left else left
+
def evaluate_static_shape(self, left: Expr) -> int:
right = self.size_hint(left)
self.guard_equals(left, sympy.Integer(right))