blob: df258069821174daa7d0ebaea4591b072369b29e [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace mhlo {
namespace {
LogicalResult TryLowerToCollapseShape(
ReduceOp op, RankedTensorType arg_ty, Value arg,
SmallVector<int64_t>& ordered_reduction_dims, PatternRewriter& rewriter) {
// This only works for trivial reductions where all declared reduction
// dimensiosn are of extent 1.
if (!llvm::all_of(ordered_reduction_dims,
[&](int64_t i) { return arg_ty.getDimSize(i) == 1; })) {
return failure();
}
int64_t arg_rank = arg_ty.getRank();
int64_t num_reduction_dims = ordered_reduction_dims.size();
int64_t j = 0;
auto is_declared_as_reduction_dim = [&](int64_t i) {
if (j < num_reduction_dims && ordered_reduction_dims[j] == i) {
j++;
return true;
}
return false;
};
// Build reassociation indices.
SmallVector<ReassociationIndices, 4> reassociation;
int64_t i_begin = 0;
int64_t i = 0;
while (i < arg_rank && is_declared_as_reduction_dim(i)) i++;
while (i < arg_rank) {
i++;
while (i < arg_rank && is_declared_as_reduction_dim(i)) i++;
reassociation.push_back(llvm::to_vector(llvm::seq(i_begin, i)));
i_begin = i;
}
// Lower reduction op to collapse shape op.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, arg, reassociation);
return success();
}
enum class DimensionKind {
kParallel,
kReduction,
kDegenerate,
};
struct DimensionGroup {
DimensionKind kind;
int64_t begin;
int64_t end;
int64_t size() { return end - begin; }
};
// Groups consecutive dimensions of a reduction argument by their kind, i.e. if
// they are reduction or parallel dimensions. Dimensions of size 1 can be
// considered as any kind.
void GroupDimensions(RankedTensorType arg_ty,
SmallVector<int64_t> ordered_reduction_dims,
SmallVector<DimensionGroup>& groups) {
int64_t arg_rank = arg_ty.getRank();
int64_t num_reduction_dims = ordered_reduction_dims.size();
int64_t j = 0;
for (int64_t i = 0; i < arg_rank; ++i) {
// Check if the i-th dimension is one of the declared reduction dimensions.
bool is_declared_as_reduction_dim = false;
if (j < num_reduction_dims && i == ordered_reduction_dims[j]) {
is_declared_as_reduction_dim = true;
j++;
}
// Use the declared dimension kind unless the dimension is of extent 1, in
// which case we can consider it either kind. We exploit this to form
// maximal dimension groups.
DimensionKind kind = is_declared_as_reduction_dim
? DimensionKind::kReduction
: DimensionKind::kParallel;
if (arg_ty.getDimSize(i) == 1) kind = DimensionKind::kDegenerate;
// Start a new dimension group if the dimenion kind conflicts with the
// trailing kind.
if (groups.empty() || (groups.back().kind != kind &&
groups.back().kind != DimensionKind::kDegenerate &&
kind != DimensionKind::kDegenerate)) {
groups.push_back({kind, i, i});
}
// Include dimension in trailing group and concretize dimension kind if
// necessary.
if (groups.back().kind == DimensionKind::kDegenerate)
groups.back().kind = kind;
groups.back().end++;
}
}
LogicalResult TryLowerTo1DOr2DReduction(
ReduceOp op, RankedTensorType arg_ty, Value arg,
SmallVector<int64_t>& ordered_reduction_dims,
bool prefer_columns_reductions, PatternRewriter& rewriter) {
// Group the argument dimensions by their kind.
SmallVector<DimensionGroup> dim_groups;
GroupDimensions(arg_ty, ordered_reduction_dims, dim_groups);
// Do not (re-)apply if the dimensions are already fully collapsed.
if (dim_groups.size() <= 2 &&
llvm::all_of(dim_groups, [](auto g) { return g.size() == 1; })) {
return failure();
}
// Determine whether or not a dynamic reshape is needed for the final result.
int64_t num_dyn_parallel_dims = 0;
for (auto group : dim_groups) {
if (group.kind != DimensionKind::kParallel) continue;
for (int64_t i = group.begin; i < group.end; i++) {
if (arg_ty.isDynamicDim(i)) num_dyn_parallel_dims++;
}
}
bool requires_dynamic_reshape = num_dyn_parallel_dims > 1;
// Reify the result shape early so that the pattern can fail without altering
// the IR.
Optional<Value> result_shape;
if (requires_dynamic_reshape) {
llvm::SmallVector<Value, 1> reified_shapes;
if (failed(llvm::cast<InferShapedTypeOpInterface>(op.getOperation())
.reifyReturnTypeShapes(rewriter, op->getOperands(),
reified_shapes))) {
return failure();
}
assert(reified_shapes.size() == 1 && "expect exactly one shape");
result_shape = reified_shapes.front();
}
// Collapse dimension groups so that all adjacent dimensions of the
// intermediate result are of a different kind.
Value interm_result = arg;
auto loc = op.getLoc();
bool requires_collapse =
llvm::any_of(dim_groups, [&](auto g) { return g.size() > 1; });
if (requires_collapse) {
auto reassociation =
llvm::to_vector(llvm::map_range(dim_groups, [&](auto g) {
return llvm::to_vector<2>(llvm::seq<int64_t>(g.begin, g.end));
}));
interm_result = rewriter.create<tensor::CollapseShapeOp>(loc, interm_result,
reassociation);
}
// If required, transpose the intermediate result so that dimensions kinds
// form two partitions, which can be collapsed to a 2D intermediate result.
bool requires_transpose = dim_groups.size() > 2;
if (requires_transpose) {
// Materialize transpose.
DimensionKind leading_dim_kind = prefer_columns_reductions
? DimensionKind::kReduction
: DimensionKind::kParallel;
DimensionKind trailing_dim_kind = prefer_columns_reductions
? DimensionKind::kParallel
: DimensionKind::kReduction;
SmallVector<int64_t> perm;
for (int i = 0; i < dim_groups.size(); i++) {
if (dim_groups[i].kind == leading_dim_kind) perm.push_back(i);
}
int64_t num_leading_dims = perm.size();
for (int i = 0; i < dim_groups.size(); i++) {
if (dim_groups[i].kind == trailing_dim_kind) perm.push_back(i);
}
auto perm_attr = rewriter.getI64TensorAttr(perm);
interm_result = rewriter.create<TransposeOp>(loc, interm_result, perm_attr)
->getResults()
.front();
// Collapse intermediate result rank 2.
SmallVector<ReassociationIndices, 2> reassociation = {
llvm::to_vector<2>(llvm::seq<int64_t>(0, num_leading_dims)),
llvm::to_vector<2>(llvm::seq<int64_t>(num_leading_dims, perm.size()))};
interm_result = rewriter.create<tensor::CollapseShapeOp>(loc, interm_result,
reassociation);
}
// Materialize inner 1D or 2D reduction.
bool leading_reduction =
requires_transpose ? prefer_columns_reductions
: dim_groups.front().kind == DimensionKind::kReduction;
int64_t reduction_dim = leading_reduction ? 0 : 1;
auto reduction_dim_attr = rewriter.getI64VectorAttr({reduction_dim});
Value init_val = op.init_values().front();
auto reduction_op = rewriter.create<ReduceOp>(loc, interm_result, init_val,
reduction_dim_attr);
rewriter.inlineRegionBefore(op.body(), reduction_op.body(),
reduction_op.body().begin());
interm_result = reduction_op->getResults().front();
// Restore the expected shape by dynamic reshape, if required.
auto result_ty = op->getResultTypes().front().cast<RankedTensorType>();
if (requires_dynamic_reshape) {
assert(result_shape && "expect to have reified the result shape");
interm_result = rewriter.create<DynamicReshapeOp>(
loc, result_ty, interm_result, *result_shape);
}
// Othwerise, restore the expected shape by shape expansion, if required.
int64_t result_rank = result_ty.getRank();
int64_t interm_result_rank =
interm_result.getType().cast<RankedTensorType>().getRank();
bool requires_expand =
!requires_dynamic_reshape && result_rank != interm_result_rank;
if (requires_expand) {
assert(interm_result_rank <= 1 &&
"expect intermediate result to be of rank 0 or 1 before expansion");
SmallVector<ReassociationIndices, 1> reassociation;
bool is_scalar_expansion = interm_result_rank == 0;
if (!is_scalar_expansion)
reassociation = {llvm::to_vector(llvm::seq<int64_t>(0, result_rank))};
interm_result = rewriter.create<tensor::ExpandShapeOp>(
loc, result_ty, interm_result, reassociation);
}
rewriter.replaceOp(op, interm_result);
return success();
}
struct GroupReductionDimensionsPattern : public OpRewritePattern<ReduceOp> {
GroupReductionDimensionsPattern(MLIRContext* ctx,
bool prefer_columns_reductions)
: OpRewritePattern<ReduceOp>(ctx, /*benefit=*/1),
prefer_columns_reductions(prefer_columns_reductions) {}
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter& rewriter) const override {
// Only apply to reduction of a unique argument.
if (op.inputs().size() != 1 || op.init_values().size() != 1)
return failure();
Value arg = op.inputs().front();
auto arg_ty = arg.getType().cast<RankedTensorType>();
// Sort reduction dimensions, which is not an invariant of the op.
SmallVector<int64_t> ordered_reduction_dims =
llvm::to_vector<4>(llvm::map_range(op.dimensions(), [](auto d) {
return static_cast<int64_t>(d.getLimitedValue());
}));
std::sort(ordered_reduction_dims.begin(), ordered_reduction_dims.end());
// If all reduction dimensions are known to be of extent 1 then we can
// express the reduction through an equivalent collapsing op.
if (succeeded(TryLowerToCollapseShape(op, arg_ty, arg,
ordered_reduction_dims, rewriter))) {
return success();
}
// Otherwise, try lowering the reduction to an equivalent 1D or 2D
// reduction, and insert transposes if needed.
if (succeeded(
TryLowerTo1DOr2DReduction(op, arg_ty, arg, ordered_reduction_dims,
prefer_columns_reductions, rewriter))) {
return success();
}
return failure();
}
bool prefer_columns_reductions;
};
struct GroupReductionDimensionsPass
: public GroupReductionDimensionsPassBase<GroupReductionDimensionsPass> {
explicit GroupReductionDimensionsPass(bool prefer_columns_reductions)
: GroupReductionDimensionsPassBase<
GroupReductionDimensionsPass>::GroupReductionDimensionsPassBase() {
prefer_columns_reductions_ = prefer_columns_reductions;
}
void runOnOperation() override {
MLIRContext* ctx = &getContext();
RewritePatternSet patterns(ctx);
populateGroupReductionDimensionsPatterns(ctx, &patterns,
prefer_columns_reductions_);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
void populateGroupReductionDimensionsPatterns(MLIRContext* context,
RewritePatternSet* patterns,
bool prefer_columns_reductions) {
patterns->add<GroupReductionDimensionsPattern>(context,
prefer_columns_reductions);
}
std::unique_ptr<OperationPass<func::FuncOp>> createGroupReductionDimensionsPass(
bool prefer_columns_reductions) {
return std::make_unique<GroupReductionDimensionsPass>(
prefer_columns_reductions);
}
} // namespace mhlo
} // namespace mlir