blob: 886e416db7dd322d3def8f7d4e6a7142d363b435 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for translating mixed IR to buffer form.
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Transforms/rewriters.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace {
struct BufferizeConstantOp : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::ConstantOp op, OpAdaptor /*adaptor*/,
ConversionPatternRewriter &rewriter) const final {
// We only need to bufferize tensor constants.
Location loc = op.getLoc();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
int64_t resultRank = resultType.getRank();
if (!resultType || !resultType.hasStaticShape() || resultRank > 1)
return failure();
auto elementType = resultType.getElementType();
auto memrefType = MemRefType::get(resultType.getShape(), elementType);
auto elementsAttr = op.getValue().cast<DenseElementsAttr>();
// arith.constant doesn't handle scalar complex types.
// TODO(kramerb): Should this use materializeConstant instead?
auto makeConstant = [&](Attribute attr, Type type) -> Value {
if (complex::ConstantOp::isBuildableWith(attr, type))
return rewriter.create<complex::ConstantOp>(loc, type,
attr.cast<ArrayAttr>());
return rewriter.create<arith::ConstantOp>(loc, attr);
};
if (resultRank == 0) {
Value buffer = rewriter.create<memref::AllocOp>(loc, memrefType);
Value constant =
makeConstant(elementsAttr.getValues<Attribute>()[0], elementType);
rewriter.create<memref::StoreOp>(loc, constant, buffer);
rewriter.replaceOp(op, {buffer});
return success();
}
Value buffer = rewriter.create<memref::AllocaOp>(loc, memrefType);
bool allSameElems = elementsAttr.isSplat();
Value value;
if (allSameElems)
value = makeConstant(elementsAttr.getSplatValue<mlir::Attribute>(),
elementType);
for (auto &en : llvm::enumerate(elementsAttr.getValues<Attribute>())) {
if (!allSameElems) value = makeConstant(en.value(), elementType);
Value index = rewriter.create<arith::ConstantIndexOp>(loc, en.index());
rewriter.create<memref::StoreOp>(loc, value, buffer, index);
}
rewriter.replaceOp(op, {buffer});
return success();
}
};
struct BufferizeAndConvertMinimumBroadcastShapesOp
: public OpConversionPattern<chlo::MinimumBroadcastShapesOp> {
using OpConversionPattern<
chlo::MinimumBroadcastShapesOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
chlo::MinimumBroadcastShapesOp broadcastShapesOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = broadcastShapesOp.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
Value zero = lb.create<arith::ConstantIndexOp>(0);
SmallVector<Value> shapes = adaptor.shapes();
size_t k = shapes.size();
SmallVector<Value> ranks;
ranks.reserve(k);
// Determine the maximum rank of the operands.
Value max_rank;
for (size_t i = 0; i < k; ++i) {
Value rank = lb.create<memref::DimOp>(loc, shapes[i], zero);
ranks.push_back(rank);
if (i) {
Value rankIsGreater = lb.create<arith::CmpIOp>(
arith::CmpIPredicate::ugt, ranks[i], max_rank);
max_rank =
lb.create<arith::SelectOp>(rankIsGreater, ranks[i], max_rank);
} else {
max_rank = ranks[0];
}
}
// Allocate buffers for the return values and initialize them with 1's.
SmallVector<Value> resultShapes;
resultShapes.reserve(k);
auto resultType =
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
Value one = lb.create<arith::ConstantIndexOp>(1);
for (size_t i = 0; i < k; ++i) {
// We assume the buffer will be small, so we allocate it on the stack.
// TODO(b/181654096): Replace AllocaOp with AllocOp.
auto result = lb.create<memref::AllocaOp>(resultType, ranks[i]);
lb.create<scf::ForOp>(zero, ranks[i], one, llvm::None,
[&one, &result](OpBuilder &b, Location l, Value idx,
ValueRange /*vr*/) {
b.create<memref::StoreOp>(l, one, result, idx);
b.create<scf::YieldOp>(l, llvm::None);
});
resultShapes.push_back(result);
}
// Iterate through the dimensions and determine which adjacent dimensions
// can be combined. Keep a running product of the dimensions that can be
// combined as iteration variable (initialized to 1), and the current
// dimension offset in the result shapes. We iterate through the shapes
// backward, because the broadcasting semantics mean that the last
// dimensions of each shape (the least significant ones) are matched
// together.
Value two = lb.create<arith::ConstantIndexOp>(2);
Value maxRankPlusTwo = lb.create<arith::AddIOp>(loc, max_rank, two);
Value constantFalse =
lb.create<arith::ConstantOp>(lb.getI1Type(), lb.getBoolAttr(false));
SmallVector<Value> initValues;
initValues.reserve(k + 3);
// Initially, all values are marked as not broadcasted.
for (int i = 0; i < k; ++i) {
initValues.push_back(constantFalse);
}
// The running product is initially 1.
initValues.push_back(one);
// The current dimension offset is initially 0.
initValues.push_back(zero);
// Whether the broadcasting is invalid.
initValues.push_back(constantFalse);
// Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is
// used as an offset from the end of each shape vector. We iterate until
// max_rank + 1 to handle the case that we have a running_product > 1 left
// when we have processed all dimensions of the largest shape.
auto mainLoop = lb.create<scf::ForOp>(
one, maxRankPlusTwo, one, initValues,
[&](OpBuilder &b, Location l, Value v, ValueRange vr) {
// 'same_size' should track what the size of the dimension is to which
// the 1-sized dimensions are broadcasted. If all of the dimensions
// are 1, it will stay 1.
Value same_size = one;
// 'result_dimensions' stores the current dimension with an offset of
// 'leading_ones' to make it easier to check whether we are in-bounds
// with respect to the "real" shape with leading 1's removed.
SmallVector<Value> result_dimensions;
result_dimensions.reserve(k);
// 'no_broadcasting' stores boolean flags that encode whether the
// corresponding shape does not need broadcasting at the current
// position.
SmallVector<Value> noBroadcasting;
noBroadcasting.reserve(k + 3);
// The first k loop carried values are the previous broadcasting
// state.
auto prevNoBroadcasting = vr.take_front(k);
// This loop checks which shapes need broadcasting at the current
// dimension. A shape needs broadcasting if it is indexed out of
// bounds, or its current dimension size is 1.
Value currentDimensionHasInvalidBroadcast = constantFalse;
for (size_t i = 0; i < k; ++i) {
// Determine the size of the current dimension. If the dimension is
// out of bounds, we choose the value 'one'.
Value isOutOfBounds = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ult, ranks[i], v);
Value dimension = b.create<arith::SubIOp>(l, ranks[i], v);
result_dimensions.push_back(dimension);
Value currentSize =
b.create<scf::IfOp>(
l, TypeRange{b.getIndexType()}, isOutOfBounds,
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, one);
},
[&](OpBuilder &b, Location l) {
// Using IfOp instead of SelectOp makes sure that we
// don't try to load if the dimension is out of bounds.
Value size =
b.create<memref::LoadOp>(l, shapes[i], dimension);
b.create<scf::YieldOp>(l, size);
})
.getResult(0);
// Compute whether the current dimension does require broadcasting.
Value currentSizeIsNotOne = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, currentSize, one);
noBroadcasting.push_back(currentSizeIsNotOne);
Value newSameSize = b.create<arith::SelectOp>(
l, currentSizeIsNotOne, currentSize, same_size);
Value sameSizeWasNotOne = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, same_size, one);
Value isDifferentSize = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, same_size, newSameSize);
// The broadcast is invalid if the size of the current dimension
// is not equal to the expected size, unless the expected size was
// still the initial value 1.
Value isInvalid =
b.create<arith::AndIOp>(l, sameSizeWasNotOne, isDifferentSize);
currentDimensionHasInvalidBroadcast = b.create<arith::OrIOp>(
l, currentDimensionHasInvalidBroadcast, isInvalid);
same_size = newSameSize;
}
// Check whether we have at least one shape that has a different
// status regarding whether it needs broadcasting at the current
// dimension versus whether it needs broadcasting at the previous
// dimension.
Value sameSizeIsOne = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::eq, same_size, one);
Value differentBroadcastingSet = constantFalse;
for (size_t i = 0; i < k; ++i) {
// If all dimensions are 1, we preserve the status whether a shape
// needs broadcasting or not, because in that case the dimension can
// just be ignored.
noBroadcasting[i] = b.create<arith::SelectOp>(
l, sameSizeIsOne, prevNoBroadcasting[i], noBroadcasting[i]);
// Compare whether the current shape changes its status regarding
// whether it needs broadcasting at the current dimension.
Value broadcastingIsDifferent = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, prevNoBroadcasting[i],
noBroadcasting[i]);
differentBroadcastingSet = b.create<arith::OrIOp>(
l, differentBroadcastingSet, broadcastingIsDifferent);
}
Value running_product = vr[k];
Value current_dimension_offset = vr[k + 1];
// We need to stop combining dimensions if the set of shapes which
// need broadcasting at the current dimension changes compared to the
// set of shapes needing broadcasting at the previous dimension.
Value isLastIteration = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::sgt, v, max_rank);
Value stopCombiningDimensions = b.create<arith::OrIOp>(
l, isLastIteration, differentBroadcastingSet);
auto ifStopCombiningDimensions = b.create<scf::IfOp>(
l, TypeRange{b.getIndexType(), b.getIndexType()},
stopCombiningDimensions,
[&](OpBuilder &b, Location l) {
// If the running product is not 1, add one dimension of size
// 'running_product' to each shape that didn't need
// broadcasting, otherwise add a 1 dimension if it was
// previously indexed in-bounds.
Value runningProductNotOne = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, running_product, one);
Value newDimensionOffset =
b.create<scf::IfOp>(
l, TypeRange{b.getIndexType()}, runningProductNotOne,
[&](OpBuilder &b, Location l) {
Value new_dimension_offset = b.create<arith::AddIOp>(
l, current_dimension_offset, one);
Value minusOne =
lb.create<arith::ConstantIndexOp>(-1);
for (size_t i = 0; i < k; ++i) {
Value wasInBounds = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::sge,
result_dimensions[i], minusOne);
Value shouldStoreDimension =
b.create<arith::OrIOp>(l, wasInBounds,
prevNoBroadcasting[i]);
b.create<scf::IfOp>(
l, shouldStoreDimension,
[&](OpBuilder &b, Location l) {
Value outputDimension =
b.create<arith::SubIOp>(
l, ranks[i], new_dimension_offset);
// If the shape needed broadcasting at the
// previous dimension, we set the output size
// to 1, otherwise to 'running_product'.
Value outputSize = b.create<arith::SelectOp>(
l, prevNoBroadcasting[i],
running_product, one);
b.create<memref::StoreOp>(l, outputSize,
resultShapes[i],
outputDimension);
b.create<scf::YieldOp>(l, llvm::None);
});
}
b.create<scf::YieldOp>(l, new_dimension_offset);
},
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, current_dimension_offset);
})
.getResult(0);
b.create<scf::YieldOp>(
l, ValueRange{same_size, newDimensionOffset});
},
[&](OpBuilder &b, Location l) {
Value newRunningProduct =
b.create<arith::MulIOp>(l, running_product, same_size);
b.create<scf::YieldOp>(
l, ValueRange{newRunningProduct, current_dimension_offset});
});
// Add the remaining results.
noBroadcasting.push_back(ifStopCombiningDimensions.getResult(0));
noBroadcasting.push_back(ifStopCombiningDimensions.getResult(1));
Value isInvalid = vr.back();
isInvalid = b.create<arith::OrIOp>(
l, isInvalid, currentDimensionHasInvalidBroadcast);
noBroadcasting.push_back(isInvalid);
b.create<scf::YieldOp>(l, noBroadcasting);
});
Value isInvalid = mainLoop.getResults().back();
for (size_t i = 0; i < k; ++i) {
resultShapes[i] =
removeLeadingOnesFrom1DMemref(lb, resultShapes[i], ranks[i]);
resultShapes[i] =
lb.create<arith::SelectOp>(isInvalid, shapes[i], resultShapes[i]);
}
rewriter.replaceOp(broadcastShapesOp, resultShapes);
return success();
}
private:
Value countLeadingOnes(ImplicitLocOpBuilder &lb, Value extentMemref,
Value rank) const {
// Count leading 1's. Use two iteration variables for that: one with a
// boolean flag for whether every size so far was 1, one with the number of
// leading 1's.
Value constantTrue =
lb.create<arith::ConstantOp>(lb.getI1Type(), lb.getBoolAttr(true));
Value zero = lb.create<arith::ConstantIndexOp>(0);
Value one = lb.create<arith::ConstantIndexOp>(1);
auto leadingOnesLoop = lb.create<scf::ForOp>(
zero, rank, one, ValueRange{constantTrue, zero},
[&](OpBuilder &b, Location l, Value idx, ValueRange vr) {
auto size = b.create<memref::LoadOp>(l, extentMemref, idx);
auto isEqualToOne =
b.create<arith::CmpIOp>(l, arith::CmpIPredicate::eq, size, one);
auto allOnes = b.create<arith::AndIOp>(l, vr.front(), isEqualToOne);
auto increasedValue = b.create<arith::AddIOp>(l, vr.back(), one);
auto numberOfLeadingOnes =
b.create<arith::SelectOp>(l, allOnes, increasedValue, vr.back());
b.create<scf::YieldOp>(l, ValueRange{allOnes, numberOfLeadingOnes});
});
return leadingOnesLoop.getResults()[1];
}
Value removeLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb,
Value extentMemref, Value rank) const {
Value leadingOnes = countLeadingOnes(lb, extentMemref, rank);
Value newRank = lb.create<arith::SubIOp>(rank, leadingOnes);
auto resultType =
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
// We cannot use SubView here to return a MemRef with 'leading_ones' as
// offset, because that also changes the size, so the result type would need
// to have an affine map to change the layout. This is incompatible to our
// other MemRef types without affine map. So instead we just allocate
// another buffer of the desired size and copy the elements over. We assume
// the buffer will be small, so we allocate it on the stack.
// TODO(b/181654096): Replace AllocaOp with AllocOp.
Value result = lb.create<memref::AllocaOp>(resultType, newRank);
Value zero = lb.create<arith::ConstantIndexOp>(0);
Value one = lb.create<arith::ConstantIndexOp>(1);
lb.create<scf::ForOp>(
zero, newRank, one, llvm::None,
[&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) {
Value idxWithOffset = b.create<arith::AddIOp>(l, idx, leadingOnes);
auto size = b.create<memref::LoadOp>(l, extentMemref, idxWithOffset);
b.create<memref::StoreOp>(l, size, result, idx);
b.create<scf::YieldOp>(l, llvm::None);
});
return result;
}
};
} // namespace
void populateExtraBufferizePatterns(
MLIRContext *context, bufferization::BufferizeTypeConverter *converter,
RewritePatternSet *patterns) {
// clang-format off
patterns->add<
BufferizeAndConvertMinimumBroadcastShapesOp,
BufferizeConstantOp
>(*converter, context);
// clang-format on
}
} // namespace mlir