Avoid overflow when lowering linalg.slice

linalg.subview used to lower to a slice with a bounded range resulting in correct bounded accesses. However linalg.slice could still index out of bounds. This CL moves the bounding to linalg.slice.

LLVM select and cmp ops gain a more idiomatic builder.

PiperOrigin-RevId: 264897125
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 10533cc..fcba2b7 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -164,6 +164,13 @@
   let llvmBuilder = [{
     $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, ICmpPredicate predicate, Value *lhs, "
+    "Value *rhs", [{
+      LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
+      build(b, result, LLVMType::getInt1Ty(dialect),
+            b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
+    }]>];
   let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
   let printer = [{ printICmpOp(p, *this); }];
 }
@@ -386,6 +393,11 @@
                  LLVM_Type:$falseValue)>,
       LLVM_Builder<
           "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, Value *condition, Value *lhs, "
+    "Value *rhs", [{
+      build(b, result, lhs->getType(), condition, lhs, rhs);
+    }]>];
   let parser = [{ return parseSelectOp(parser, result); }];
   let printer = [{ printSelectOp(p, *this); }];
 }
@@ -550,5 +562,4 @@
   }];
 }
 
-
 #endif // LLVMIR_OPS
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 1e8f076..b6e0430 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -415,7 +415,8 @@
 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
-///      and stride corresponding to the
+///      and stride corresponding to the region of memory within the bounds of
+///      the parent view.
 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
 /// The linalg.slice op is replaced by the alloca'ed pointer.
 class SliceOpConversion : public LLVMOpLowering {
@@ -446,6 +447,8 @@
     auto ib = rewriter.getInsertionBlock();
     rewriter.setInsertionPointToStart(
         &op->getParentOfType<FuncOp>().getBlocks().front());
+    Value *zero =
+        constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
     Value *one =
         constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
     // Alloca with proper alignment.
@@ -470,12 +473,10 @@
     Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
     for (int i = 0, e = viewType.getRank(); i < e; ++i) {
       Value *indexing = adaptor.indexings()[i];
-      Value *min =
-          sliceOp.indexing(i)->getType().isa<RangeType>()
-              ? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
-              : indexing;
-      Value *product = mul(min, strides[i]);
-      baseOffset = add(baseOffset, product);
+      Value *min = indexing;
+      if (sliceOp.indexing(i)->getType().isa<RangeType>())
+        min = extractvalue(int64Ty, indexing, pos(0));
+      baseOffset = add(baseOffset, mul(min, strides[i]));
     }
     desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
 
@@ -485,13 +486,21 @@
     for (auto en : llvm::enumerate(sliceOp.indexings())) {
       Value *indexing = en.value();
       if (indexing->getType().isa<RangeType>()) {
-        int i = en.index();
-        Value *rangeDescriptor = adaptor.indexings()[i];
+        int rank = en.index();
+        Value *rangeDescriptor = adaptor.indexings()[rank];
         Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
         Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
         Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+        Value *baseSize =
+            extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank}));
+        // Bound upper by base view upper bound.
+        max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
+                          baseSize);
         Value *size = sub(max, min);
-        Value *stride = mul(strides[i], step);
+        // Bound lower by zero.
+        size =
+            llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
+        Value *stride = mul(strides[rank], step);
         desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
         desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
         ++numNewDims;
@@ -703,16 +712,8 @@
     ScopedContext scope(b, op.getLoc());
     auto *view = op.getView();
     SmallVector<Value *, 8> ranges;
-    for (auto en : llvm::enumerate(op.getRanges())) {
-      using edsc::op::operator<;
-      using linalg::intrinsics::dim;
-      unsigned rank = en.index();
-      auto sliceRange = en.value();
-      auto size = dim(view, rank);
-      ValueHandle ub(sliceRange.max);
-      auto max = edsc::intrinsics::select(size < ub, size, ub);
-      ranges.push_back(range(sliceRange.min, max, sliceRange.step));
-    }
+    for (auto sliceRange : op.getRanges())
+      ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step));
     op.replaceAllUsesWith(slice(view, ranges));
     op.erase();
   });