blob: 7aa6e56b017ece7310a2a6b7d94872af4bab8c56 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sys/types.h>
#include <string>
#include <utility>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Analysis/shape_component_analysis.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
namespace tensorflow {
namespace {
using llvm::ArrayRef;
using llvm::SmallVector;
using mlir::AffineExpr;
using mlir::AffineMap;
using mlir::failure;
using mlir::Location;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpBuilder;
using mlir::OperationPass;
using mlir::RankedTensorType;
using mlir::ShapeComponentAnalysis;
using mlir::success;
using mlir::TypeRange;
using mlir::Value;
using mlir::ValueRange;
using mlir::arith::ConstantIndexOp;
using mlir::arith::ConstantOp;
using mlir::arith::IndexCastOp;
using mlir::func::FuncOp;
namespace linalg = mlir::linalg;
namespace mhlo = mlir::mhlo;
namespace shape = mlir::shape;
namespace tensor = mlir::tensor;
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
// -------------------------------------------------------------------------- //
// Replace shape.broadcast with a shape if it's statically known.
class BroadcastOpLowering final
: public mlir::OpRewritePattern<shape::BroadcastOp> {
public:
explicit BroadcastOpLowering(MLIRContext* ctx) : OpRewritePattern(ctx) {}
LogicalResult matchAndRewrite(shape::BroadcastOp op,
mlir::PatternRewriter& rewriter) const override;
};
// Returns a shape tensor if the shapes can be broadcasted to a known shape.
// Will either return one of the shapes or a generated mix of the shapes.
llvm::Optional<Value> simplifyBroadcast(ShapeComponentAnalysis& analysis,
ValueRange shapes, Location loc,
OpBuilder* builder) {
// First find the input shape with the largest rank.
SmallVector<ArrayRef<ShapeComponentAnalysis::SymbolicExpr>> shapes_found;
size_t maxRank = 0;
for (const auto &shape : llvm::enumerate(shapes)) {
auto found_shape = analysis.GetValueInfo(shape.value());
if (!found_shape) return {};
shapes_found.push_back(*found_shape);
maxRank = std::max(maxRank, found_shape->size());
}
if (maxRank == 0) {
return Value(builder->create<tensor::FromElementsOp>(
loc, shapes[0].getType(), SmallVector<Value>()));
}
SmallVector<const ShapeComponentAnalysis::SymbolicExpr*> joined_dimensions(
maxRank);
SmallVector<std::pair<Value, int64_t>> shape_and_rank_for_dim(maxRank);
for (const auto &shape : llvm::enumerate(shapes_found)) {
for (const auto &dim : llvm::enumerate(llvm::reverse(shape.value()))) {
// 1 dimensions don't contribute to the final result.
if (dim.value().isConstant(1)) continue;
// If it's not a 1 dimension it will be present in the result. Remember
// where it came from.
auto index = maxRank - dim.index() - 1;
if (!joined_dimensions[index]) {
joined_dimensions[index] = &dim.value();
shape_and_rank_for_dim[index] =
std::make_pair(shapes[shape.index()], shape.value().size());
continue;
}
// Bail if the dimensions are neither equal nor 1.
if (*joined_dimensions[index] != dim.value()) return {};
}
}
// If the output is the same as one of the inputs just return that.
if (llvm::is_splat(shape_and_rank_for_dim) &&
shape_and_rank_for_dim[0].first) {
return shape_and_rank_for_dim[0].first;
}
// Otherwise rematerialize the shape from the pieces we have.
SmallVector<Value> elements;
for (int i = 0; i != maxRank; ++i) {
// 1 dimensions are filtered above, recreate the constant.
if (!shape_and_rank_for_dim[i].first) {
auto one = builder->getIntegerAttr(
shapes[0].getType().cast<RankedTensorType>().getElementType(), 1);
elements.push_back(builder->create<ConstantOp>(loc, one));
continue;
}
// Extract from one of the shapes, accounting for the reverse indexing
// performed by broadcast.
Value index = builder->create<ConstantIndexOp>(
loc, i - maxRank + shape_and_rank_for_dim[i].second);
elements.push_back(builder->create<tensor::ExtractOp>(
loc, shape_and_rank_for_dim[i].first, index));
}
return Value(builder->create<tensor::FromElementsOp>(loc, elements));
}
LogicalResult BroadcastOpLowering::matchAndRewrite(
shape::BroadcastOp op, mlir::PatternRewriter& rewriter) const {
ShapeComponentAnalysis shape_component_analysis;
auto new_broadcast = simplifyBroadcast(
shape_component_analysis, op.getShapes(), op.getLoc(), &rewriter);
if (!new_broadcast) return failure();
rewriter.replaceOp(op, {*new_broadcast});
return success();
}
// -------------------------------------------------------------------------- //
// Rewrite mhlo.dynamic_broadcast_in_dim operation into linalg.generic operation
// if can infer the indexing maps for the operand from the symbolic shapes.
class DynamicBroadcastInDimOpLowering
: public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
public:
using Base = OpRewritePattern<mhlo::DynamicBroadcastInDimOp>;
explicit DynamicBroadcastInDimOpLowering(MLIRContext* ctx);
LogicalResult matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,
mlir::PatternRewriter& rewriter) const override;
};
DynamicBroadcastInDimOpLowering::DynamicBroadcastInDimOpLowering(
MLIRContext* ctx)
: Base(ctx) {}
// Check if broadcasting `from` to `to_shape` is statically known to only have
// dimensions that never expand or always expand.
llvm::Optional<AffineMap> isNonExpandingBroadcast(
ShapeComponentAnalysis& analysis, Value from, Value to_shape) {
auto in_shape = analysis.GetShapeInfo(from);
auto out_shape = analysis.GetValueInfo(to_shape);
if (!in_shape || !out_shape) return {};
SmallVector<AffineExpr> input_map_exprs;
size_t rank = out_shape->size();
MLIRContext* ctx = (*out_shape)[0].expr.getContext();
size_t d = 0;
auto affine_zero = getAffineConstantExpr(0, ctx);
for (auto zip :
llvm::zip(llvm::reverse(*in_shape), llvm::reverse(*out_shape))) {
const auto& in = std::get<0>(zip);
const auto& out = std::get<1>(zip);
bool extend = in.isConstant(1) && !out.isConstant(1);
input_map_exprs.push_back(extend ? affine_zero
: getAffineDimExpr(rank - d - 1, ctx));
++d;
// Bail if this is neither a known expansion nor a known non-expansion.
if (!extend && in != out) return {};
}
// Any leading dimensions will be expanded.
input_map_exprs.resize(in_shape->size(), affine_zero);
std::reverse(input_map_exprs.begin(), input_map_exprs.end());
return AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, input_map_exprs, ctx);
}
LogicalResult DynamicBroadcastInDimOpLowering::matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op, mlir::PatternRewriter& rewriter) const {
MLIRContext* ctx = getContext();
auto in_type = op.operand().getType().dyn_cast<RankedTensorType>();
auto out_type = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!in_type || !out_type) return failure();
// Check that broadcast is right-aligned (numpy style), so that operand
// dimensions broadcasted to match inner-most dimensions of the output.
auto bcast_dims = op.broadcast_dimensions().getValues<int64_t>();
auto expected_bcast_dims = llvm::seq<int64_t>(
out_type.getRank() - in_type.getRank(), out_type.getRank());
if (!llvm::equal(bcast_dims, expected_bcast_dims)) return failure();
ShapeComponentAnalysis shape_component_analysis;
auto input_map = isNonExpandingBroadcast(
shape_component_analysis, op.operand(), op.output_dimensions());
if (!input_map) return failure();
// Resolve dynamic output dimensions for the `linalg.init_tensor` operation.
SmallVector<Value> output_dyn_dimensions;
Location loc = op.getLoc();
int64_t rank = out_type.getRank();
for (size_t d = 0; d < rank; ++d) {
int64_t output_dim = out_type.getShape()[d];
// Skip static output dimensions, they will be resolved from the shape.
if (output_dim >= 0) continue;
// Resolve the dynamic size of the output dimension.
Value output_dyn_dim = rewriter.create<tensor::ExtractOp>(
loc, op.output_dimensions(),
ValueRange{rewriter.create<ConstantIndexOp>(loc, d)});
// Symbolic shape analysis might have given us an i32 or i64. Cast to index.
if (!output_dyn_dim.getType().isIndex())
output_dyn_dim = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(), output_dyn_dim);
output_dyn_dimensions.push_back(output_dyn_dim);
}
// Create a linalg.tensor_init operation to initialize output.
Value init = rewriter.create<linalg::InitTensorOp>(loc, output_dyn_dimensions,
out_type.getShape(),
out_type.getElementType());
// Output indexing map is an identity with `rank` number of loops.
AffineMap output_map = AffineMap::getMultiDimIdentityMap(rank, ctx);
// All iterators are parallel.
SmallVector<llvm::StringRef> iterator_types(rank, "parallel");
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/TypeRange{init.getType()},
/*inputs=*/ValueRange{op.operand()},
/*outputs=*/ValueRange{init},
/*indexingMaps=*/llvm::makeArrayRef({*input_map, output_map}),
/*iteratorTypes=*/iterator_types,
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nested_builder.create<linalg::YieldOp>(nested_loc, args[0]);
});
return success();
}
// -------------------------------------------------------------------------- //
// Optimize function based on the symbolic shape attributes.
// -------------------------------------------------------------------------- //
struct SymbolicShapeOptimizationPass
: public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
SymbolicShapeOptimizationPass() = default;
explicit SymbolicShapeOptimizationPass(bool constraints_only) {
this->optimize_only_constraints = constraints_only;
}
void runOnOperation() override {
MLIRContext* ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
// Rewrite shape.broadcast based on the symbolic shapes.
patterns.add<BroadcastOpLowering>(ctx);
// Rewrite broadcasts based on the symbolic shapes if enabled.
if (!optimize_only_constraints)
patterns.add<DynamicBroadcastInDimOpLowering>(ctx);
// Add shape dialect canonicalization patterns to fold shape operations
// after constraints are replaced with constant witness.
for (auto op : ctx->getRegisteredOperations()) {
if (llvm::isa<shape::ShapeDialect>(op.getDialect()))
op.getCanonicalizationPatterns(patterns, ctx);
}
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateSymbolicShapeOptimizationPass(
bool constraints_only) {
return std::make_unique<SymbolicShapeOptimizationPass>(constraints_only);
}
} // namespace tensorflow