Convert expr - c * (expr floordiv c) to expr mod c in AffineExpr
- Detect 'mod' to replace the combination of floordiv, mul, and subtract when
possible at construction time; when 'c' is a power of two, this reduces the number of
operations; also more compact and readable. Update simplifyAdd for this.
On a side note:
- with the affine expr flattening we have, a mod expression like d0 mod c
would be flattened into d0 - c * q, c * q <= d0 <= c*q + c - 1, with 'q'
being added as the local variable (q = d0 floordiv c); as a result, a mod
was turned into a floordiv whenever the expression was reconstructed back,
i.e., as d0 - c * (d0 floordiv c); as a result of this change, we recover
the mod back.
- rename SimplifyAffineExpr -> SimplifyAffineStructures (pass had been renamed but
the file hadn't been).
PiperOrigin-RevId: 228258120
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 1303888..579a6b6 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -1262,6 +1262,29 @@
}
}
+ // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
+ // leads to a much more efficient form when 'c' is a power of two, and in
+ // general a more compact and readable form.
+
+ // Process '(expr floordiv c) * (-c)'.
+ AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
+ if (!rBinOpExpr)
+ return nullptr;
+
+ auto lrhs = rBinOpExpr.getLHS();
+ auto rrhs = rBinOpExpr.getRHS();
+
+ // Process lrhs, which is 'expr floordiv c'.
+ AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
+ if (!lrBinOpExpr)
+ return nullptr;
+
+ auto llrhs = lrBinOpExpr.getLHS();
+ auto rlrhs = lrBinOpExpr.getRHS();
+
+ if (lhs == llrhs && rlrhs == -rrhs) {
+ return lhs % rlrhs;
+ }
return nullptr;
}
diff --git a/lib/Transforms/SimplifyAffineExpr.cpp b/lib/Transforms/SimplifyAffineStructures.cpp
similarity index 95%
rename from lib/Transforms/SimplifyAffineExpr.cpp
rename to lib/Transforms/SimplifyAffineStructures.cpp
index 086e889..bd39e47 100644
--- a/lib/Transforms/SimplifyAffineExpr.cpp
+++ b/lib/Transforms/SimplifyAffineStructures.cpp
@@ -1,4 +1,4 @@
-//===- SimplifyAffineExpr.cpp - MLIR Affine Structures Class-----*- C++ -*-===//
+//===- SimplifyAffineStructures.cpp - ---------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,7 +15,7 @@
// limitations under the License.
// =============================================================================
//
-// This file implements a pass to simplify affine expressions.
+// This file implements a pass to simplify affine structures.
//
//===----------------------------------------------------------------------===//
diff --git a/test/Transforms/simplify.mlir b/test/Transforms/simplify-affine-structures.mlir
similarity index 88%
rename from test/Transforms/simplify.mlir
rename to test/Transforms/simplify-affine-structures.mlir
index 7c18474..9a34f38 100644
--- a/test/Transforms/simplify.mlir
+++ b/test/Transforms/simplify-affine-structures.mlir
@@ -12,20 +12,17 @@
#map4 = (d0, d1) -> (d0 floordiv 2, (3*d0 + 2*d1 + d0) floordiv 2, (50*d1 + 100) floordiv 100)
// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (0, d0 * 5 + 3)
#map5 = (d0, d1) -> ((4*d0 + 8*d1) ceildiv 2 mod 2, (2 + d0 + (8*d0 + 2) floordiv 2))
-// The flattening based simplification is currently regressive on modulo
-// expression simplification in the simple case (d0 mod 8 would be turn into d0
-// - 8 * (d0 floordiv 8); however, in other cases like d1 - d1 mod 8, it
-// would be simplified to an arithmetically simpler and more intuitive 8 * (d1
-// floordiv 8). In general, we have a choice of using either mod or floordiv
-// to express the same expression in mathematically equivalent ways, and making that
-// choice to minimize the number of terms or to simplify arithmetic is a TODO.
-// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0 - (d0 floordiv 8) * 8, (d1 floordiv 8) * 8)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0 mod 8, (d1 floordiv 8) * 8)
#map6 = (d0, d1) -> (d0 mod 8, d1 - d1 mod 8)
// Test map with nested floordiv/mod. Simply should scale by GCD.
// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 * 72 + d1) floordiv 2304, (d0 * 72 + d1 - ((d0 * 9216 + d1 * 128) floordiv 294912) * 2304) floordiv 1152)
#map7 = (d0, d1) -> ((d0 * 9216 + d1 * 128) floordiv 294912, ((d0 * 9216 + d1 * 128) mod 294912) floordiv 147456)
+// floordiv/mul/sub to mod conversion
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0 mod 32, d0 - (d0 floordiv 8) * 4, (d1 mod 16) floordiv 256, d0 mod 7)
+#map8 = (d0, d1) -> (d0 - (32 * (d0 floordiv 32)), d0 - (4 * (d0 floordiv 8)), (d1 - (16 * (d1 floordiv 16))) floordiv 256, d0 - 7 * (d0 floordiv 7))
+
// CHECK-DAG: [[SET_EMPTY_2D:#set[0-9]+]] = (d0, d1) : (1 == 0)
// CHECK-DAG: #set1 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, d0 * -1 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
// CHECK-DAG: #set2 = (d0, d1)[s0, s1] : (1 == 0)
@@ -59,14 +56,15 @@
func @test() {
for %n0 = 0 to 127 {
for %n1 = 0 to 7 {
- %x = affine_apply #map0(%n0, %n1)
- %y = affine_apply #map1(%n0, %n1)
- %z = affine_apply #map2(%n0, %n1)
- %w = affine_apply #map3(%n0, %n1)
- %u = affine_apply #map4(%n0, %n1)
- %v = affine_apply #map5(%n0, %n1)
- %t = affine_apply #map6(%n0, %n1)
- %s = affine_apply #map7(%n0, %n1)
+ %a = affine_apply #map0(%n0, %n1)
+ %b = affine_apply #map1(%n0, %n1)
+ %c = affine_apply #map2(%n0, %n1)
+ %d = affine_apply #map3(%n0, %n1)
+ %e = affine_apply #map4(%n0, %n1)
+ %f = affine_apply #map5(%n0, %n1)
+ %g = affine_apply #map6(%n0, %n1)
+ %h = affine_apply #map7(%n0, %n1)
+ %i = affine_apply #map8(%n0, %n1)
}
}
return