Add conversion for splat of vectors of 2+D

This CL adds a missing lowering for splat of multi-dimensional vectors.
Additional support is also added to the runtime utils library to allow printing memrefs with such vectors.

PiperOrigin-RevId: 274794723
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 0c162cb..4b7dec7 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -368,6 +368,85 @@
   }
 };
 
+//////////////// Support for Lowering operations on n-D vectors ////////////////
+namespace {
+// Helper struct to "unroll" operations on n-D vectors in terms of operations on
+// 1-D LLVM vectors.
+struct NDVectorTypeInfo {
+  // LLVM array struct which encodes n-D vectors.
+  LLVM::LLVMType llvmArrayTy;
+  // LLVM vector type which encodes the inner 1-D vector type.
+  LLVM::LLVMType llvmVectorTy;
+  // Multiplicity of llvmArrayTy to llvmVectorTy.
+  SmallVector<int64_t, 4> arraySizes;
+};
+} // namespace
+
+// For >1-D vector types, extracts the necessary information to iterate over all
+// 1-D subvectors in the underlying llrepresentation of the n-D vecotr
+// Iterates on the llvm array type until we hit a non-array type (which is
+// asserted to be an llvm vector type).
+static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
+                                                LLVMTypeConverter &converter) {
+  assert(vectorType.getRank() > 1 && "extpected >1D vector type");
+  NDVectorTypeInfo info;
+  info.llvmArrayTy =
+      converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
+  if (!info.llvmArrayTy)
+    return info;
+  info.arraySizes.reserve(vectorType.getRank() - 1);
+  auto llvmTy = info.llvmArrayTy;
+  while (llvmTy.isArrayTy()) {
+    info.arraySizes.push_back(llvmTy.getArrayNumElements());
+    llvmTy = llvmTy.getArrayElementType();
+  }
+  if (!llvmTy.isVectorTy())
+    return info;
+  info.llvmVectorTy = llvmTy;
+  return info;
+}
+
+// Express `linearIndex` in terms of coordinates of `basis`.
+// Returns the empty vector when linearIndex is out of the range [0, P] where
+// P is the product of all the basis coordinates.
+//
+// Prerequisites:
+//   Basis is an array of nonnegative integers (signed type inherited from
+//   vector shape type).
+static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
+                                              unsigned linearIndex) {
+  SmallVector<int64_t, 4> res;
+  res.reserve(basis.size());
+  for (unsigned basisElement : llvm::reverse(basis)) {
+    res.push_back(linearIndex % basisElement);
+    linearIndex = linearIndex / basisElement;
+  }
+  if (linearIndex > 0)
+    return {};
+  std::reverse(res.begin(), res.end());
+  return res;
+}
+
+// Iterate of linear index, convert to coords space and insert splatted 1-D
+// vector in each position.
+template <typename Lambda>
+void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
+                     Lambda fun) {
+  unsigned ub = 1;
+  for (auto s : info.arraySizes)
+    ub *= s;
+  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
+    auto coords = getCoordinates(info.arraySizes, linearIndex);
+    // Linear index is out of bounds, we are done.
+    if (coords.empty())
+      break;
+    assert(coords.size() == info.arraySizes.size());
+    auto position = builder.getIndexArrayAttr(coords);
+    fun(position);
+  }
+}
+////////////// End Support for Lowering operations on n-D vectors //////////////
+
 // Basic lowering implementation for one-to-one rewriting from Standard Ops to
 // LLVM Dialect Ops.
 template <typename SourceOp, typename TargetOp>
@@ -415,27 +494,6 @@
   }
 };
 
-// Express `linearIndex` in terms of coordinates of `basis`.
-// Returns the empty vector when linearIndex is out of the range [0, P] where
-// P is the product of all the basis coordinates.
-//
-// Prerequisites:
-//   Basis is an array of nonnegative integers (signed type inherited from
-//   vector shape type).
-static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
-                                              unsigned linearIndex) {
-  SmallVector<int64_t, 4> res;
-  res.reserve(basis.size());
-  for (unsigned basisElement : llvm::reverse(basis)) {
-    res.push_back(linearIndex % basisElement);
-    linearIndex = linearIndex / basisElement;
-  }
-  if (linearIndex > 0)
-    return {};
-  std::reverse(res.begin(), res.end());
-  return res;
-}
-
 template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
   static_assert(
       std::is_base_of<
@@ -490,33 +548,16 @@
       return this->matchSuccess();
     }
 
-    // Unroll iterated array type until we hit a non-array type.
-    auto llvmTy = llvmArrayTy;
-    SmallVector<int64_t, 4> arraySizes;
-    while (llvmTy.isArrayTy()) {
-      arraySizes.push_back(llvmTy.getArrayNumElements());
-      llvmTy = llvmTy.getArrayElementType();
-    }
-    assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
-    auto llvmVectorTy = llvmTy;
+    auto vectorType = op->getResult(0)->getType().dyn_cast<VectorType>();
+    if (!vectorType)
+      return this->matchFailure();
+    auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering);
+    auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+    if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
+      return this->matchFailure();
 
-    // Iteratively extract a position coordinates with basis `arraySize` from a
-    // `linearIndex` that is incremented at each step. This terminates when
-    // `linearIndex` exceeds the range specified by `arraySize`.
-    // This has the effect of fully unrolling the dimensions of the n-D array
-    // type, getting to the underlying vector element.
     Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
-    unsigned ub = 1;
-    for (auto s : arraySizes)
-      ub *= s;
-    for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
-      auto coords = getCoordinates(arraySizes, linearIndex);
-      // Linear index is out of bounds, we are done.
-      if (coords.empty())
-        break;
-
-      auto position = rewriter.getIndexArrayAttr(coords);
-
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
       // For this unrolled `position` corresponding to the `linearIndex`^th
       // element, extract operand vectors
       SmallVector<Value *, OpCount> extractedOperands;
@@ -528,7 +569,7 @@
           loc, llvmVectorTy, extractedOperands, op->getAttrs());
       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
                                                   newVal, position);
-    }
+    });
     rewriter.replaceOp(op, desc);
     return this->matchSuccess();
   }
@@ -1263,6 +1304,58 @@
   }
 };
 
+// The Splat operation is lowered to an insertelement + a shufflevector
+// operation. Splat to only 2+-d vector result types are lowered by the
+// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
+struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
+  using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto splatOp = cast<SplatOp>(op);
+    OperandAdaptor<SplatOp> adaptor(operands);
+    VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
+    if (!resultType || resultType.getRank() == 1)
+      return matchFailure();
+
+    // First insert it into an undef vector so we can shuffle it.
+    auto loc = op->getLoc();
+    auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering);
+    auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
+    auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+    if (!llvmArrayTy || !llvmVectorTy)
+      return matchFailure();
+
+    // Construct returned value.
+    Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+
+    // Construct a 1-D vector with the splatted value that we insert in all the
+    // places within the returned descriptor.
+    Value *vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
+    auto zero = rewriter.create<LLVM::ConstantOp>(
+        loc, lowering.convertType(rewriter.getIntegerType(32)),
+        rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+    Value *v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
+                                                      adaptor.input(), zero);
+
+    // Shuffle the value across the desired number of elements.
+    int64_t width = resultType.getDimSize(resultType.getRank() - 1);
+    SmallVector<int32_t, 4> zeroValues(width, 0);
+    ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+    v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
+
+    // Iterate of linear index, convert to coords space and insert splatted 1-D
+    // vector in each position.
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
+                                                  position);
+    });
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
 } // namespace
 
 static void ensureDistinctSuccessors(Block &bb) {
@@ -1352,6 +1445,7 @@
       SelectOpLowering,
       SignExtendIOpLowering,
       SplatOpLowering,
+      SplatNdOpLowering,
       StoreOpLowering,
       SubFOpLowering,
       SubIOpLowering,