[inductor][cpp] improve vector contiguous checks for FloorDiv and ModularIndexing (#117221)

Fix https://github.com/pytorch/pytorch/issues/114488

The PR tries to enable contiguous vector loads for cases where we can reduce `FloorDiv` and `ModularIndexing` in the vectorized loop.

Take the index expression in test case `test_vec_contiguous_ModularIndexing` for example.
`14336*x0 + 256*x1 + 128*((x2//256)) + ModularIndexing(x2, 1, 128) + 7168*ModularIndexing(x2, 128, 2)` can be reduced to `14336*x0 + 256*x1 + x2 + 128*x2_div_c0 + 7168*x2_mod_c0 + x2_mod_c1` where `x2` is a vectorized loop variable and the vector length is 16. This means we can do vectorized load for this index. Check the code comment for more details:
https://github.com/pytorch/pytorch/pull/117221/files#diff-5ab7b0235e2076a5fc6629ba0b109208940f5b94f5c13babc3e0f87cf4fcec82R317-R329

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117221
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index a369131..925d502 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -2724,10 +2724,47 @@
             return y.softmax(dim=-1)
 
         x = torch.randn(128, 2048)
+        opt_fn = torch.compile(fn)
         metrics.reset()
-        self.common(fn, (x,))
+        _, code = run_and_get_cpp_code(opt_fn, x)
+        self.assertTrue(same(fn(x), opt_fn(x)))
         # 4 kernels for max, exp, sum and div
         assert metrics.generated_cpp_vec_kernel_count == 4
+        FileCheck().check_count(
+            "Vectorized<int>::loadu(tmpbuf.data())", 0, exactly=True
+        ).run(code)
+
+    def test_vec_contiguous_ModularIndexing(self):
+        # https://github.com/pytorch/pytorch/issues/114488
+        class M(torch.nn.Module):
+            def __init__(self, dim):
+                super().__init__()
+                self.norm = torch.nn.LayerNorm(dim * 4)
+
+            def forward(self, x):
+                # the pattern from swin_base_patch4_window7_224
+                B, H, W, C = x.shape
+                x = (
+                    x.reshape(B, H // 2, 2, W // 2, 2, C)
+                    .permute(0, 1, 3, 4, 2, 5)
+                    .flatten(3)
+                )
+                x = self.norm(x)
+                return x
+
+        x = torch.randn(1, 56, 56, 128)
+        m = M(128)
+        opt_m = torch.compile(m)
+        with torch.no_grad():
+            metrics.reset()
+            _, code = run_and_get_cpp_code(opt_m, x)
+            self.assertTrue(same(m(x), opt_m(x)))
+            # Two kernels: one for reduction, one pointwises
+            assert metrics.generated_cpp_vec_kernel_count == 2
+            # Only one kernel has non-contiguous load
+            FileCheck().check_count(
+                "Vectorized<float>::loadu(tmpbuf.data())", 1, exactly=True
+            ).run(code)
 
 
 if __name__ == "__main__":
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index fb2a5aa..f2f4014 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -16,7 +16,7 @@
 from torch._inductor import dependencies
 from torch._inductor.ir import StorageBox, TensorBox
 from torch._prims_common import is_float_dtype
-from torch.utils._sympy.functions import FloorDiv
+from torch.utils._sympy.functions import FloorDiv, ModularIndexing
 from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
 
 from .. import codecache, config, ir, metrics
@@ -311,6 +311,69 @@
     return sympy.simplify(new_index - index)
 
 
+@functools.lru_cache
+def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int):
+    """
+    Simplifies the index expression within the range of a vectorized loop.
+    Given a vectorized loop variable `var` in the range of a loop with `vec_length`,
+    this function transforms the `index` into an equivalent form. It handles
+    simplifications for cases where `var` can be expressed as `vec_length * a + b`,
+    where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences
+    of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations.
+
+    NOTE:
+    The simplified index expression is intended for analysis purposes only, not
+    for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables
+    which are not dependent on the loop variable `var` in the vectorized range. Check
+    https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details.
+
+    Examples:
+    1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then
+       `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable
+       when `div` is divisible by 16.
+    2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free
+       variable when `mod` is divisible by 16.
+    """
+
+    div_freevar_id = 0
+    mod_freevar_id = 0
+
+    def visit_indexing_div(divisor):
+        nonlocal div_freevar_id
+        result = FloorDiv(var, divisor)
+        if sympy.gcd(divisor, vec_length) == vec_length:
+            result = sympy.Symbol(f"{var}_div_c{div_freevar_id}")
+            div_freevar_id += 1
+        return result
+
+    def visit_modular_indexing(divisor, modulus):
+        nonlocal mod_freevar_id
+        result = ModularIndexing(var, divisor, modulus)
+        if sympy.gcd(divisor, vec_length) == vec_length:
+            result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
+            mod_freevar_id += 1
+        elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length:
+            result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
+            mod_freevar_id += 1
+        return result
+
+    original_index = index
+
+    div = sympy.Wild("divisor")
+    if index.has(FloorDiv):
+        index = index.replace(FloorDiv(var, div), visit_indexing_div)
+
+    mod = sympy.Wild("modulus")
+    if index.has(ModularIndexing):
+        index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)
+
+    index = sympy.simplify(index)
+    if index != original_index:
+        return simplify_index_in_vec_range(index, var, vec_length)
+
+    return index
+
+
 class CppPrinter(ExprPrinter):
     def _print_Integer(self, expr):
         return f"{int(expr)}L"
@@ -1355,12 +1418,15 @@
         assert isinstance(V.kernel, CppVecKernel)
         index = V.kernel.rename_indexing(expr)
         tiling_var = V.kernel.itervars[V.kernel.tiling_idx]
-        if V.kernel.index_is_vector_invariant(index):
-            return CppOverrides.index_expr(expr, dtype)
-        if stride_at(
-            tiling_var, index
-        ).is_number and not V.kernel.index_indirect_depends_on(index, tiling_var):
-            stride = stride_at(tiling_var, index)
+        index_vec_simplified = simplify_index_in_vec_range(
+            index, tiling_var, V.kernel.tiling_factor
+        )
+        stride = stride_at(tiling_var, index_vec_simplified)
+        if stride.is_number and not V.kernel.index_indirect_depends_on(
+            index, tiling_var
+        ):
+            if stride == 0:
+                return CppOverrides.index_expr(expr, dtype)
             value = ops.to_dtype(cexpr(index), dtype)
             if isinstance(value, OpsValue):
                 value = value.value
@@ -1711,18 +1777,6 @@
         self.tiling_idx = tiling_idx
         metrics.generated_cpp_vec_kernel_count += 1
 
-    def index_is_vector_invariant(self, index: sympy.Expr):
-        """`index` is either independent from the tiling itervar or unchanged in the vector range"""
-        tiling_var = self.itervars[self.tiling_idx]
-        if not self.index_depends_on(index, tiling_var):
-            return True
-        if not self.index_indirect_depends_on(index, tiling_var):
-            vec_range = [
-                sympy_subs(index, {tiling_var: i}) for i in range(self.tiling_factor)
-            ]
-            return all(expr == vec_range[0] for expr in vec_range)
-        return False
-
     def _get_vec_load_line(
         self,
         var: str,
@@ -1883,12 +1937,16 @@
         index = self.rename_indexing(index)
         dtype = V.graph.get_dtype(name)
         tiling_var = self.itervars[self.tiling_idx]
-        if self.index_is_vector_invariant(index):
+        index_vec_simplified = simplify_index_in_vec_range(
+            index, tiling_var, self.tiling_factor
+        )
+        stride = stride_at(tiling_var, index_vec_simplified)
+        if stride == 0:
             # load scalar and lazily broadcast it on demand
             return super().load(name, index)
-        non_contiguous = stride_at(
-            tiling_var, index
-        ) != 1 or self.index_indirect_depends_on(index, tiling_var)
+        non_contiguous = stride != 1 or self.index_indirect_depends_on(
+            index, tiling_var
+        )
         if non_contiguous:
             csevar = self.load_non_contiguous(var, index, dtype)
         else: