Refactor linalg.view lowering to LLVM - NFC

This CL fuses the emission of size and stride information and makes it clearer which indexings are stepped over when querying the positions. This refactor was motivated by an index calculation bug in the stride computation.

PiperOrigin-RevId: 263341610
diff --git a/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index ac71c32..908191c 100644
--- a/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -388,6 +388,7 @@
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    SliceOpOperandAdaptor adaptor(operands);
     auto sliceOp = cast<SliceOp>(op);
     auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
@@ -408,56 +409,45 @@
     // Declare the view descriptor and insert data ptr.
     Value *desc = undef(viewDescriptorTy);
     desc = insertvalue(viewDescriptorTy, desc,
-                       getViewPtr(viewType, operands[0]), pos(0));
+                       getViewPtr(viewType, adaptor.view()), pos(0));
 
     // TODO(ntv): extract sizes and emit asserts.
     SmallVector<Value *, 4> strides(viewType.getRank());
-    for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
-      strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
+    for (int i = 0, e = viewType.getRank(); i < e; ++i) {
+      strides[i] = extractvalue(int64Ty, adaptor.view(), pos({3, i}));
     }
 
     // Compute and insert base offset.
-    Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
-    for (int j = 0, e = viewType.getRank(); j < e; ++j) {
-      Value *indexing = operands[1 + j];
+    Value *baseOffset = extractvalue(int64Ty, adaptor.view(), pos(1));
+    for (int i = 0, e = viewType.getRank(); i < e; ++i) {
+      Value *indexing = adaptor.indexings()[i];
       Value *min =
-          sliceOp.indexing(j)->getType().isa<RangeType>()
+          sliceOp.indexing(i)->getType().isa<RangeType>()
               ? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
               : indexing;
-      Value *product = mul(min, strides[j]);
+      Value *product = mul(min, strides[i]);
       baseOffset = add(baseOffset, product);
     }
     desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
 
-    // Compute and insert view sizes (max - min along the range).  Skip the
-    // non-range operands as they will be projected away from the view.
-    int i = 0, j = 0;
-    for (Value *index : sliceOp.indexings()) {
-      if (!index->getType().isa<RangeType>()) {
-        ++j;
-        continue;
+    // Compute and insert view sizes (max - min along the range) and strides.
+    // Skip the non-range operands as they will be projected away from the view.
+    int numNewDims = 0;
+    for (auto en : llvm::enumerate(sliceOp.indexings())) {
+      Value *indexing = en.value();
+      if (indexing->getType().isa<RangeType>()) {
+        int i = en.index();
+        Value *rangeDescriptor = adaptor.indexings()[i];
+        Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+        Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+        Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+        Value *size = sub(max, min);
+        Value *stride = mul(strides[i], step);
+        desc = insertvalue(viewDescriptorTy, desc, size, pos({2, numNewDims}));
+        desc =
+            insertvalue(viewDescriptorTy, desc, stride, pos({3, numNewDims}));
+        ++numNewDims;
       }
-
-      Value *rangeDescriptor = operands[1 + j];
-      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
-      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
-      Value *size = sub(max, min);
-
-      desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
-      ++i;
-      ++j;
-    }
-
-    // Compute and insert view strides.  Step over the strides that correspond
-    // to non-range operands as they are projected away from the view.
-    i = 0;
-    for (int j = 0, e = strides.size(); j < e; ++j) {
-      if (!sliceOp.indexing(j)->getType().isa<RangeType>())
-        continue;
-      Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
-      Value *stride = mul(strides[j], step);
-      desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
-      ++i;
     }
 
     rewriter.replaceOp(op, desc);
diff --git a/test/Linalg/llvm.mlir b/test/Linalg/llvm.mlir
index a56b631..9fa05af 100644
--- a/test/Linalg/llvm.mlir
+++ b/test/Linalg/llvm.mlir
@@ -76,10 +76,10 @@
 //  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 //  CHECK-NEXT:   %{{.*}} = llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
 //  CHECK-NEXT:   %{{.*}} = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
-//  CHECK-NEXT:   %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
-//  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 //  CHECK-NEXT:   %{{.*}} = llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+//  CHECK-NEXT:   %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
 //  CHECK-NEXT:   %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+//  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 //  CHECK-NEXT:   %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 
 func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {