blob: ad54a3633e30bbb8470b06b708dbf17e3a3e863b [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass prepares for legalization to the TFLite dialect by
// converting Tensorlist operations in TensorFlow dialect into operations that
// can be legalized to TensorFlow Lite dialect with simple replacements. The
// newly created operations are in the TensorFlow dialect if the operation can
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
// is used.
#include <climits>
#include <cstdint>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#define DEBUG_TYPE "tf-tfl-legalization"
//===----------------------------------------------------------------------===//
// The actual LowerStaticTensorList Pass.
//
namespace mlir {
namespace {
class TensorListPatternRewriter : public PatternRewriter {
public:
explicit TensorListPatternRewriter(FuncOp fn)
: PatternRewriter(fn.getContext()) {}
Operation *createOperation(const OperationState &state) override {
return OpBuilder::createOperation(state);
}
};
/// Lower TensorList ops in functions for subsequent legalization.
// TODO(haoliang): Use DialectConversion infra to simplify the rewriting
// process.
struct LowerStaticTensorListPass
: public ModulePass<LowerStaticTensorListPass> {
void runOnModule() override;
// Apply type and op changes within a function.
LogicalResult RewriteFunction(FuncOp func,
TensorListPatternRewriter *rewriter);
// Changes the function type of `cond_func` and `body_func`, and the result
// type of the `WhileOp`.
LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
};
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
ArrayRef<int64_t> shape, int32_t val) {
auto type = rewriter->getTensorType(shape, rewriter->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
return rewriter->create<ConstantOp>(op->getLoc(), type, attr);
}
Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
Value *shape_tensor, int32_t val) {
auto scalar_val = CreateI32SplatConst(op, rewriter, {}, val);
return rewriter->create<TF::FillOp>(
op->getLoc(), rewriter->getTensorType({-1}, rewriter->getIntegerType(32)),
shape_tensor, scalar_val);
}
struct ConvertTFTensorListSetItem : public RewritePattern {
explicit ConvertTFTensorListSetItem(MLIRContext *context)
: RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1,
context) {}
// This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then
// expands the dimension of the `$item`, followed by another slice of the
// remaining rows starting from `$index` + 1. Lastly it concatenates the
// three parts together.
// On a high level, it's doing something like:
// def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
// (Concat
// concat_dim = 0,
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TF::TensorListSetItemOp tf_op = cast<TF::TensorListSetItemOp>(op);
auto input = tf_op.input_handle();
auto shape_dtype = rewriter.getIntegerType(32);
auto input_rank = rewriter.create<TF::RankOp>(
op->getLoc(), rewriter.getTensorType({}, shape_dtype), input);
auto item = tf_op.item();
auto item_rank = rewriter.create<TF::RankOp>(
op->getLoc(), rewriter.getTensorType({}, shape_dtype), item);
// Prepare the start position for the first slice op, which is [0, 0, ..,
// 0].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto position_shape = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
scalar_zero);
// Fill all 0s into the first position tensor.
auto first_start_position =
CreateI32SplatTensor(op, &rewriter, position_shape, 0);
// Prepare the start position for the second slice op, which is
// [index + 1, 0, 0 .. 0].
// Calculate the first dimension, which is index + 1.
auto index = tf_op.index();
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
auto begin = rewriter.create<TF::AddOp>(
op->getLoc(), rewriter.getTensorType(shape_dtype), index,
CreateI32SplatConst(op, &rewriter, {1}, 1));
// Followed by the first dimension `begin`, are `item_rank` of 0s.
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
scalar_zero);
auto partial_second_start_position =
CreateI32SplatTensor(op, &rewriter, item_position_shape, 0);
auto position_type = first_start_position->getType();
// Concatenate `begin` with the remaining 0s.
auto second_start_position = rewriter.create<TF::ConcatOp>(
op->getLoc(), position_type, scalar_zero,
ArrayRef<Value *>({begin, partial_second_start_position}),
rewriter.getI64IntegerAttr(2));
// Create the size parameter for the first slice op, which is [index, -1,
// -1, .., -1].
auto size1_leading_dim = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), vector_type, index, scalar_zero);
auto partial_size1 =
CreateI32SplatTensor(op, &rewriter, item_position_shape, -1);
auto size1 = rewriter.create<TF::ConcatOp>(
op->getLoc(), position_type, scalar_zero,
ArrayRef<Value *>({size1_leading_dim, partial_size1}),
rewriter.getI64IntegerAttr(2));
// Create the size parameter for the second slice, which is [-1, -1, ..,
// -1].
auto size2 = CreateI32SplatTensor(op, &rewriter, position_shape, -1);
// Create two slice ops.
auto element_type = input->getType().cast<TensorType>().getElementType();
auto unranked_tensor = rewriter.getTensorType(element_type);
auto slice1 = rewriter.create<TF::SliceOp>(
op->getLoc(), unranked_tensor, input, first_start_position, size1);
auto slice2 = rewriter.create<TF::SliceOp>(
op->getLoc(), unranked_tensor, input, second_start_position, size2);
// Expand the dimension of item so that it will have the same rank with
// input.
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), unranked_tensor, item, scalar_zero);
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input->getType(), scalar_zero,
ArrayRef<Value *>({slice1, expanded_item, slice2}),
rewriter.getI64IntegerAttr(3));
return matchSuccess();
}
};
struct ConvertTFTensorListReserve : public RewritePattern {
explicit ConvertTFTensorListReserve(MLIRContext *context)
: RewritePattern(TF::TensorListReserveOp::getOperationName(), 1,
context) {}
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TF::TensorListReserveOp tf_op = cast<TF::TensorListReserveOp>(op);
auto element_shape = tf_op.element_shape();
auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
auto num_elements = tf_op.num_elements();
Type element_dtype = tf_op.element_dtype();
int64_t result_rank = -1; // -1 means unknown result rank.
Type result_type = rewriter.getTensorType(element_dtype);
if (auto element_type = tf_op.element_type().dyn_cast<RankedTensorType>()) {
result_rank = element_type.getRank() + 1;
// If element type is ranked, then result type will have unknown leading
// dimension and element shape for the following dimensions.
//
// Note: leading dim is not inferred here even if num_elements input is a
// constant.
SmallVector<int64_t, 4> result_shape = {-1};
ArrayRef<int64_t> shape = element_type.getShape();
result_shape.append(shape.begin(), shape.end());
result_type = rewriter.getTensorType(result_shape, element_dtype);
}
// The output shape of the result tensor should be [num_elements +
// element_shape].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
scalar_zero);
// Create a 1-D RankedTensorType for result's shape. Number of elements in
// it is equal to the rank of the result, if known. Otherwise, the number of
// elements are unknown and represented with -1. In both cases, we can
// specify dimension using rank of the result.
Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype);
auto list_shape = rewriter.create<TF::ConcatOp>(
op->getLoc(), shape_type, scalar_zero,
ArrayRef<Value *>({leading_dim, element_shape}),
rewriter.getI64IntegerAttr(2));
// Create a zero-initialized constant tensor that has the same type
// as specified by element_dtype.
auto zero_type = rewriter.getTensorType({}, element_dtype);
auto zero_attr = rewriter.getZeroAttr(zero_type);
auto zero = rewriter.create<ConstantOp>(op->getLoc(), zero_type, zero_attr);
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
return matchSuccess();
}
};
} // namespace
namespace TFL {
namespace {
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
} // namespace
} // namespace TFL
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
TF::WhileOp *while_op) {
SmallVector<Type, 8> unranked_argument_types;
for (const auto &operand : while_op->getOperands()) {
unranked_argument_types.push_back(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
}
auto *context = &getContext();
auto module = getModule();
FuncOp cond_func = module.lookupSymbol<FuncOp>(while_op->cond());
FuncOp body_func = module.lookupSymbol<FuncOp>(while_op->body());
if (cond_func) {
// Change `cond_func`'s argument types to `unranked_argument_types`.
cond_func.setType(FunctionType::get(
unranked_argument_types, cond_func.getType().getResults(), context));
// Change the argument type for the first block.
Block &cond_first_bb = cond_func.front();
for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) {
cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
}
}
if (body_func) {
SmallVector<Type, 8> updated_result_types;
for (int i = 0; i < body_func.getType().getNumResults(); ++i) {
auto result_type = body_func.getType().getResult(i);
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
// For variant type, use the corresponding unranked type.
updated_result_types.push_back(unranked_argument_types[i]);
} else {
updated_result_types.push_back(result_type);
}
}
// Change `body_func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type
// derived from the corresponding argument.
body_func.setType(FunctionType::get(unranked_argument_types,
updated_result_types, context));
// Change the argument type for the first block.
Block &body_first_bb = body_func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
}
}
for (int i = 0; i < while_op->getNumOperands(); ++i) {
auto operand = while_op->getOperand(i);
auto result = while_op->getResult(i);
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
result->setType(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
}
}
return success();
}
LogicalResult LowerStaticTensorListPass::RewriteFunction(
FuncOp func, TensorListPatternRewriter *rewriter) {
auto *context = &getContext();
for (Block &block : func) {
// Buffer the op pointers inside the current block into a vector, since
// the block iterator might be invalidated if we rewrite ops during looping.
std::vector<Operation *> ops_in_block;
for (Operation &op : block) {
ops_in_block.push_back(&op);
}
for (Operation *op : ops_in_block) {
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
auto c = TFL::ConvertTFTensorListFromTensor(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
tf_op.element_dtype().isF64() ||
tf_op.element_dtype().isInteger(8) ||
tf_op.element_dtype().isInteger(16) ||
tf_op.element_dtype().isInteger(32) ||
tf_op.element_dtype().isInteger(64))) {
return tf_op.emitError(
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
"or 16-bit/32-bit/64-bit "
"float type during TF Lite transformation pass");
}
auto c = ConvertTFTensorListReserve(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
auto c = TFL::ConvertTFTensorListGetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
auto c = ConvertTFTensorListSetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
auto c = TFL::ConvertTFTensorListStack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(&tf_op);
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
tf_op.getResult()->setType(tf_op.getOperand()->getType());
}
}
}
return success();
}
void LowerStaticTensorListPass::runOnModule() {
// TODO(haoliang): currently we process the `main` function first, and the
// remaining functions may be processed in arbitrary order. However, this will
// have a potential issue when one function taking a `DT_VARIANT` is processed
// before the function that produces the `DT_VARIANT`. We need to carefully
// order the functions to be processed.
std::vector<FuncOp> funcs_in_module;
for (auto func : getModule().getOps<FuncOp>()) {
// Always place the main function to be the first in the list.
if (func.getName() == "main") {
funcs_in_module.insert(funcs_in_module.begin(), func);
} else {
funcs_in_module.push_back(func);
}
}
for (auto func : funcs_in_module) {
TensorListPatternRewriter rewriter(func);
if (failed(RewriteFunction(func, &rewriter))) {
signalPassFailure();
return;
}
}
}
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass.
ModulePassBase *TFL::CreateLowerStaticTensorListPass() {
return new LowerStaticTensorListPass();
}
static PassRegistration<LowerStaticTensorListPass> pass(
"tfl-lower-static-tensor-list",
"Lower TensorList ops within TensorFlow Lite dialect");
} // namespace mlir