Lower linalg.indexed_generic to loops.
PiperOrigin-RevId: 281169885
Change-Id: I3951c74ff03faebc335fcf3a084411a81f58cde4
diff --git a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 867b8ab..b137764 100644
--- a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -108,12 +108,11 @@
}
template <typename GenericOpType>
-LogicalResult verifyBlockArgs(GenericOpType op, Block &block, unsigned nViews,
- unsigned nLoops, unsigned nInputViews);
+LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
-template <>
-LogicalResult verifyBlockArgs(GenericOp op, Block &block, unsigned nViews,
- unsigned nLoops, unsigned nInputViews) {
+template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
+ auto nViews = op.getNumInputsAndOutputs();
+ auto nInputViews = op.getNumInputs();
if (block.getNumArguments() != nViews)
return op.emitError(
"op expected number of block arguments to match number of views");
@@ -129,10 +128,10 @@
return success();
}
-template <>
-LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block,
- unsigned nViews, unsigned nLoops,
- unsigned nInputViews) {
+template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
+ auto nInputViews = op.getNumInputs();
+ auto nLoops = op.getNumLoops();
+ auto nViews = op.getNumInputsAndOutputs();
if (block.getNumArguments() != nViews + nLoops)
return op.emitError(
"op expected number of block arguments to match number of views + "
@@ -158,6 +157,76 @@
}
template <typename GenericOpType>
+LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
+
+template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
+ auto nViews = op.getNumInputsAndOutputs();
+ auto nInputViews = op.getNumInputs();
+ if (funType.getNumInputs() != nViews)
+ return op.emitError("op expected fun arguments to match number of views");
+ if (funType.getNumResults() != op.getNumOutputs())
+ return op.emitError(
+ "op expected fun results to match number of output views");
+
+ for (auto en : llvm::enumerate(op.indexing_maps())) {
+ auto idx = en.index();
+ auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+ : op.getOutputViewType(idx - nInputViews);
+ if (funType.getInput(idx) != view.getElementType())
+ return op.emitError("op expected fun argument ")
+ << idx << " of the same type as elemental type "
+ << view.getElementType() << " of view " << idx;
+
+ if (idx >= nInputViews) {
+ auto resultIdx = idx - nInputViews;
+ if (funType.getResult(resultIdx) != view.getElementType())
+ return op.emitError("op expected fun result ")
+ << resultIdx << " of the same type as elemental type "
+ << view.getElementType() << " of view " << idx;
+ }
+ }
+ return success();
+}
+
+template <>
+LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
+ auto nLoops = op.getNumLoops();
+ auto nInputViews = op.getNumInputs();
+ auto nOutputs = op.getNumOutputs();
+ auto nViews = op.getNumInputsAndOutputs();
+ if (funType.getNumInputs() != nViews + nLoops)
+ return op.emitError(
+ "op expected fun arguments to match number of views + number of loops");
+ if (funType.getNumResults() != nOutputs)
+ return op.emitError(
+ "op expected fun results to match number of output views");
+ for (unsigned i = 0; i < nLoops; ++i) {
+ if (!funType.getInput(i).isIndex())
+ return op.emitError("op expected fun argument ")
+ << i << " to be of IndexType";
+ }
+ for (auto en : llvm::enumerate(op.indexing_maps())) {
+ auto idx = en.index();
+ auto funIdx = nLoops + idx;
+ auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+ : op.getOutputViewType(idx - nInputViews);
+ if (funType.getInput(funIdx) != view.getElementType())
+ return op.emitError("op expected fun argument ")
+ << funIdx << " of the same type as elemental type "
+ << view.getElementType() << " of view " << idx;
+
+ if (idx >= nInputViews) {
+ auto resultIdx = idx - nInputViews;
+ if (funType.getResult(resultIdx) != view.getElementType())
+ return op.emitError("op expected fun result ")
+ << resultIdx << " of the same type as elemental type "
+ << view.getElementType() << " of view " << idx;
+ }
+ }
+ return success();
+}
+
+template <typename GenericOpType>
LogicalResult verifyGenericOp(GenericOpType op) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
@@ -171,20 +240,14 @@
if (!region.empty()) {
if (region.getBlocks().size() != 1)
return op.emitError("op expected region with 1 block");
-
- auto &block = region.getBlocks().front();
- if (failed(verifyBlockArgs(op, block, nViews, nLoops, nInputViews))) {
+ if (failed(verifyBlockArgs(op, region.getBlocks().front())))
return failure();
- }
} else {
if (!funOp || !funOp.getType())
return op.emitError(
"op expected fun attribute to refer to a defined symbol");
- if (funType.getNumInputs() != nViews)
- return op.emitError("op expected fun arguments to match number of views");
- if (funType.getNumResults() != op.getNumOutputs())
- return op.emitError(
- "op expected fun results to match number of output views");
+ if (failed(verifyFuncArgs(op, funType)))
+ return failure();
}
SmallVector<AffineMap, 4> indexingMaps;
@@ -215,19 +278,6 @@
if (m.getNumResults() != view.getRank())
return op.emitError("op expected indexing_map #")
<< idx << " results to match view rank: " << view;
-
- if (funType) {
- if (funType.getInput(idx) != view.getElementType())
- return op.emitError("op expected fun argument ")
- << idx
- << " to match view element type: " << view.getElementType();
-
- if (idx >= nInputViews)
- if (funType.getResult(idx - nInputViews) != view.getElementType())
- return op.emitError("op expected fun result ")
- << idx << " to match output view element type: "
- << view.getElementType();
- }
}
auto concatMap = concatAffineMaps(indexingMaps);
@@ -718,6 +768,13 @@
res.push_back(genericOp.getIndexingMap(i));
}
return res;
+ } else if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
+ SmallVector<AffineMap, 4> res;
+ unsigned nViews = indexedGenericOp.getNumInputsAndOutputs();
+ res.reserve(nViews);
+ for (unsigned i = 0, e = nViews; i < e; ++i)
+ res.push_back(indexedGenericOp.getIndexingMap(i));
+ return res;
}
llvm_unreachable("Missing loopToOperandRangesMaps for op");
}
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 6aace80..6e97a7a 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -474,10 +474,11 @@
MLIRContext *ctx) {
// TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
- patterns.insert<CopyTransposeConversion, LinalgOpConversion<CopyOp>,
- LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
- LinalgOpConversion<MatvecOp>, LinalgOpConversion<MatmulOp>,
- LinalgOpConversion<ConvOp>, LinalgOpConversion<GenericOp>>(
+ patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
+ LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
+ LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
+ LinalgOpConversion<IndexedGenericOp>,
+ LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
ctx);
}
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
index 058dc07..0bf4cea 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
@@ -244,14 +244,14 @@
SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
// 1.a. Emit std_load from input views.
- for (unsigned i = 0, e = nInputs; i < e; ++i) {
+ for (unsigned i = 0; i < nInputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
}
// 1.b. Emit std_load from output views.
- for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+ for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
@@ -264,49 +264,138 @@
assert(callOp->getNumResults() == genericOp.getNumOutputs());
// 3. Emit std_store.
- for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+ for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
}
- } else {
- // TODO(ntv): When a region inliner exists, use it.
- // 2. Inline region, currently only works for a single basic block.
- BlockAndValueMapping map;
- auto &block = genericOp.region().front();
- for (auto it : llvm::zip(block.getArguments(), indexedValues))
+ return;
+ }
+ // TODO(ntv): When a region inliner exists, use it.
+ // 2. Inline region, currently only works for a single basic block.
+ BlockAndValueMapping map;
+ auto &block = genericOp.region().front();
+ for (auto it : llvm::zip(block.getArguments(), indexedValues))
+ map.map(std::get<0>(it), std::get<1>(it));
+ for (auto &op : block.without_terminator()) {
+ assert(op.getNumRegions() == 0);
+ auto *newOp = b.clone(op, map);
+ for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
map.map(std::get<0>(it), std::get<1>(it));
- for (auto &op : block) {
- // Skip terminator.
- if (&op == &block.back())
- continue;
- assert(op.getNumRegions() == 0);
- auto *newOp = b.clone(op, map);
- for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
- map.map(std::get<0>(it), std::get<1>(it));
- }
+ }
- // 3. Emit std_store.
- auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
- assert(yieldOp->getNumOperands() == nOutputs);
- for (unsigned i = 0, e = nOutputs; i < e; ++i) {
- ValueHandleArray indexing(foldedAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
- std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
- indexing);
- }
+ // 3. Emit std_store.
+ auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+ assert(yieldOp->getNumOperands() == nOutputs);
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(foldedAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+ std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
+ indexing);
}
}
};
+// Emits the MLIR for the scalar part of the indexed generic op by:
+// 1. Emitting std_load and std_store ops for each input and output view in
+// order. This is achieved by applying the appropriate input or output map
+// to the enclosing induction variables.
+// 2. Emitting a call to `op.fun()` that takes as arguments the induction
+// variables and the scalars from point 1. above.
+// 3. Emitting std_store to store the results of 2. to the output views.
+//
+// An example output may resemble:
+//
+// ```
+// loop.for %i = %c0 to %0 step %c1 {
+// loop.for %j = %c0 to %1 step %c1 {
+// loop.for %k = %c0 to %4 step %c1 {
+// %11 = load %arg0[%i, %j] :
+// memref<?x?xf32, stride_specification>
+// %12 = load %arg1[%i, %j, %k] :
+// memref<?x?x?xf32, stride_specification>
+// %13 = load %arg2[%i, %k, %j] :
+// memref<?x?x?xf32, stride_specification>
+// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
+// (index, index, index, f32, f32, f32) -> (f32, f32)
+// store %14#0, %arg1[%i, %j, %k] :
+// memref<?x?x?Xf32, stride_specification>
+// store %14#1, %arg2[%i, %k, %j] :
+// memref<?x?x?Xf32, stride_specification>
+// }
+// }
+// }
+// ```
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- IndexedGenericOp genericOp,
+ IndexedGenericOp indexedGenericOp,
OperationFolder *folder) {
- // This is just a shim to make Linalg compile.
- // TODO(pifon): Implement lowering after IndexedGenericOp def is submitted.
+ auto b = ScopedContext::getBuilder();
+ auto loc = ScopedContext::getLocation();
+ using edsc::intrinsics::detail::ValueHandleArray;
+ unsigned nInputs = indexedGenericOp.getNumInputs();
+ unsigned nOutputs = indexedGenericOp.getNumOutputs();
+ unsigned nLoops = allIvs.size();
+ SmallVector<Value *, 4> indexedValues(nLoops + nInputs + nOutputs);
+
+ for (unsigned i = 0; i < nLoops; ++i) {
+ indexedValues[i] = allIvs[i];
+ }
+
+ // 1.a. Emit std_load from input views.
+ for (unsigned i = 0; i < nInputs; ++i) {
+ ValueHandleArray indexing(foldedAffineApplies(
+ b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder));
+ indexedValues[nLoops + i] =
+ std_load(indexedGenericOp.getInput(i), indexing);
+ }
+
+ // 1.b. Emit std_load from output views.
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(foldedAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+ indexedValues[nLoops + nInputs + i] =
+ std_load(indexedGenericOp.getOutput(i), indexing);
+ }
+
+ if (auto funcOp = indexedGenericOp.getFunction()) {
+ // 2. Emit call.
+ Operation *callOp = call(funcOp, indexedValues);
+ assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs());
+
+ // 3. Emit std_store.
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(foldedAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+ std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+ indexing);
+ }
+ return;
+ }
+ // TODO(ntv): When a region inliner exists, use it.
+ // 2. Inline region, currently only works for a single basic block.
+ BlockAndValueMapping map;
+ auto &block = indexedGenericOp.region().front();
+ for (auto it : llvm::zip(block.getArguments(), indexedValues))
+ map.map(std::get<0>(it), std::get<1>(it));
+ for (auto &op : block.without_terminator()) {
+ assert(op.getNumRegions() == 0);
+ auto *newOp = b.clone(op, map);
+ for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
+ map.map(std::get<0>(it), std::get<1>(it));
+ }
+
+ // 3. Emit std_store.
+ auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+ assert(yieldOp->getNumOperands() == nOutputs);
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(foldedAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+ std_store(map.lookup(yieldOp->getOperand(i)),
+ indexedGenericOp.getOutput(i), indexing);
+ }
}
};