blob: b2cff6085f3b99fdb1abee7637ec64c7947135d7 [file] [log] [blame]
//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the linalg dialect Tiling pass.
//
//===----------------------------------------------------------------------===//
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Linalg/Passes.h"
#include "mlir/Linalg/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using namespace llvm;
static llvm::cl::OptionCategory clOptionsCategory("linalg options");
static llvm::cl::list<unsigned>
clTileSizes("linalg-tile-sizes",
llvm::cl::desc("Tile sizes by which to tile linalg operations"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::cat(clOptionsCategory));
namespace {
class PerFunctionState {
public:
PerFunctionState(Function &f) : f(f) {}
Value *getOrCreate(int64_t v) {
auto it = map.find(v);
if (it != map.end())
return it->second;
FuncBuilder builder(f);
edsc::ScopedContext s(builder, f.getLoc());
return map.insert(std::make_pair(v, edsc::intrinsics::constant_index(v)))
.first->getSecond();
}
private:
Function &f;
SmallDenseMap<int64_t, Value *> map;
};
} // namespace
// Folding eagerly is necessary to abide by affine.for static step requirement.
// We must propagate constants on the steps as aggressively as possible.
// Returns nullptr if folding is not trivially feasible.
static Value *tryFold(AffineMap map, ArrayRef<Value *> operands,
PerFunctionState &state) {
assert(map.getNumResults() == 1 && "single result map expected");
auto expr = map.getResult(0);
if (auto dim = expr.dyn_cast<AffineDimExpr>())
return operands[dim.getPosition()];
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
return operands[map.getNumDims() + sym.getPosition()];
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
return state.getOrCreate(cst.getValue());
return nullptr;
}
static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc,
AffineMap map,
ArrayRef<Value *> operandsRef,
PerFunctionState &state) {
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
fullyComposeAffineMapAndOperands(&map, &operands);
if (auto *v = tryFold(map, operands, state))
return v;
return b->create<AffineApplyOp>(loc, map, operands);
}
static SmallVector<Value *, 4> applyMapToRangePart(FuncBuilder *b, Location loc,
AffineMap map,
ArrayRef<Value *> ranges,
RangePart part,
PerFunctionState &state) {
SmallVector<Value *, 4> rangeParts(ranges.size());
transform(llvm::make_range(ranges.begin(), ranges.end()), rangeParts.begin(),
[&](Value *range) { return extractRangePart(range, part); });
SmallVector<Value *, 4> res;
res.reserve(map.getNumResults());
unsigned numDims = map.getNumDims();
for (auto expr : map.getResults()) {
AffineMap map = AffineMap::get(numDims, 0, expr, {});
res.push_back(
emitOrFoldComposedAffineApply(b, loc, map, rangeParts, state));
}
return res;
}
static bool isZero(Value *v) {
return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
}
/// Returns a map that can be used to filter the zero values out of tileSizes.
/// For example, if tileSizes contains `{v1, 0, v2}`, the returned map is:
///
/// ```{.mlir}
/// (d0, d1, d2) -> (d0, d2)
/// ```
static AffineMap nonZeroMap(ArrayRef<Value *> tileSizes) {
SmallVector<AffineExpr, 4> exprs;
for (auto en : llvm::enumerate(tileSizes))
if (!isZero(en.value()))
exprs.push_back(getAffineDimExpr(en.index(), en.value()->getContext()));
assert(!exprs.empty() &&
"unexpected zero-only tile sizes, should have been handled earlier");
return AffineMap::get(tileSizes.size(), 0, exprs, {});
}
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
// one entry per surrounding loop. It uses zero as the convention that a
// particular loop is not tiled. This convention simplifies implementations by
// avoiding affine map manipulations.
// The returned ranges correspond to the loop ranges, in the proper order, that
// are tiled and for which new loops will be created.
static SmallVector<Value *, 4>
makeTiledLoopRanges(FuncBuilder *b, Location loc, AffineMap map,
ArrayRef<Value *> allOpRanges, ArrayRef<Value *> tileSizes,
PerFunctionState &state) {
assert(tileSizes.size() == map.getNumResults());
// Tile sizes are in loop order by construction, apply `map` to
// get mins/maxes/steps in loop order.
auto mins =
applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Min, state);
auto maxes =
applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Max, state);
auto steps =
applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Step, state);
SmallVector<Value *, 4> sizes(tileSizes.begin(), tileSizes.end());
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
for (int idx = mins.size() - 1; idx >= 0; --idx) {
if (isZero(tileSizes[idx])) {
mins.erase(mins.begin() + idx);
maxes.erase(maxes.begin() + idx);
steps.erase(steps.begin() + idx);
sizes.erase(sizes.begin() + idx);
}
}
// Create a new range with the applied tile sizes.
SmallVector<Value *, 4> res;
for (unsigned idx = 0, e = steps.size(); idx < e; ++idx) {
auto *step = steps[idx];
auto *tileSize = sizes[idx];
// clang-format off
// Steps must be constant for now to abide by affine.for semantics.
auto *newStep =
state.getOrCreate(
cast<ConstantIndexOp>(step->getDefiningOp()).getValue() *
cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue());
res.push_back(b->create<RangeOp>(loc, mins[idx], maxes[idx], newStep));
// clang-format on
}
return res;
}
static SmallVector<Value *, 4> makeTiledViews(FuncBuilder *b, Location loc,
Operation *op,
ArrayRef<Value *> ivs,
ArrayRef<Value *> tileSizes,
PerFunctionState &state) {
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
[](Value *v) { return !isZero(v); })) &&
"expected as many ivs as non-zero sizes");
auto *context = op->getContext();
SmallVector<Value *, 4> res;
res.reserve(op->getNumOperands());
for (unsigned i = 0, ei = op->getNumOperands(); i < ei; ++i) {
auto *viewDefiningOp = op->getOperand(i)->getDefiningOp();
assert(viewDefiningOp && "Need operations to extract ranges from views");
auto ranges = getRanges(viewDefiningOp);
// E.g. for A in A(i, k) * B(k, j) -> C(i, j) returns the map:
// (i, j, k) -> (i, k)
auto map = loopToOperandRangesMaps(op)[i];
if (!map) {
assert(ranges.empty() && "scalar should have empty ranges");
res.push_back(op->getOperand(i));
continue;
}
assert(ranges.size() == map.getNumResults());
// E.g. for {0, 0, v2} returns the map:
// (i, j, k) -> (k)
auto nzMap = nonZeroMap(tileSizes);
SmallVector<Value *, 4> newRanges;
newRanges.reserve(ranges.size());
for (unsigned j = 0, ej = ranges.size(); j < ej; ++j) {
// Loop position for the range dimension.
// E.g. for A in A(i, k) * B(k, j) -> C(i, j) and map: (i, j, k) -> (i, k)
// and for j == 1 (i.e. result `k`)
// returns loopPos = 2 (i.e. `k` on the map domain).
auto pos = map.getResult(j).template cast<AffineDimExpr>().getPosition();
if (isZero(tileSizes[pos])) {
newRanges.push_back(ranges[j]);
continue;
}
auto it = llvm::find_if(nzMap.getResults(), [pos, context](AffineExpr e) {
return e == getAffineDimExpr(pos, context);
});
assert(it != nzMap.getResults().end() &&
"position does not correspond to a valid induction variable");
unsigned pos2 = it - nzMap.getResults().begin();
using edsc::op::operator+;
using range = ValueBuilder<RangeOp>;
using range_intersect = ValueBuilder<RangeIntersectOp>;
ScopedContext scope(*b, loc);
ValueHandle iv(ivs[pos2]), step(tileSizes[pos]);
auto min = ValueHandle(extractRangePart(ranges[j], RangePart::Min));
// zero case is important enough to fold away by special-casing.
auto newMin = isZero(min) ? iv : min + iv;
Value *r = range_intersect(ranges[j], range(newMin, newMin + step, step));
newRanges.push_back(r);
}
res.push_back(createOrReturnView(b, loc, viewDefiningOp, newRanges));
}
return res;
}
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
PerFunctionState &state) {
// Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
op.getNumWindowLoops() ==
tileSizes.size() &&
"expected matching number of tile sizes and loops");
FuncBuilder builder(op.getOperation());
ScopedContext scope(builder, op.getLoc());
auto loopRanges = makeTiledLoopRanges(
scope.getBuilder(), scope.getLocation(),
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op))),
getRanges(op.getOperation()), tileSizes, state);
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
LoopNestRangeBuilder(pivs, loopRanges)({[&op, &tileSizes, &ivs, &state]() {
auto *b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
// If/when the assertion below becomes false, we will have to templatize
// `makeTiledViews`.
assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
auto views =
makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state);
op.create(*b, loc, views);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()});
return success();
}
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
PerFunctionState &state) {
if (tileSizes.empty())
return failure();
// The following uses the convention that "tiling by zero" skips tiling a
// particular dimension. This convention is significantly simpler to handle
// instead of adjusting affine maps to account for missing dimensions.
auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() +
op.getNumWindowLoops();
tileSizes = tileSizes.take_front(nLoops);
// If only 0 tilings are left, then return.
if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
return failure();
// Materialize concrete tile size values to pass the generic tiling function.
SmallVector<Value *, 8> tileSizeValues;
tileSizeValues.reserve(tileSizes.size());
for (auto ts : tileSizes)
tileSizeValues.push_back(state.getOrCreate(ts));
// Pad tile sizes with zero values to enforce our convention.
if (tileSizeValues.size() < nLoops) {
for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
tileSizeValues.push_back(state.getOrCreate(0));
}
return tileLinalgOp(op, tileSizeValues, state);
}
// TODO(ntv) expose as a primitive for other passes.
static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes,
PerFunctionState &state) {
if (auto linalgOp = dyn_cast<LinalgOp>(op))
return tileLinalgOp(linalgOp, tileSizes, state);
return failure();
}
static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
PerFunctionState state(f);
f.walk([tileSizes, &state](Operation *op) {
if (succeeded(tileLinalgOp(op, tileSizes, state)))
op->erase();
});
}
namespace {
struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> {
LinalgTilingPass();
LinalgTilingPass(ArrayRef<int64_t> sizes);
void runOnFunction() { tileLinalgOps(getFunction(), tileSizes); }
SmallVector<int64_t, 8> tileSizes;
};
} // namespace
LinalgTilingPass::LinalgTilingPass()
: tileSizes(clTileSizes.begin(), clTileSizes.end()) {}
LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes)
: LinalgTilingPass() {
if (!sizes.empty())
this->tileSizes.assign(sizes.begin(), sizes.end());
}
FunctionPassBase *
mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
return new LinalgTilingPass(tileSizes);
}
static PassRegistration<LinalgTilingPass>
pass("linalg-tile", "Tile operations in the linalg dialect");