blob: 54f611fee4266f2fe387c84f254213f7e1b0b78d [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for translating mixed IR to buffer form.
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" // from @llvm-project
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
struct BufferizeConstantOp : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::ConstantOp op, OpAdaptor /*adaptor*/,
ConversionPatternRewriter &rewriter) const final {
// We only need to bufferize tensor constants.
Location loc = op.getLoc();
auto result_type = op.getType().dyn_cast<RankedTensorType>();
int64_t result_rank = result_type.getRank();
if (!result_type || !result_type.hasStaticShape() || result_rank > 1)
return failure();
auto element_type = result_type.getElementType();
auto memref_type = MemRefType::get(result_type.getShape(), element_type);
auto elements_attr = op.getValue().cast<DenseElementsAttr>();
// arith.constant doesn't handle scalar complex types.
// TODO(kramerb): Should this use materializeConstant instead?
auto make_constant = [&](Attribute attr, Type type) -> Value {
if (complex::ConstantOp::isBuildableWith(attr, type))
return rewriter.create<complex::ConstantOp>(loc, type,
attr.cast<ArrayAttr>());
return rewriter.create<arith::ConstantOp>(loc, attr);
};
if (result_rank == 0) {
Value buffer = rewriter.create<memref::AllocOp>(loc, memref_type);
Value constant =
make_constant(elements_attr.getValues<Attribute>()[0], element_type);
rewriter.create<memref::StoreOp>(loc, constant, buffer);
rewriter.replaceOp(op, {buffer});
return success();
}
Value buffer = rewriter.create<memref::AllocaOp>(loc, memref_type);
bool all_same_elems = elements_attr.isSplat();
Value value;
if (all_same_elems)
value = make_constant(elements_attr.getSplatValue<mlir::Attribute>(),
element_type);
for (auto &en : llvm::enumerate(elements_attr.getValues<Attribute>())) {
if (!all_same_elems) value = make_constant(en.value(), element_type);
Value index = rewriter.create<arith::ConstantIndexOp>(loc, en.index());
rewriter.create<memref::StoreOp>(loc, value, buffer, index);
}
rewriter.replaceOp(op, {buffer});
return success();
}
};
struct BufferizeAndConvertMinimumBroadcastShapesOp
: public OpConversionPattern<chlo::MinimumBroadcastShapesOp> {
using OpConversionPattern<
chlo::MinimumBroadcastShapesOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
chlo::MinimumBroadcastShapesOp broadcast_shapes_op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = broadcast_shapes_op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
Value zero = lb.create<arith::ConstantIndexOp>(0);
SmallVector<Value> shapes = adaptor.shapes();
size_t k = shapes.size();
SmallVector<Value> ranks;
ranks.reserve(k);
// Determine the maximum rank of the operands.
Value max_rank;
for (size_t i = 0; i < k; ++i) {
Value rank = lb.create<memref::DimOp>(loc, shapes[i], zero);
ranks.push_back(rank);
if (i) {
Value rank_is_greater = lb.create<arith::CmpIOp>(
arith::CmpIPredicate::ugt, ranks[i], max_rank);
max_rank =
lb.create<arith::SelectOp>(rank_is_greater, ranks[i], max_rank);
} else {
max_rank = ranks[0];
}
}
// Allocate buffers for the return values and initialize them with 1's.
SmallVector<Value> result_shapes;
result_shapes.reserve(k);
auto result_type =
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
Value one = lb.create<arith::ConstantIndexOp>(1);
for (size_t i = 0; i < k; ++i) {
// We assume the buffer will be small, so we allocate it on the stack.
// TODO(b/181654096): Replace AllocaOp with AllocOp.
auto result = lb.create<memref::AllocaOp>(result_type, ranks[i]);
lb.create<scf::ForOp>(zero, ranks[i], one, llvm::None,
[&one, &result](OpBuilder &b, Location l, Value idx,
ValueRange /*vr*/) {
b.create<memref::StoreOp>(l, one, result, idx);
b.create<scf::YieldOp>(l, llvm::None);
});
result_shapes.push_back(result);
}
// Iterate through the dimensions and determine which adjacent dimensions
// can be combined. Keep a running product of the dimensions that can be
// combined as iteration variable (initialized to 1), and the current
// dimension offset in the result shapes. We iterate through the shapes
// backward, because the broadcasting semantics mean that the last
// dimensions of each shape (the least significant ones) are matched
// together.
Value two = lb.create<arith::ConstantIndexOp>(2);
Value max_rank_plus_two = lb.create<arith::AddIOp>(loc, max_rank, two);
Value constant_false =
lb.create<arith::ConstantOp>(lb.getI1Type(), lb.getBoolAttr(false));
SmallVector<Value> init_values;
init_values.reserve(k + 3);
// Initially, all values are marked as not broadcasted.
for (int i = 0; i < k; ++i) {
init_values.push_back(constant_false);
}
// The running product is initially 1.
init_values.push_back(one);
// The current dimension offset is initially 0.
init_values.push_back(zero);
// Whether the broadcasting is invalid.
init_values.push_back(constant_false);
// Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is
// used as an offset from the end of each shape vector. We iterate until
// max_rank + 1 to handle the case that we have a running_product > 1 left
// when we have processed all dimensions of the largest shape.
auto main_loop = lb.create<scf::ForOp>(
one, max_rank_plus_two, one, init_values,
[&](OpBuilder &b, Location l, Value v, ValueRange vr) {
// 'same_size' should track what the size of the dimension is to which
// the 1-sized dimensions are broadcasted. If all of the dimensions
// are 1, it will stay 1.
Value same_size = one;
// 'result_dimensions' stores the current dimension with an offset of
// 'leading_ones' to make it easier to check whether we are in-bounds
// with respect to the "real" shape with leading 1's removed.
SmallVector<Value> result_dimensions;
result_dimensions.reserve(k);
// 'no_broadcasting' stores boolean flags that encode whether the
// corresponding shape does not need broadcasting at the current
// position.
SmallVector<Value> no_broadcasting;
no_broadcasting.reserve(k + 3);
// The first k loop carried values are the previous broadcasting
// state.
auto prev_no_broadcasting = vr.take_front(k);
// This loop checks which shapes need broadcasting at the current
// dimension. A shape needs broadcasting if it is indexed out of
// bounds, or its current dimension size is 1.
Value current_dimension_has_invalid_broadcast = constant_false;
for (size_t i = 0; i < k; ++i) {
// Determine the size of the current dimension. If the dimension is
// out of bounds, we choose the value 'one'.
Value is_out_of_bounds = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ult, ranks[i], v);
Value dimension = b.create<arith::SubIOp>(l, ranks[i], v);
result_dimensions.push_back(dimension);
Value current_size =
b.create<scf::IfOp>(
l, TypeRange{b.getIndexType()}, is_out_of_bounds,
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, one);
},
[&](OpBuilder &b, Location l) {
// Using IfOp instead of SelectOp makes sure that we
// don't try to load if the dimension is out of bounds.
Value size =
b.create<memref::LoadOp>(l, shapes[i], dimension);
b.create<scf::YieldOp>(l, size);
})
.getResult(0);
// Compute whether the current dimension does require broadcasting.
Value current_size_is_not_one = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, current_size, one);
no_broadcasting.push_back(current_size_is_not_one);
Value new_same_size = b.create<arith::SelectOp>(
l, current_size_is_not_one, current_size, same_size);
Value same_size_was_not_one = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, same_size, one);
Value is_different_size = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, same_size, new_same_size);
// The broadcast is invalid if the size of the current dimension
// is not equal to the expected size, unless the expected size was
// still the initial value 1.
Value is_invalid = b.create<arith::AndIOp>(l, same_size_was_not_one,
is_different_size);
current_dimension_has_invalid_broadcast = b.create<arith::OrIOp>(
l, current_dimension_has_invalid_broadcast, is_invalid);
same_size = new_same_size;
}
// Check whether we have at least one shape that has a different
// status regarding whether it needs broadcasting at the current
// dimension versus whether it needs broadcasting at the previous
// dimension.
Value same_size_is_one = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::eq, same_size, one);
Value different_broadcasting_set = constant_false;
for (size_t i = 0; i < k; ++i) {
// If all dimensions are 1, we preserve the status whether a shape
// needs broadcasting or not, because in that case the dimension can
// just be ignored.
no_broadcasting[i] = b.create<arith::SelectOp>(
l, same_size_is_one, prev_no_broadcasting[i],
no_broadcasting[i]);
// Compare whether the current shape changes its status regarding
// whether it needs broadcasting at the current dimension.
Value broadcasting_is_different = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, prev_no_broadcasting[i],
no_broadcasting[i]);
different_broadcasting_set = b.create<arith::OrIOp>(
l, different_broadcasting_set, broadcasting_is_different);
}
Value running_product = vr[k];
Value current_dimension_offset = vr[k + 1];
// We need to stop combining dimensions if the set of shapes which
// need broadcasting at the current dimension changes compared to the
// set of shapes needing broadcasting at the previous dimension.
Value is_last_iteration = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::sgt, v, max_rank);
Value stop_combining_dimensions = b.create<arith::OrIOp>(
l, is_last_iteration, different_broadcasting_set);
auto if_stop_combining_dimensions = b.create<scf::IfOp>(
l, TypeRange{b.getIndexType(), b.getIndexType()},
stop_combining_dimensions,
[&](OpBuilder &b, Location l) {
// If the running product is not 1, add one dimension of size
// 'running_product' to each shape that didn't need
// broadcasting, otherwise add a 1 dimension if it was
// previously indexed in-bounds.
Value running_product_not_one = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::ne, running_product, one);
Value new_dimension_offset =
b.create<scf::IfOp>(
l, TypeRange{b.getIndexType()},
running_product_not_one,
[&](OpBuilder &b, Location l) {
Value new_dimension_offset = b.create<arith::AddIOp>(
l, current_dimension_offset, one);
Value minus_one =
lb.create<arith::ConstantIndexOp>(-1);
for (size_t i = 0; i < k; ++i) {
Value was_in_bounds = b.create<arith::CmpIOp>(
l, arith::CmpIPredicate::sge,
result_dimensions[i], minus_one);
Value should_store_dimension =
b.create<arith::OrIOp>(
l, was_in_bounds, prev_no_broadcasting[i]);
b.create<scf::IfOp>(
l, should_store_dimension,
[&](OpBuilder &b, Location l) {
Value output_dimension =
b.create<arith::SubIOp>(
l, ranks[i], new_dimension_offset);
// If the shape needed broadcasting at the
// previous dimension, we set the output size
// to 1, otherwise to 'running_product'.
Value output_size =
b.create<arith::SelectOp>(
l, prev_no_broadcasting[i],
running_product, one);
b.create<memref::StoreOp>(l, output_size,
result_shapes[i],
output_dimension);
b.create<scf::YieldOp>(l, llvm::None);
});
}
b.create<scf::YieldOp>(l, new_dimension_offset);
},
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, current_dimension_offset);
})
.getResult(0);
b.create<scf::YieldOp>(
l, ValueRange{same_size, new_dimension_offset});
},
[&](OpBuilder &b, Location l) {
Value new_running_product =
b.create<arith::MulIOp>(l, running_product, same_size);
b.create<scf::YieldOp>(l, ValueRange{new_running_product,
current_dimension_offset});
});
// Add the remaining results.
no_broadcasting.push_back(if_stop_combining_dimensions.getResult(0));
no_broadcasting.push_back(if_stop_combining_dimensions.getResult(1));
Value is_invalid = vr.back();
is_invalid = b.create<arith::OrIOp>(
l, is_invalid, current_dimension_has_invalid_broadcast);
no_broadcasting.push_back(is_invalid);
b.create<scf::YieldOp>(l, no_broadcasting);
});
Value is_invalid = main_loop.getResults().back();
for (size_t i = 0; i < k; ++i) {
result_shapes[i] =
RemoveLeadingOnesFrom1DMemref(lb, result_shapes[i], ranks[i]);
result_shapes[i] =
lb.create<arith::SelectOp>(is_invalid, shapes[i], result_shapes[i]);
}
rewriter.replaceOp(broadcast_shapes_op, result_shapes);
return success();
}
private:
Value CountLeadingOnes(ImplicitLocOpBuilder &lb, Value extent_memref,
Value rank) const {
// Count leading 1's. Use two iteration variables for that: one with a
// boolean flag for whether every size so far was 1, one with the number of
// leading 1's.
Value constant_true =
lb.create<arith::ConstantOp>(lb.getI1Type(), lb.getBoolAttr(true));
Value zero = lb.create<arith::ConstantIndexOp>(0);
Value one = lb.create<arith::ConstantIndexOp>(1);
auto leading_ones_loop = lb.create<scf::ForOp>(
zero, rank, one, ValueRange{constant_true, zero},
[&](OpBuilder &b, Location l, Value idx, ValueRange vr) {
auto size = b.create<memref::LoadOp>(l, extent_memref, idx);
auto is_equal_to_one =
b.create<arith::CmpIOp>(l, arith::CmpIPredicate::eq, size, one);
auto all_ones =
b.create<arith::AndIOp>(l, vr.front(), is_equal_to_one);
auto increased_value = b.create<arith::AddIOp>(l, vr.back(), one);
auto number_of_leading_ones = b.create<arith::SelectOp>(
l, all_ones, increased_value, vr.back());
b.create<scf::YieldOp>(l,
ValueRange{all_ones, number_of_leading_ones});
});
return leading_ones_loop.getResults()[1];
}
Value RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb,
Value extent_memref, Value rank) const {
Value leading_ones = CountLeadingOnes(lb, extent_memref, rank);
Value new_rank = lb.create<arith::SubIOp>(rank, leading_ones);
auto result_type =
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
// We cannot use SubView here to return a MemRef with 'leading_ones' as
// offset, because that also changes the size, so the result type would need
// to have an affine map to change the layout. This is incompatible to our
// other MemRef types without affine map. So instead we just allocate
// another buffer of the desired size and copy the elements over. We assume
// the buffer will be small, so we allocate it on the stack.
// TODO(b/181654096): Replace AllocaOp with AllocOp.
Value result = lb.create<memref::AllocaOp>(result_type, new_rank);
Value zero = lb.create<arith::ConstantIndexOp>(0);
Value one = lb.create<arith::ConstantIndexOp>(1);
lb.create<scf::ForOp>(
zero, new_rank, one, llvm::None,
[&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) {
Value idx_with_offset = b.create<arith::AddIOp>(l, idx, leading_ones);
auto size =
b.create<memref::LoadOp>(l, extent_memref, idx_with_offset);
b.create<memref::StoreOp>(l, size, result, idx);
b.create<scf::YieldOp>(l, llvm::None);
});
return result;
}
};
struct BufferizeJITExecuteOp
: public OpConversionPattern<tf_framework::JITExecuteOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
tf_framework::JITExecuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type, 2> result_types;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
result_types))) {
return failure();
}
rewriter.replaceOpWithNewOp<tf_framework::JITExecuteOp>(
op, result_types, adaptor.getOperands(), op->getAttrs());
return success();
}
};
} // namespace
void populateExtraBufferizePatterns(
MLIRContext *context, bufferization::BufferizeTypeConverter *converter,
RewritePatternSet *patterns) {
// clang-format off
patterns->add<
BufferizeAndConvertMinimumBroadcastShapesOp,
BufferizeConstantOp,
BufferizeJITExecuteOp
>(*converter, context);
// clang-format on
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir