Add conversion for splat of vectors of 2+D
This CL adds a missing lowering for splat of multi-dimensional vectors.
Additional support is also added to the runtime utils library to allow printing memrefs with such vectors.
PiperOrigin-RevId: 274794723
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 0c162cb..4b7dec7 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -368,6 +368,85 @@
}
};
+//////////////// Support for Lowering operations on n-D vectors ////////////////
+namespace {
+// Helper struct to "unroll" operations on n-D vectors in terms of operations on
+// 1-D LLVM vectors.
+struct NDVectorTypeInfo {
+ // LLVM array struct which encodes n-D vectors.
+ LLVM::LLVMType llvmArrayTy;
+ // LLVM vector type which encodes the inner 1-D vector type.
+ LLVM::LLVMType llvmVectorTy;
+ // Multiplicity of llvmArrayTy to llvmVectorTy.
+ SmallVector<int64_t, 4> arraySizes;
+};
+} // namespace
+
+// For >1-D vector types, extracts the necessary information to iterate over all
+// 1-D subvectors in the underlying llrepresentation of the n-D vecotr
+// Iterates on the llvm array type until we hit a non-array type (which is
+// asserted to be an llvm vector type).
+static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
+ LLVMTypeConverter &converter) {
+ assert(vectorType.getRank() > 1 && "extpected >1D vector type");
+ NDVectorTypeInfo info;
+ info.llvmArrayTy =
+ converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
+ if (!info.llvmArrayTy)
+ return info;
+ info.arraySizes.reserve(vectorType.getRank() - 1);
+ auto llvmTy = info.llvmArrayTy;
+ while (llvmTy.isArrayTy()) {
+ info.arraySizes.push_back(llvmTy.getArrayNumElements());
+ llvmTy = llvmTy.getArrayElementType();
+ }
+ if (!llvmTy.isVectorTy())
+ return info;
+ info.llvmVectorTy = llvmTy;
+ return info;
+}
+
+// Express `linearIndex` in terms of coordinates of `basis`.
+// Returns the empty vector when linearIndex is out of the range [0, P] where
+// P is the product of all the basis coordinates.
+//
+// Prerequisites:
+// Basis is an array of nonnegative integers (signed type inherited from
+// vector shape type).
+static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
+ unsigned linearIndex) {
+ SmallVector<int64_t, 4> res;
+ res.reserve(basis.size());
+ for (unsigned basisElement : llvm::reverse(basis)) {
+ res.push_back(linearIndex % basisElement);
+ linearIndex = linearIndex / basisElement;
+ }
+ if (linearIndex > 0)
+ return {};
+ std::reverse(res.begin(), res.end());
+ return res;
+}
+
+// Iterate of linear index, convert to coords space and insert splatted 1-D
+// vector in each position.
+template <typename Lambda>
+void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
+ Lambda fun) {
+ unsigned ub = 1;
+ for (auto s : info.arraySizes)
+ ub *= s;
+ for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
+ auto coords = getCoordinates(info.arraySizes, linearIndex);
+ // Linear index is out of bounds, we are done.
+ if (coords.empty())
+ break;
+ assert(coords.size() == info.arraySizes.size());
+ auto position = builder.getIndexArrayAttr(coords);
+ fun(position);
+ }
+}
+////////////// End Support for Lowering operations on n-D vectors //////////////
+
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
@@ -415,27 +494,6 @@
}
};
-// Express `linearIndex` in terms of coordinates of `basis`.
-// Returns the empty vector when linearIndex is out of the range [0, P] where
-// P is the product of all the basis coordinates.
-//
-// Prerequisites:
-// Basis is an array of nonnegative integers (signed type inherited from
-// vector shape type).
-static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
- unsigned linearIndex) {
- SmallVector<int64_t, 4> res;
- res.reserve(basis.size());
- for (unsigned basisElement : llvm::reverse(basis)) {
- res.push_back(linearIndex % basisElement);
- linearIndex = linearIndex / basisElement;
- }
- if (linearIndex > 0)
- return {};
- std::reverse(res.begin(), res.end());
- return res;
-}
-
template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
static_assert(
std::is_base_of<
@@ -490,33 +548,16 @@
return this->matchSuccess();
}
- // Unroll iterated array type until we hit a non-array type.
- auto llvmTy = llvmArrayTy;
- SmallVector<int64_t, 4> arraySizes;
- while (llvmTy.isArrayTy()) {
- arraySizes.push_back(llvmTy.getArrayNumElements());
- llvmTy = llvmTy.getArrayElementType();
- }
- assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
- auto llvmVectorTy = llvmTy;
+ auto vectorType = op->getResult(0)->getType().dyn_cast<VectorType>();
+ if (!vectorType)
+ return this->matchFailure();
+ auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering);
+ auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+ if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
+ return this->matchFailure();
- // Iteratively extract a position coordinates with basis `arraySize` from a
- // `linearIndex` that is incremented at each step. This terminates when
- // `linearIndex` exceeds the range specified by `arraySize`.
- // This has the effect of fully unrolling the dimensions of the n-D array
- // type, getting to the underlying vector element.
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
- unsigned ub = 1;
- for (auto s : arraySizes)
- ub *= s;
- for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
- auto coords = getCoordinates(arraySizes, linearIndex);
- // Linear index is out of bounds, we are done.
- if (coords.empty())
- break;
-
- auto position = rewriter.getIndexArrayAttr(coords);
-
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value *, OpCount> extractedOperands;
@@ -528,7 +569,7 @@
loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
- }
+ });
rewriter.replaceOp(op, desc);
return this->matchSuccess();
}
@@ -1263,6 +1304,58 @@
}
};
+// The Splat operation is lowered to an insertelement + a shufflevector
+// operation. Splat to only 2+-d vector result types are lowered by the
+// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
+struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
+ using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splatOp = cast<SplatOp>(op);
+ OperandAdaptor<SplatOp> adaptor(operands);
+ VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
+ if (!resultType || resultType.getRank() == 1)
+ return matchFailure();
+
+ // First insert it into an undef vector so we can shuffle it.
+ auto loc = op->getLoc();
+ auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering);
+ auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
+ auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+ if (!llvmArrayTy || !llvmVectorTy)
+ return matchFailure();
+
+ // Construct returned value.
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+
+ // Construct a 1-D vector with the splatted value that we insert in all the
+ // places within the returned descriptor.
+ Value *vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(rewriter.getIntegerType(32)),
+ rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+ Value *v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
+ adaptor.input(), zero);
+
+ // Shuffle the value across the desired number of elements.
+ int64_t width = resultType.getDimSize(resultType.getRank() - 1);
+ SmallVector<int32_t, 4> zeroValues(width, 0);
+ ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+ v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
+
+ // Iterate of linear index, convert to coords space and insert splatted 1-D
+ // vector in each position.
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
+ position);
+ });
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
@@ -1352,6 +1445,7 @@
SelectOpLowering,
SignExtendIOpLowering,
SplatOpLowering,
+ SplatNdOpLowering,
StoreOpLowering,
SubFOpLowering,
SubIOpLowering,