blob: f9163e7686c665133a64cbf4ffd7573ec925020b [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering HLO dialect to LHLO dialect.
#include <functional>
#include <memory>
#include <utility>
#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace mhlo {
namespace {
using bufferization::AnalysisState;
using bufferization::BufferizableOpInterface;
using bufferization::BufferizationOptions;
using bufferization::BufferRelation;
using bufferization::replaceOpWithNewBufferizedOp;
struct CustomCallOpInterface
: public BufferizableOpInterface::ExternalModel<CustomCallOpInterface,
mhlo::CustomCallOp> {
bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false; // Arguments are read-only.
}
SmallVector<OpResult> getAliasingOpResult(Operation *, OpOperand &,
const AnalysisState &) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto customCallOp = cast<mhlo::CustomCallOp>(op);
// Bufferize arguments.
SmallVector<Value> bufferArgs;
for (OpOperand &operand : customCallOp->getOpOperands()) {
if (!operand.get().getType().isa<TensorType>()) return failure();
FailureOr<Value> operandBuffer =
getBuffer(rewriter, operand.get(), options);
if (failed(operandBuffer)) return failure();
bufferArgs.push_back(*operandBuffer);
}
// Allocate outputs.
for (OpResult result : customCallOp->getOpResults()) {
auto tensorType = result.getType().cast<RankedTensorType>();
if (!tensorType) return failure();
// TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
FailureOr<Value> tensorAlloc =
bufferization::allocateTensorForShapedValue(
rewriter, op->getLoc(), result,
analysisState.isTensorYielded(result), options);
if (failed(tensorAlloc)) return failure();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value resultBuffer = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, *tensorAlloc);
bufferArgs.push_back(resultBuffer);
}
auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
op->getLoc(), llvm::None, bufferArgs, op->getAttrs());
// lmhlo.custom_call uses a segment_size attribute to tell input from output
// arguments.
lhloOp->setAttr(
lhloOp.getOperandSegmentSizeAttr(),
rewriter.getI32VectorAttr({static_cast<int32_t>(op->getNumOperands()),
static_cast<int32_t>(op->getNumResults())}));
bufferization::replaceOpWithBufferizedValues(
rewriter, op, makeArrayRef(bufferArgs).slice(op->getNumOperands()));
return success();
}
};
struct ReshapeOpInterface
: public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
mhlo::ReshapeOp> {
bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return false;
}
bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(
Operation *op, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return {op->getResult(0)};
}
BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
const AnalysisState & /*state*/) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto reshapeOp = cast<mhlo::ReshapeOp>(op);
auto unrankedOperandType =
reshapeOp.operand().getType().dyn_cast<UnrankedTensorType>();
if (unrankedOperandType == nullptr) return success();
// The buffer still has the old (pre-reshape) type.
FailureOr<Value> operandBuffer =
getBuffer(rewriter, reshapeOp.operand(), options);
if (failed(operandBuffer)) return failure();
auto resultType = reshapeOp.getType().cast<RankedTensorType>();
auto destType =
MemRefType::get(resultType.getShape(), resultType.getElementType());
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, destType,
*operandBuffer);
return success();
}
};
struct DynamicReshapeOpInterface
: public BufferizableOpInterface::ExternalModel<DynamicReshapeOpInterface,
mhlo::DynamicReshapeOp> {
bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return false;
}
bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(
Operation *op, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return {op->getResult(0)};
}
BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
const AnalysisState & /*state*/) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto reshapeOp = cast<mhlo::DynamicReshapeOp>(op);
// The buffer still has the old (pre-reshape) type.
FailureOr<Value> operandBuffer =
getBuffer(rewriter, reshapeOp.operand(), options);
FailureOr<Value> outputShapeBuffer =
getBuffer(rewriter, reshapeOp.output_shape(), options);
if (failed(operandBuffer) || failed(outputShapeBuffer)) return failure();
ShapedType resultType;
TensorType opResultType = reshapeOp.getType();
if (auto rankedType = opResultType.dyn_cast<RankedTensorType>()) {
resultType =
MemRefType::get(rankedType.getShape(), rankedType.getElementType());
} else if (auto unrankedType =
opResultType.dyn_cast<UnrankedTensorType>()) {
resultType = UnrankedMemRefType::get(unrankedType.getElementType(), 0);
}
auto operand = *operandBuffer;
// If the operand has a non-identity affine map, we will have to add a copy.
auto bufferType = operandBuffer->getType().dyn_cast<MemRefType>();
if (bufferType && !bufferType.getLayout().isIdentity()) {
// TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
FailureOr<Value> tensorAlloc =
bufferization::allocateTensorForShapedValue(
rewriter, op->getLoc(), *operandBuffer,
analysisState.isTensorYielded(reshapeOp.getResult()), options);
if (failed(tensorAlloc)) return failure();
auto memrefType =
MemRefType::get(bufferType.getShape(), bufferType.getElementType());
operand = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, *tensorAlloc);
}
bufferization::replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultType, operand, *outputShapeBuffer);
return success();
}
};
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
FailureOr<Value> insertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, RewriterBase &rewriter,
const BufferizationOptions &options) {
auto loc = op.getLoc();
auto operandType = operand.getType().cast<MemRefType>();
auto operandShape = operandType.getShape();
auto operandRank = operandType.getRank();
auto resultType = op.getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
// Compute a reversed scan product. Compute the stride for the dimensions so
// far, working from minor to major dimensions. Additionally, save the
// operand shape Values to use in the next loop.
SmallVector<Value, 2> operandStrides(operandRank, one);
SmallVector<Value, 2> operandSizes(operandRank, one);
Value strideSoFar = one;
for (int i = operandRank - 1; i >= 0; --i) {
Value operandDimSize =
ShapedType::isDynamic(operandShape[i])
? rewriter.create<memref::DimOp>(loc, operand, i).getResult()
: rewriter.create<arith::ConstantIndexOp>(loc, operandShape[i])
.getResult();
operandSizes[i] = operandDimSize;
operandStrides[i] = strideSoFar;
if (i > 0) {
strideSoFar =
rewriter.create<arith::MulIOp>(loc, strideSoFar, operandDimSize);
}
}
SmallVector<OpFoldResult, 2> sizes, strides;
sizes.reserve(resultRank);
strides.reserve(resultRank);
DenseMap<int, int> outputToInputDim;
for (const auto &dim : llvm::enumerate(op.broadcast_dimensions())) {
outputToInputDim[dim.value().getSExtValue()] = dim.index();
}
for (int i = 0; i < resultRank; ++i) {
Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
FailureOr<Value> outputDimsBuffer =
getBuffer(rewriter, op.output_dimensions(), options);
if (failed(outputDimsBuffer)) return failure();
Value resultDimSize =
rewriter.create<memref::LoadOp>(loc, *outputDimsBuffer, iVal);
if (!resultDimSize.getType().isIndex()) {
resultDimSize = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), resultDimSize);
}
if (resultType.isDynamicDim(i)) {
sizes.push_back(resultDimSize);
} else {
sizes.push_back(rewriter.getIndexAttr(resultType.getDimSize(i)));
}
auto it = outputToInputDim.find(i);
// If the rank of the output is greater than the rank of the input, i.e.
// there was no output dimension in the inverse broadcast_dimensions map
// we also set stride to 0 to emulate padding of the shape with 1s and the
// corresponding expansion.
if (it == outputToInputDim.end()) {
strides.push_back(zero);
continue;
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed
// => stride flattened buffer stride
// 2) Operand dim < result dim => expansion is needed => stride := 0.
int dim = it->second;
Value isExpansion = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, operandSizes[dim], resultDimSize);
Value select = rewriter.create<mlir::arith::SelectOp>(
loc, isExpansion, zero, operandStrides[dim]);
strides.push_back(select);
}
// Type-erased memref type with static rank and dynamic strides.
SmallVector<int64_t, 2> dynamicLayout(resultRank,
ShapedType::kDynamicStrideOrOffset);
auto typeErasedMemrefType = MemRefType::get(
resultType.getShape(), operandType.getElementType(),
makeStridedLinearLayoutMap(dynamicLayout,
/*offset=*/0, rewriter.getContext()));
auto transformedOperand = rewriter.create<memref::ReinterpretCastOp>(
loc, typeErasedMemrefType, operand,
/*offset=*/rewriter.getI64IntegerAttr(0), sizes, strides);
return transformedOperand.getResult();
}
struct DynamicBroadcastInDimOpInterface
: public BufferizableOpInterface::ExternalModel<
DynamicBroadcastInDimOpInterface, mhlo::DynamicBroadcastInDimOp> {
bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return true;
}
bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(
Operation *op, OpOperand & /*opOperand*/,
const AnalysisState & /*state*/) const {
return {op->getResult(0)};
}
BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/,
const AnalysisState & /*state*/) const {
// The op may allocate.
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto broadcastInDimOp = cast<mhlo::DynamicBroadcastInDimOp>(op);
auto resultType = broadcastInDimOp.getType().dyn_cast<RankedTensorType>();
if (!resultType) return success();
// The buffer still has the old (pre-reshape) type.
FailureOr<Value> operandBuffer =
getBuffer(rewriter, broadcastInDimOp.operand(), options);
if (failed(operandBuffer)) return failure();
FailureOr<Value> result = insertDynamicMemrefCastOp(
broadcastInDimOp, *operandBuffer, rewriter, options);
if (failed(result)) return failure();
bufferization::replaceOpWithBufferizedValues(rewriter, op, *result);
return success();
}
};
struct HloLegalizeToMemrefPass
: public HloLegalizeToMemrefPassBase<HloLegalizeToMemrefPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
mhlo::MhloDialect, lmhlo::LmhloDialect>();
registerBufferizableOpInterfaceExternalModels(registry);
}
public:
void runOnOperation() override {
bufferization::BufferizationOptions options =
bufferization::getPartialBufferizationOptions();
options.opFilter.allowDialect<mhlo::MhloDialect>();
if (failed(bufferizeOp(getOperation(), options))) signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToMemrefPass() {
return std::make_unique<HloLegalizeToMemrefPass>();
}
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, MhloDialect * /*dialect*/) {
CustomCallOp::attachInterface<CustomCallOpInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
DynamicReshapeOp::attachInterface<DynamicReshapeOpInterface>(*ctx);
DynamicBroadcastInDimOp::attachInterface<DynamicBroadcastInDimOpInterface>(
*ctx);
});
}
} // namespace mhlo
} // namespace mlir