Add canonicalization pattern for linalg.dim

This CL introduces canonicalization patterns for linalg.dim.
This allows the dimenions of chains of view, slice and subview operations to simplify.
Down the line, when mixed with cse, this also allows better composition of linalg tiling and fusion by tracking operations that give the same result (not in this CL).

PiperOrigin-RevId: 262365865
diff --git a/include/mlir/Linalg/IR/LinalgOps.td b/include/mlir/Linalg/IR/LinalgOps.td
index 55a8108..bbbbfad 100644
--- a/include/mlir/Linalg/IR/LinalgOps.td
+++ b/include/mlir/Linalg/IR/LinalgOps.td
@@ -146,6 +146,8 @@
     }
     ViewType getViewType() { return getOperand()->getType().cast<ViewType>(); }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
diff --git a/lib/Linalg/IR/LinalgOps.cpp b/lib/Linalg/IR/LinalgOps.cpp
index 4feb22b..60820ae 100644
--- a/lib/Linalg/IR/LinalgOps.cpp
+++ b/lib/Linalg/IR/LinalgOps.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Linalg/IR/LinalgTypes.h"
 #include "mlir/Linalg/Utils/Utils.h"
@@ -42,6 +43,81 @@
 using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 
+namespace {
+/// Fold constant dimensions into an alloc operation.
+struct SimplifyDimOp : public OpRewritePattern<linalg::DimOp> {
+  using OpRewritePattern<linalg::DimOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(linalg::DimOp dimOp,
+                                     PatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+PatternMatchResult
+SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
+                               PatternRewriter &rewriter) const {
+  auto *viewProducingOp = dimOp.view()->getDefiningOp();
+  auto subView = dyn_cast_or_null<SubViewOp>(viewProducingOp);
+  auto slice = dyn_cast_or_null<SliceOp>(viewProducingOp);
+  auto view = dyn_cast_or_null<ViewOp>(viewProducingOp);
+  if (!subView && !slice && !view)
+    return matchFailure();
+
+  unsigned dim = dimOp.getIndex();
+  Value *min, *max, *step;
+  if (view) {
+    // Cannot traverse block arguments, fail.
+    if (isa<BlockArgument>(view.getIndexing(dim)))
+      return matchFailure();
+    // Record min, max, step for further processing.
+    auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
+    std::tie(min, max, step) =
+        std::make_tuple(range.min(), range.max(), range.step());
+  } else if (subView) {
+    // Record min, max, step for further processing.
+    auto range = subView.getRange(dim);
+    std::tie(min, max, step) =
+        std::make_tuple(range.min, range.max, range.step);
+  } else {
+    // Taking the dim of a slice must take a range (since other dims have been
+    // rank-reduced).
+    auto *rangeValue = slice.getRanges()[dim];
+    // Cannot traverse block arguments, fail.
+    if (isa<BlockArgument>(rangeValue))
+      return matchFailure();
+    auto range = cast<RangeOp>(rangeValue->getDefiningOp());
+    // Record min, max, step for further processing.
+    std::tie(min, max, step) =
+        std::make_tuple(range.min(), range.max(), range.step());
+  }
+
+  // Only support constant steps of 1 atm.
+  auto constant = dyn_cast_or_null<ConstantIndexOp>(step->getDefiningOp());
+  if (!constant || constant.getValue() != 1)
+    return matchFailure();
+
+  // Circumvent affine constraints:
+  //   emit an affine_apply when possible, otherwise emit a `subi`.
+  bool validAffineMin = isValidDim(min) || isValidSymbol(min) ||
+                        isa_and_nonnull<ConstantIndexOp>(min->getDefiningOp());
+  bool validAffineMax = isValidDim(max) || isValidSymbol(max) ||
+                        isa_and_nonnull<ConstantIndexOp>(max->getDefiningOp());
+
+  OpBuilder b(dimOp);
+  ScopedContext scope(b, dimOp.getLoc());
+  // Emit `subi`.
+  if (!validAffineMin || !validAffineMax) {
+    rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()});
+    return matchSuccess();
+  }
+
+  // Emit affine_apply.
+  using edsc::op::operator-;
+  rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)},
+                     {dimOp.view()});
+  return matchSuccess();
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // LoadOp.
 ////////////////////////////////////////////////////////////////////////////////
@@ -501,6 +577,14 @@
                                        result->types));
 }
 
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::DimOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SimplifyDimOp>(context);
+}
+
 static void print(OpAsmPrinter *p, linalg::DimOp op) {
   *p << op.getOperationName() << " " << *op.getOperand() << ", "
      << op.getIndex();
diff --git a/test/Linalg/canonicalize.mlir b/test/Linalg/canonicalize.mlir
new file mode 100644
index 0000000..65e1d54
--- /dev/null
+++ b/test/Linalg/canonicalize.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-DAG: #[[SUB:.*]] = ()[s0, s1] -> (s0 - s1)
+
+func @fold_constants(%arg0: !linalg.buffer<?xf32>) -> (index, index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %R02 = linalg.range %c0:%c2:%c1 : !linalg.range
+  %R03 = linalg.range %c0:%c3:%c1 : !linalg.range
+  %R04 = linalg.range %c0:%c4:%c1 : !linalg.range
+  %R12 = linalg.range %c1:%c2:%c1 : !linalg.range
+  %R13 = linalg.range %c1:%c3:%c1 : !linalg.range
+  %R14 = linalg.range %c1:%c4:%c1 : !linalg.range
+
+  %v = linalg.view %arg0[%R02, %R14] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+  // Expected 2.
+  %v0 = linalg.dim %v, 0 : !linalg.view<?x?xf32>
+  // Expected 3.
+  %v1 = linalg.dim %v, 1 : !linalg.view<?x?xf32>
+
+  %s = linalg.slice %v[%c1, %R12] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+  // Expected 1.
+  %s0 = linalg.dim %s, 0 : !linalg.view<?xf32>
+
+  %sv = linalg.subview %v[%v0, %v1, %c1, %c2, %c4, %c1] : !linalg.view<?x?xf32>
+  // Expected 1.
+  %sv0 = linalg.dim %sv, 0 : !linalg.view<?x?xf32>
+  // Expected 2.
+  %sv1 = linalg.dim %sv, 1 : !linalg.view<?x?xf32>
+
+  return %v0, %v1, %s0, %sv0, %sv1 : index, index, index, index, index
+}
+
+// CHECK-LABEL: fold_constants
+//   CHECK-DAG:   %[[c1:.*]] = constant 1 : index
+//   CHECK-DAG:   %[[c2:.*]] = constant 2 : index
+//   CHECK-DAG:   %[[c3:.*]] = constant 3 : index
+//       CHECK:   return %[[c2]], %[[c3]], %[[c1]], %[[c1]], %[[c2]]
+
+
+func @fold_indices(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %R = linalg.range %arg1:%arg3:%c1 : !linalg.range
+
+  %v = linalg.view %arg0[%R, %R] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+  // Expected %arg3 - %arg1.
+  %v0 = linalg.dim %v, 0 : !linalg.view<?x?xf32>
+  // Expected %arg3 - %arg1.
+  %v1 = linalg.dim %v, 1 : !linalg.view<?x?xf32>
+
+  %arg1_p_arg2 = addi %arg1, %arg2: index
+  %arg1_p_arg2_affine = affine.apply (i, j) -> (i + j) (%arg1, %arg2)
+  %sv = linalg.subview %v[%arg1, %arg1_p_arg2, %c1, %arg1, %arg1_p_arg2_affine, %c1] : !linalg.view<?x?xf32>
+  // Expected %arg2 but can't fold affine.apply with addi.
+  %sv0 = linalg.dim %sv, 0 : !linalg.view<?x?xf32>
+  // Expected %arg2.
+  %sv1 = linalg.dim %sv, 1 : !linalg.view<?x?xf32>
+
+  return %v0, %v1, %sv0, %sv1 : index, index, index, index
+}
+
+// CHECK-LABEL: fold_indices
+//       CHECK: (%[[arg0:.*]]: !linalg.buffer<?xf32>, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index
+//       CHECK:   %[[r0:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]]
+//       CHECK:   %[[r1:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]]
+//       CHECK:   %[[add:.*]] = addi %[[arg1]], %[[arg2]] : index
+//       CHECK:   %[[aff:.*]] = affine.apply #[[SUB]]()[%[[add]], %[[arg1]]]
+//       CHECK:   return %[[r0]], %[[r1]], %[[aff]], %[[arg2]]
\ No newline at end of file