blob: 5ce31b1d814c5df21215aa435063e9676a7f6481 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for translating mixed IR to buffer form.
// Currently it supports MHLO and some operations from the Standard dialect.
#include <cstdint>
#include <memory>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" // from @llvm-project
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project
#include "mlir/Dialect/Tensor/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
/// A helper type converter class that automatically populates the relevant
/// materializations and type conversions for bufferization.
static Value materializeToTensor(OpBuilder& builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}
// TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
class CustomBufferizeTypeConverter
: public bufferization::BufferizeTypeConverter {
public:
CustomBufferizeTypeConverter() {
// Keep all types unchanged.
addConversion([](Type type) { return type; });
// Convert RankedTensorType to MemRefType.
addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
// Convert UnrankedTensorType to UnrankedMemRefType.
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
// Target materialization is invoked if the new operand type does not
// match the expected type. A special case is when the new operand type is
// a memref with a specified layout, i.e. non-empty affine map.
// TODO(pifon) : Change how target materialization is invoked in dialect
// conversion.
if (auto memref_type = inputs[0].getType().dyn_cast<MemRefType>()) {
assert(!memref_type.getLayout().isIdentity());
return inputs[0];
}
assert(inputs[0].getType().isa<TensorType>());
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
});
}
};
struct ComputeOpAndFuncBufferizePass
: public ComputeOpAndFuncBufferizePassBase<ComputeOpAndFuncBufferizePass> {
// TODO(b/173201243): Move to tablegen.
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect, bufferization::BufferizationDialect,
lmhlo::LmhloDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<
arith::ArithmeticDialect, complex::ComplexDialect, lmhlo::LmhloDialect,
AffineDialect, vector::VectorDialect, memref::MemRefDialect,
StandardOpsDialect, tensor::TensorDialect, math::MathDialect>();
target.addLegalOp<UnrealizedConversionCastOp, linalg::TiledLoopOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
target.addDynamicallyLegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>(
[&](Operation* op) {
return mlir::isa<linalg::TiledLoopOp>(op->getParentOp());
});
CustomBufferizeTypeConverter converter;
mhlo::RemoveSignTypeConverter remove_sign_converter;
// Configure bufferize pattern for functions and lhlo.
mhlo::populateHLOToMemrefConversionPattern(
&converter, &remove_sign_converter, &patterns,
/*enforce_identity_map=*/[](Operation* op) {
// Force identity maps for several ops which don't support memrefs
// with affine_maps.
return llvm::any_of(op->getUsers(), [](Operation* user) {
return isa<mlir::ReturnOp, mhlo::DynamicReshapeOp, tensor::CastOp,
tensor::CollapseShapeOp, tensor::ExpandShapeOp,
linalg::TiledLoopOp>(user);
});
});
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
// Configure legality and structural patterns.
bufferization::populateBufferizeMaterializationLegality(target);
linalg::populateLinalgBufferizePatterns(converter, patterns);
populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
target);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
// TODO(herhut): Move this legality configuration to bufferize itself?
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
auto results = op.getType().getResults();
return converter.isLegal(inputs) && converter.isLegal(results) &&
converter.isLegal(&op.getBody());
});
auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
target.addDynamicallyLegalOp<CallOp, ReturnOp>(isLegalOp);
auto isLegalOrInsideTiledLoop = [&](Operation* op) {
return converter.isLegal(op) ||
mlir::isa<linalg::TiledLoopOp>(op->getParentOp());
};
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
isLegalOrInsideTiledLoop);
target
.addDynamicallyLegalOp<vector::TransferWriteOp, vector::TransferReadOp>(
isLegalOrInsideTiledLoop);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct TiledLoopBufferizePass
: public TiledLoopBufferizePassBase<TiledLoopBufferizePass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<
arith::ArithmeticDialect, bufferization::BufferizationDialect,
complex::ComplexDialect, lmhlo::LmhloDialect, AffineDialect,
vector::VectorDialect, memref::MemRefDialect, StandardOpsDialect,
tensor::TensorDialect, math::MathDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>();
CustomBufferizeTypeConverter converter;
mhlo::RemoveSignTypeConverter remove_sign_converter;
// Configure bufferize pattern.
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
bufferization::populateBufferizeMaterializationLegality(target);
populateTiledLoopBufferizePattern(&getContext(), &converter, &patterns);
populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
target);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
// Configure legality.
auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOp);
target
.addDynamicallyLegalOp<CallOp, LLVM::InlineAsmOp,
vector::TransferWriteOp, vector::TransferReadOp>(
isLegalOp);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
// TODO(b/173201243): Move to tablegen.
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
shape::ShapeDialect, tensor::TensorDialect,
tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect,
arith::ArithmeticDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
}
// Default alignment_ specified in passes.td
FinalBufferizePass() = default;
explicit FinalBufferizePass(uint64_t alignment) { alignment_ = alignment; }
void runOnOperation() override {
// Bufferize ops using BufferizableOpInterface. This could be switched to
// One-Shot Bufferize in the future.
RewritePatternSet patterns(&getContext());
bufferization::BufferizationOptions options =
bufferization::getPartialBufferizationOptions();
options.bufferAlignment = alignment_;
// TODO(springerm): Add dialects to this filter as more and more dialects
// will be migrated to BufferizableOpInterface-based bufferization.
options.addToDialectFilter<arith::ArithmeticDialect, StandardOpsDialect,
tensor::TensorDialect>();
bufferization::AlwaysCopyBufferizationState bufferization_state(options);
bufferization::populateBufferizationPattern(bufferization_state, patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
return;
}
// Bufferize the remaining IR with dialect conversion. This will disappear
// eventually once all bufferization is done via BufferizableOpInterface.
if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
}
private:
LogicalResult runDialectConversionBasedBufferization() {
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<
arith::ArithmeticDialect, bufferization::BufferizationDialect,
cf::ControlFlowDialect, complex::ComplexDialect, memref::MemRefDialect,
StandardOpsDialect, scf::SCFDialect, tensor::TensorDialect,
tf_framework::TFFrameworkDialect, AffineDialect, shape::ShapeDialect,
lmhlo::LmhloDialect, linalg::LinalgDialect, math::MathDialect,
vector::VectorDialect>();
target.addLegalOp<FuncOp, ModuleOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::CastOp, tensor::DimOp,
tensor::RankOp, chlo::MinimumBroadcastShapesOp,
bufferization::ToTensorOp, bufferization::ToMemrefOp,
tensor::ExpandShapeOp, tensor::CollapseShapeOp>();
bufferization::BufferizeTypeConverter converter;
auto typesAreLegal = [&converter](Operation* op) {
return converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes());
};
target.addDynamicallyLegalOp<ConstantOp, arith::ConstantOp,
arith::IndexCastOp, arith::SelectOp,
tf_framework::JITExecuteOp>(typesAreLegal);
RewritePatternSet patterns(&getContext());
linalg::populateLinalgBufferizePatterns(converter, patterns);
populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
populateExtraBufferizePatterns(&getContext(), &converter, &patterns);
populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
target);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
return applyFullConversion(getOperation(), target, std::move(patterns));
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateComputeOpAndFuncBufferizePass() {
return std::make_unique<ComputeOpAndFuncBufferizePass>();
}
std::unique_ptr<OperationPass<FuncOp>> CreateTiledLoopBufferizePass() {
return std::make_unique<TiledLoopBufferizePass>();
}
std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass() {
return std::make_unique<FinalBufferizePass>();
}
std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(
uint64_t alignment) {
return std::make_unique<FinalBufferizePass>(alignment);
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir