blob: 710c86f18ae6c1ee9bb4cc7a4ae5888ec1288bda [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass prepares for legalization to the TFLite dialect by
// converting Tensorlist operations in TensorFlow dialect into operations that
// can be legalized to TensorFlow Lite dialect with simple replacements. The
// newly created operations are in the TensorFlow dialect if the operation can
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
// is used.
#include <climits>
#include <cstdint>
#include <utility>
#include "absl/container/inlined_vector.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/tensor_list.h"
#define DEBUG_TYPE "tf-tfl-legalization"
//===----------------------------------------------------------------------===//
// The actual LowerStaticTensorList Pass.
//
namespace mlir {
namespace {
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerStaticTensorListPass)
LowerStaticTensorListPass() = default;
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through,
bool default_to_single_batch,
bool enable_dynamic_update_slice) {
this->allow_tensorlist_pass_through = allow_tensorlist_pass_through;
this->default_to_single_batch = default_to_single_batch;
this->enable_dynamic_update_slice = enable_dynamic_update_slice;
}
StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "tfl-lower-static-tensor-list";
}
StringRef getDescription() const final {
// This is a brief description of the pass.
return "Lower TensorList ops within TensorFlow Lite dialect";
}
void runOnOperation() override;
Option<bool> allow_tensorlist_pass_through{
*this, "allow-tensorlist-pass-through",
llvm::cl::desc(
"When specified to true, if the tensorlist ops can't be properly "
"legalized by this pass, then the IR won't be changed so that "
"tensorlist ops can pass through (default false)"),
llvm::cl::init(false)};
Option<bool> default_to_single_batch{
*this, "default-to-single-batch",
llvm::cl::desc(
"When specified to true, if the tensorlist ops has unspecified batch "
"size, this pass will assume that the batch size is one to proceed "
"tensorlist op lowering (default true)"),
llvm::cl::init(false)};
Option<bool> enable_dynamic_update_slice{
*this, "enable-dynamic-update-slice",
llvm::cl::desc("When specified to true, lower TensorListSetItem with "
"DynamicUpdateSlice op (default false)"),
llvm::cl::init(false)};
};
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
ArrayRef<int64_t> shape, int32_t val) {
RankedTensorType type =
RankedTensorType::get(shape, rewriter->getIntegerType(32));
DenseElementsAttr attr =
DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
return rewriter->create<arith::ConstantOp>(loc, type, attr);
}
Value CreateI64SplatConst(Location loc, PatternRewriter *rewriter,
ArrayRef<int64_t> shape, int64_t val) {
RankedTensorType type =
RankedTensorType::get(shape, rewriter->getIntegerType(64));
DenseElementsAttr attr =
DenseElementsAttr::get(type, rewriter->getI64IntegerAttr(val));
return rewriter->create<arith::ConstantOp>(loc, type, attr);
}
Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
Value shape_tensor, int32_t val) {
Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
return rewriter->create<TF::FillOp>(
loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
shape_tensor, scalar_val);
}
// Returns a new type by prepending the specified dimension to the shape of
// the given type if it is a ranked type.
Type PrependLeadingDimIfRanked(int64_t dim, Type type,
PatternRewriter *rewriter) {
Type dtype = getElementTypeOrSelf(type);
if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
llvm::SmallVector<int64_t, 4> shape = {dim};
shape.append(ty.getShape().begin(), ty.getShape().end());
return RankedTensorType::get(shape, dtype);
}
return type;
}
Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
PatternRewriter *rewriter) {
// If the variant type in the output handle has item shape available, use it
// to derive the output shape by setting unknown leading dimension.
// Otherwise, result type will be of unranked type.
if (handle_dtype.getSubtypes().empty()) {
return UnrankedTensorType::get(element_type);
}
return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
}
// Gets the index of tensorlist arguments which size might get changed by the
// function.
llvm::SmallSet<int, 4> GetResizedTensorListIndexes(
func::FuncOp func, const llvm::SmallSet<int, 4> &tensor_list_args) {
// `indexes` stores the argument index of tensorlists which size may get
// updated in the function.
llvm::SmallSet<int, 4> indexes;
for (BlockArgument &arg : func.getArguments()) {
if (tensor_list_args.contains(arg.getArgNumber())) {
for (const mlir::OpOperand &use : arg.getUses()) {
mlir::Operation *op = use.getOwner();
// Currently we only check if the tensorlist argument is consumed by
// `TensorListPushBack` or `TensorListResize`, since those are the only
// length-mutating ops supported in this pass.
if (llvm::isa<TF::TensorListPushBackOp>(op) ||
llvm::isa<TF::TensorListResizeOp>(op)) {
indexes.insert(arg.getArgNumber());
}
}
}
}
return indexes;
}
// Creates a slice of the tensorlist `input_list`, starting from
// [start_index, 0, ...0], with size [size, -1, ...-1].
//
// Requires that `start_index` and `size` are scalar tensors and
// `item_position_shape` is a 1-D tensor with only one element equal to the rank
// of an item in the tensorlist.
TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
Value start_index, Value size,
Value item_rank, Type result_type,
PatternRewriter *rewriter) {
// Create the start position of slice. This is done by concatenating
// `start_index` and `partial_start_position` together.
IntegerType shape_dtype = rewriter->getIntegerType(32);
RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
Value partial_start_position =
CreateI32SplatTensor(loc, rewriter, item_rank, 0);
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
loc, vector_type, start_index, scalar_zero);
auto start_position = rewriter->create<TF::ConcatOp>(
loc, position_type, scalar_zero,
ArrayRef<Value>({expanded_start_index, partial_start_position}));
// Create the slice size tensor. This is done by concatenating `size` and
// `partial_size`.
auto size_leading_dim =
rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
auto slice_size = rewriter->create<TF::ConcatOp>(
loc, position_type, scalar_zero,
ArrayRef<Value>({size_leading_dim, partial_size}));
return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
start_position, slice_size);
}
template <typename OpT>
class TensorListOpConverterBase : public OpConversionPattern<OpT> {
public:
explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
bool allow_tensorlist_pass_through,
bool default_to_single_batch)
: OpConversionPattern<OpT>::OpConversionPattern(context),
allow_tensorlist_pass_through_(allow_tensorlist_pass_through),
default_to_single_batch_(default_to_single_batch) {}
protected:
// This flag will control the behavior of error emitting during rewrite:
// 1) If it's true, then patterns will only emit errors during debug or
// tracing mode. 2) If it's false, then patterns will emit standard errors
// when there is a rewrite failure.
bool allow_tensorlist_pass_through_;
// This flag will control the behavior of setting the batch size one when the
// given batch size is None in order to force to proceed the tensor list op
// lowerings.
bool default_to_single_batch_;
};
// Converts tf.Const containing variant of type TensorList to a tensor of
// primitive element types. Each of the individual tensor in the list is
// converted to an ElementsAttr and then those are packed together using
// tf.Pack op.
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::ConstOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Verify that the opaque elements attribute contains tensor of type variant
// and scalar shape. The variant type should hold a TensorList.
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
if (!opaque_attr) return failure();
tensorflow::Tensor tensor;
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
return failure();
if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
return failure();
const tensorflow::TensorList *list =
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
if (!list) return failure();
// Verify output type is variant and contains exactly one ranked subtypes.
auto variant_ty =
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
if (!variant_ty) return failure();
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
if (subtypes.size() != 1) return failure();
RankedTensorType list_element_ty =
subtypes.front().dyn_cast<RankedTensorType>();
if (!list_element_ty) return failure();
// Extract tensor elements for the TensorList and construct result type
// based on the number of elements and element shape.
const std::vector<tensorflow::Tensor> &tensors = list->tensors();
llvm::SmallVector<int64_t, 4> result_shape = {
static_cast<int64_t>(tensors.size())};
result_shape.append(list_element_ty.getShape().begin(),
list_element_ty.getShape().end());
auto result_ty =
RankedTensorType::get(result_shape, list_element_ty.getElementType());
// If the list is empty, directly create the final result instead of
// creating the tf.Pack op. tf.Pack op requires at least one operand.
if (tensors.empty()) {
tensorflow::Tensor tensor(list->element_dtype,
tensorflow::TensorShape(result_shape));
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return failure();
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
return success();
}
// Extract individual tensor list element and combine them using the tf.Pack
// op.
Location loc = op.getLoc();
llvm::SmallVector<Value, 4> values;
values.reserve(tensors.size());
for (const tensorflow::Tensor &tensor : tensors) {
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return failure();
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
values.push_back(value);
}
rewriter.replaceOpWithNewOp<TF::PackOp>(
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
return success();
}
};
struct ConvertTensorListSetItem
: public OpConversionPattern<TF::TensorListSetItemOp> {
explicit ConvertTensorListSetItem(MLIRContext *context,
bool enable_dynamic_update_slice = false)
: OpConversionPattern<TF::TensorListSetItemOp>(context),
enable_dynamic_update_slice(enable_dynamic_update_slice) {}
LogicalResult matchAndRewrite(
TF::TensorListSetItemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (enable_dynamic_update_slice) {
return matchAndRewriteImplWithDynamicUpdateSlice(op, adaptor, rewriter);
} else {
return matchAndRewriteImplWithSliceAndConcat(op, adaptor, rewriter);
}
}
// This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then
// expands the dimension of the `$item`, followed by another slice of the
// remaining rows starting from `$index` + 1. Lastly it concatenates the
// three parts together.
// On a high level, it's doing something like:
// def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
// (Concat
// concat_dim = 0,
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
LogicalResult matchAndRewriteImplWithSliceAndConcat(
TF::TensorListSetItemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value input = adaptor.getOperands()[0];
Value index = adaptor.getOperands()[1];
Value item = adaptor.getOperands()[2];
IntegerType shape_dtype = rewriter.getIntegerType(32);
auto item_rank = rewriter.create<TF::RankOp>(
loc, RankedTensorType::get({}, shape_dtype), item);
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Calculate `index` + 1, which is used to generate the start position for
// the second slice op.
auto suffix_start =
rewriter.create<TF::AddOp>(loc, index.getType(), index,
CreateI32SplatConst(loc, &rewriter, {}, 1));
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
// Create two slice ops.
Type element_type = input.getType().cast<TensorType>().getElementType();
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
TF::SliceOp slice1 =
CreateSliceOpForTensorList(loc, /*input_list=*/input,
/*start_index=*/scalar_zero,
/*size=*/index,
/*item_rank=*/item_position_shape,
/*result_type=*/unranked_tensor, &rewriter);
TF::SliceOp slice2 =
CreateSliceOpForTensorList(loc, /*input_list=*/input,
/*start_index=*/suffix_start,
/*size=*/scalar_minus_one,
/*item_rank=*/item_position_shape,
/*result_type=*/unranked_tensor, &rewriter);
// Expand the dimension of item so that it will have the same rank with
// input.
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op.getLoc(), unranked_tensor, item, scalar_zero);
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input.getType(), scalar_zero,
ArrayRef<Value>({slice1, expanded_item, slice2}));
return success();
}
// This function rewrites the original op into a XLA DynamicUpdateSlice op.
// |item| is expanded to have the same dimension as input_handle and
// |index| is expanded to [index, 0, 0, ...] as the indices to input_handle.
// On a high level, it's doing something like:
// def : Pat<(TensorListSetItem($input_handle, $index, $item)),
// (XlaDynamicUpdateSlice($input_handle, ExpandDims($item, 0),
// Concat(ExpandDims($index, 0), [0, 0, 0, ...])))>
LogicalResult matchAndRewriteImplWithDynamicUpdateSlice(
TF::TensorListSetItemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value input = adaptor.getOperands()[0];
Value index = adaptor.getOperands()[1];
Value item = adaptor.getOperands()[2];
IntegerType shape_dtype = rewriter.getIntegerType(32);
auto item_rank = rewriter.create<TF::RankOp>(
loc, RankedTensorType::get({}, shape_dtype), item);
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Concat(ExpandDims(index, 0), [0, 0, 0, ...])
RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
Value partial_index =
CreateI32SplatTensor(loc, &rewriter, item_position_shape, 0);
RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
auto expanded_index =
rewriter.create<TF::ExpandDimsOp>(loc, vector_type, index, scalar_zero);
auto start_position = rewriter.create<TF::ConcatOp>(
loc, position_type, scalar_zero,
ArrayRef<Value>({expanded_index, partial_index}));
// Expand the dimension of item so that it will have the same rank with
// input.
// ExpandDims(item, 0)
Type element_type = input.getType().cast<TensorType>().getElementType();
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op.getLoc(), unranked_tensor, item, scalar_zero);
// Update the element with XlaDynamicUpdateSliceOp.
rewriter.replaceOpWithNewOp<TF::XlaDynamicUpdateSliceOp>(
op, input.getType(), input, expanded_item, start_position);
return success();
}
bool enable_dynamic_update_slice;
};
// Rewrites op of the template type initializing a TensorList with a list of ops
// to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method.
template <typename OpT>
struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
using TensorListOpConverterBase<OpT>::default_to_single_batch_;
// Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with.
virtual Value GetNumElements(OpT op, ValueRange operands,
PatternRewriter *rewriter) const = 0;
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
LogicalResult matchAndRewrite(
OpT op, typename OpT::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dtype = op.element_dtype();
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
const char *error_info =
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
Value element_shape = adaptor.getOperands()[0];
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
// If the `element_shape` is a scalar, we try to acquire its shape by
// looking at the first `TensorListSetItemOp` writing to this tensor list.
// Here we assume that the element_shape won't be changed before calling
// the first `TensorListSetItemOp`.
if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasRank() && shaped_type.getRank() == 0) {
bool element_shape_acquired = false;
auto uses = op.getResult().getUses();
for (auto &use : llvm::make_early_inc_range(uses)) {
if (TF::TensorListSetItemOp set_op =
llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
element_shape = rewriter.create<TF::ShapeOp>(
op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
set_op.item());
element_shape_acquired = true;
} else if (TF::WhileOp while_op =
llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
// Tensorlist is passed into a while loop, check inside the body
// function.
auto inside_uses = while_op.body_function()
.getArgument(use.getOperandNumber())
.getUses();
for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
if (TF::TensorListSetItemOp set_op =
llvm::dyn_cast<TF::TensorListSetItemOp>(
inside_use.getOwner())) {
if (auto shaped_type =
set_op.item().getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) {
RankedTensorType type = RankedTensorType::get(
{shaped_type.getRank()}, rewriter.getIntegerType(32));
SmallVector<Attribute, 4> shape_attr;
for (int64_t dim : shaped_type.getShape()) {
shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
}
DenseElementsAttr attr =
DenseElementsAttr::get(type, shape_attr);
element_shape = rewriter.create<arith::ConstantOp>(
op.getLoc(), type, attr);
element_shape_acquired = true;
break;
}
}
}
}
}
if (element_shape_acquired) break;
}
if (!element_shape_acquired) {
const char *error_info =
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
}
}
DenseIntElementsAttr dense_elem_attr;
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
// Note: It's technically unsafe to rewrite
// TensorListReserve(num_element, element_shape)
// to
// Fill(Concat(num_element, element_shape), 0)
// because element_shape may contain -1 to represent unknown dimension.
//
// In real world use cases (e.g. Keras RNN), `element_shape` is usually
// a constant, and the first dimension of `element_shape` is usually
// batch dimension. Currently TFLiteConverter always rewrite unknown
// batch dimension to 1, therefore we also rewrite unknown dimension in
// `element_shape` to 1 here.
//
// This workaround enables converting Keras RNN without specifying batch
// dimension. This isn't guaranteed to work, but it doesn't break any
// non-broken cases either (since it's already broken if `element_shape`
// contains -1).
// TODO(b/142096690): Support dynamic element shape and remove the
// workaround.
SmallVector<int32_t, 4> new_element_shape_values;
auto int_values = dense_elem_attr.getValues<APInt>();
for (auto it = int_values.begin(); it != int_values.end(); ++it) {
auto dim_value = (*it).getSExtValue();
if (it == int_values.begin() && dim_value == -1) {
if (!default_to_single_batch_) {
const char *error_info =
"requires element_shape to be static during TF Lite "
"transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
dim_value = 1;
}
new_element_shape_values.push_back(dim_value);
}
auto attr = DenseIntElementsAttr::get(
element_shape.getType().cast<ShapedType>(), new_element_shape_values);
auto new_element_shape = rewriter.create<arith::ConstantOp>(
op.getLoc(), element_shape.getType(), attr);
element_shape = new_element_shape;
}
int64_t result_rank = -1; // -1 means unknown result rank.
Type element_dtype = op.element_dtype();
Type result_type = UnrankedTensorType::get(element_dtype);
Value leading_dim = GetNumElements(op, adaptor.getOperands(), &rewriter);
if (auto element_type =
op.element_type().template dyn_cast<RankedTensorType>()) {
result_rank = element_type.getRank() + 1;
int64_t leading_dim_v = -1;
ElementsAttr element_attr;
if (matchPattern(leading_dim, m_Constant(&element_attr))) {
leading_dim_v = element_attr.getValues<APInt>()[0].getSExtValue();
}
SmallVector<int64_t, 4> result_shape = {leading_dim_v};
ArrayRef<int64_t> shape = element_type.getShape();
result_shape.append(shape.begin(), shape.end());
result_type = RankedTensorType::get(result_shape, element_dtype);
}
// Create a 1-D RankedTensorType for result's shape. Number of elements in
// it is equal to the rank of the result, if known. Otherwise, the number of
// elements are unknown and represented with -1. In both cases, we can
// specify dimension using rank of the result.
Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
Location loc = op.getLoc();
// Add number of elements as the prefix to the element shape to get shape of
// the output tensor.
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto list_shape = rewriter.create<TF::ConcatOp>(
loc, shape_type, scalar_zero,
ArrayRef<Value>({leading_dim, element_shape}));
// Create a zero-initialized constant tensor that has the same type
// as specified by element_dtype.
RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
Attribute zero_attr = rewriter.getZeroAttr(zero_type);
auto zero = rewriter.create<arith::ConstantOp>(loc, zero_type, zero_attr);
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
return success();
}
};
struct ConvertTensorListReserve
: public ConvertTensorListInitOp<TF::TensorListReserveOp> {
explicit ConvertTensorListReserve(MLIRContext *context,
bool allow_tensorlist_pass_through,
bool default_to_single_batch)
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through,
default_to_single_batch) {}
Value GetNumElements(TF::TensorListReserveOp op, ValueRange operands,
PatternRewriter *rewriter) const override {
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
Value num_elements = operands[1];
IntegerAttr attr;
if (matchPattern(num_elements, m_Constant(&attr))) {
return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
}
if (auto const_op = num_elements.getDefiningOp<TF::ConstOp>()) {
return CreateI32SplatConst(op->getLoc(), rewriter, {1},
(*const_op.value()
.cast<DenseElementsAttr>()
.getValues<APInt>()
.begin())
.getSExtValue());
}
return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
scalar_zero);
}
};
// Note that we ignore the second operand `max_num_elements` as we don't have
// any restrictions on the number of elements we can support. So this may
// have a different behavior compared to TensorFlow in case of errors.
struct ConvertEmptyTensorList
: public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
explicit ConvertEmptyTensorList(MLIRContext *context,
bool allow_tensorlist_pass_through,
bool default_to_single_batch)
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through,
default_to_single_batch) {}
Value GetNumElements(TF::EmptyTensorListOp op, ValueRange operands,
PatternRewriter *rewriter) const override {
return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
}
};
struct ConvertTensorListPushBack
: public OpConversionPattern<TF::TensorListPushBackOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::TensorListPushBackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input_handle = adaptor.getOperands()[0];
Value item = adaptor.getOperands()[1];
// Expand the shape of the item so that it will have rank same as the input
// tensor and it is compatible for the Concat Op.
Type expanded_item_type =
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
Location loc = op.getLoc();
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
loc, expanded_item_type, item, scalar_zero);
Type elem_type = getElementTypeOrSelf(item);
auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
// Concatenate tensor stored in the input handle with the expanded item to
// get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, result_type, scalar_zero,
ArrayRef<Value>({input_handle, expanded_item}));
return success();
}
};
// Rewrites `TensorListResize` op into a functional `If` op and several basic
// TF ops to match the op semantics of Tensorflow. Basically, it does:
// 1) If the requested size is smaller or equal than the input tensorlist's
// size, rewrite it to a Slice op so that only the first 'size' rows are
// returned. 2) If the requested size is larger than the input tensorlist's
// size. We need to create an additional tensorlist with 'size - input_size'
// elements, and append it to the end of the input tensorlist.
struct ConvertTensorListResize
: public OpConversionPattern<TF::TensorListResizeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::TensorListResizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input_handle = adaptor.getOperands()[0];
Value size = adaptor.getOperands()[1];
Location loc = op.getLoc();
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Compute the input tensorlist's length and store it in `input_size`.
IntegerType shape_dtype = rewriter.getIntegerType(32);
auto input_size = rewriter.create<TF::TensorListLengthOp>(
loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
// Infer result type of this op based on TF's shape inference result.
Type elem_type = getElementTypeOrSelf(input_handle);
auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
// Compute the difference of `size` and `input_size`, and store it in
// `size_diff`, which is then consumed by `if_cond`.
auto size_diff = rewriter.create<TF::SubOp>(
loc, RankedTensorType::get({}, shape_dtype), size, input_size);
auto if_cond = rewriter.create<TF::GreaterOp>(
loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
scalar_zero);
// Build the argument/result types for if branch function.
auto input_shape = rewriter.create<TF::ShapeOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
size_diff.getType(), size.getType()};
Type branch_result_type[] = {result_type};
auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
branch_result_type);
// Create functions in a higher scope before restoring the insertion point.
// Additionally, create the SymbolTable before further modifying the module.
auto original_point = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(op->getParentOfType<func::FuncOp>());
SymbolTable manager(op->getParentOfType<ModuleOp>());
// Constructs `then_branch`, which is executed when `if_cond` evaluates to
// true.
auto then_branch_op =
rewriter.create<func::FuncOp>(loc, "cond_true", func_type);
CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
&rewriter);
then_branch_op.setVisibility(func::FuncOp::Visibility::Private);
// Constructs `else_branch`, which is executed when `if_cond` evaluates to
// false.
auto else_branch_op =
rewriter.create<func::FuncOp>(loc, "cond_false", func_type);
CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
&rewriter);
else_branch_op.setVisibility(func::FuncOp::Visibility::Private);
// Inserts the two blocks' names into the symbol table held by the module.
// Using SymbolTable will ensure that the inserted symbol names are
// unique.
manager.insert(then_branch_op);
manager.insert(else_branch_op);
rewriter.restoreInsertionPoint(original_point);
rewriter.replaceOpWithNewOp<TF::IfOp>(
op, result_type, if_cond,
/*input=*/
ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
/*then_branch=*/
mlir::SymbolRefAttr::get(then_branch_op),
/*else_branch=*/
mlir::SymbolRefAttr::get(else_branch_op),
/*is_stateless=*/rewriter.getBoolAttr(true));
return success();
}
private:
// When the input tensorlist's size is smaller than the requested size,
// then branch is executed.
// Create a new tensorlist of size 'size - input_size' and concat it
// with the input tensorlist.
void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
Type result_type, func::FuncOp branch_func,
ConversionPatternRewriter *rewriter) const {
auto guard = OpBuilder::InsertionGuard(*rewriter);
auto inputs = branch_func.getFunctionType().getInputs();
Block *block = rewriter->createBlock(
&branch_func.getBody(), branch_func.begin(), inputs,
SmallVector<Location>(inputs.size(), branch_func.getLoc()));
auto input_shape = block->getArgument(1);
auto size_diff = block->getArgument(2);
auto input = block->getArgument(0);
Location loc = resize_op.getLoc();
// Get the element shape by slicing from index 1 in the input shape.
Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
auto elem_shape = rewriter->create<TF::SliceOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
slice_size);
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
loc, resize_op.output_handle().getType(), elem_shape, size_diff);
// `ConcatOp` expects non-variant-typed input. Insert a
// `TensorListStackOp` here to convert type from variant to non-variant.
// Note that we are using the same `result_type` for both the
// `TensorListStackOp` and `ConcatOp`, since the first dimension of the
// shape specified by `result_type` is -1.
auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
loc, result_type, extended_part,
/*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
/*num_elements=*/rewriter->getI32IntegerAttr(-1));
auto concat_op = rewriter->create<TF::ConcatOp>(
loc, result_type, scalar_zero,
ArrayRef<Value>({input, stacked_extended_part}));
rewriter->create<func::ReturnOp>(loc, ArrayRef<Value>({concat_op}));
}
void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
func::FuncOp branch_func,
ConversionPatternRewriter *rewriter) const {
// When the input tensorlist's size is larger or equal than the requested
// size, the else branch is executed.
// Slice the first 'size' rows from the input tensorlist.
auto guard = OpBuilder::InsertionGuard(*rewriter);
auto inputs = branch_func.getFunctionType().getInputs();
Block *block = rewriter->createBlock(
&branch_func.getBody(), branch_func.begin(), inputs,
SmallVector<Location>(inputs.size(), branch_func.getLoc()));
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
auto input = block->getArgument(0);
auto size = block->getArgument(3);
// Subtract `input_rank` by 1 to get the item's rank, which is used as
// `partial_position_shape`.
auto input_rank = rewriter->create<TF::RankOp>(
loc, RankedTensorType::get({}, shape_dtype), input);
auto partial_position_shape = rewriter->create<TF::SubOp>(
loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
auto slice_op =
CreateSliceOpForTensorList(loc, /*input_list=*/input,
/*start_index=*/scalar_zero, /*size=*/size,
/*item_rank=*/partial_position_shape,
/*result_type=*/result_type, rewriter);
rewriter->create<func::ReturnOp>(loc, ArrayRef<Value>({slice_op}));
}
};
struct ConvertTensorListGetItem
: public OpConversionPattern<TF::TensorListGetItemOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::TensorListGetItemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getOperands()[0];
Value index = adaptor.getOperands()[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
rewriter.getBoolAttr(true));
return success();
}
};
struct ConvertTensorListLength
: public OpConversionPattern<TF::TensorListLengthOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::TensorListLengthOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input_handle = adaptor.getOperands()[0];
BoolAttr true_attr = rewriter.getBoolAttr(true);
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
/*use_32bit=*/true_attr);
rewriter.replaceOpWithNewOp<TF::GatherOp>(
op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
/*validate_indices=*/true_attr);
return success();
}
};
struct ConvertTensorListStack
: public OpConversionPattern<TF::TensorListStackOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::TensorListStackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getOperands()[0];
Value element_shape = adaptor.getOperands()[1];
// If the `element_shape` is a known constant (which is defined when calling
// `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
// trivial Reshape op (that doesn't actually change the input's shape) and
// also populate the shape info to the op result. The shape of the
// tensorlist is inferred from `num_elements` and `element_shape`.
auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
DenseIntElementsAttr dense_elem_attr;
if ((ranked_type && ranked_type.getRank() == 0) ||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
// If no constant is spotted, just forward the operand.
rewriter.replaceOp(op, {input});
return success();
}
RankedTensorType shape_type =
RankedTensorType::get({-1}, rewriter.getIntegerType(32));
auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
for (const auto &dim : dense_elem_attr.getValues<APInt>())
output_shape.push_back(dim.getSExtValue());
RankedTensorType result_type =
RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
new_shape);
return success();
}
};
// Converts `TensorListConcatV2` into Unpack and Concat. First we unpack
// the input tensorlist along the first dimension, which results in N (where N
// is the first dim's size) tensors (each with shape [element_shape]). Then
// we concatenate all those tensors along the first dimension.
// The pattern will be rejected if either `element_shape` is not constant, or
// the first dimension of `input` is not known.
struct ConvertTensorListConcatV2
: public TensorListOpConverterBase<TF::TensorListConcatV2Op> {
using TensorListOpConverterBase<
TF::TensorListConcatV2Op>::TensorListOpConverterBase;
using TensorListOpConverterBase<
TF::TensorListConcatV2Op>::allow_tensorlist_pass_through_;
LogicalResult matchAndRewrite(
TF::TensorListConcatV2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getOperands()[0];
Value element_shape = adaptor.getOperands()[1];
// Only match when `element_shape` is a constant.
DenseIntElementsAttr dense_elem_attr;
if (!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
const char *error_info = "requires element_shape to be a constant";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
llvm::SmallVector<int64_t, 4> output_shape;
for (const auto &dim : dense_elem_attr.getValues<APInt>()) {
output_shape.push_back(dim.getSExtValue());
}
// First unpack the input tensor along the first dimension.
Type input_element_type = getElementTypeOrSelf(input);
int64_t num_unpacked = 0;
if (auto type = input.getType().dyn_cast<RankedTensorType>()) {
if (type.getDimSize(0) > 0) {
num_unpacked = type.getDimSize(0);
} else {
const char *error_info =
"requires the first dimension of input tensor to have > 0 "
"dimension";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
}
llvm::SmallVector<Type, 1> unpack_output_type;
unpack_output_type.insert(
unpack_output_type.begin(), num_unpacked,
RankedTensorType::get(output_shape, input_element_type));
auto unpack = rewriter.create<TF::UnpackOp>(loc, unpack_output_type, input,
/*axis=*/0);
// Concatenate the unpacked tensors along the first dimension.
// Since we're concatenating along first dimension, change its dim size to
// -1.
output_shape[0] = -1;
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto concat = rewriter.create<TF::ConcatOp>(
loc, RankedTensorType::get(output_shape, input_element_type),
scalar_zero, unpack->getResults());
// `lengths` is only useful for computing gradient. For now we just return
// a placeholder tensor.
rewriter.replaceOp(
op, {concat.getResult(), CreateI64SplatConst(loc, &rewriter, {0}, 0)});
return success();
}
};
struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::IdentityOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getOperands()[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(
op, input.getType(), adaptor.getOperands(), op->getAttrs());
return success();
}
};
struct ConvertReturn : public OpConversionPattern<func::ReturnOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
struct ConvertYield : public OpConversionPattern<TF::YieldOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
// Returns an unranked tensor type with an element of the same type as `value`
// if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
Type VariantToUnrankedTensorType(Type type, Value value) {
TF::VariantType variant_ty =
getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
if (!variant_ty) {
return type;
}
if (!variant_ty.getSubtypes().empty()) {
// Short-circut if the variant type has subtype info.
return UnrankedTensorType::get(
variant_ty.getSubtypes()[0].getElementType());
}
Type value_type = value.getType();
Type element_type;
variant_ty = value_type.dyn_cast<TF::VariantType>();
if (variant_ty && !variant_ty.getSubtypes().empty()) {
element_type = variant_ty.getSubtypes()[0].getElementType();
} else {
element_type = getElementTypeOrSelf(value_type);
}
return UnrankedTensorType::get(element_type);
}
// Returns true if we can deduce the type is tensorlist.
bool IsTensorListType(Type type, llvm::Optional<Value> value) {
TF::VariantType variant_ty =
getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
if (!variant_ty) {
return false;
}
// Check there is only one subtype contained in the variant type. Note that
// when `subtypes.size() == 1` does not always mean the type is actually
// a tensorlist. We probably need some form of data flow analysis.
if (variant_ty.getSubtypes().size() == 1) {
return true;
}
// If subtype info is not available, check if the value is used by any of
// the following TensorList operations.
if (!value.hasValue()) {
return false;
}
for (const mlir::OpOperand &use : value.getValue().getUses()) {
mlir::Operation *op = use.getOwner();
if (llvm::isa<TF::TensorListGetItemOp>(op) ||
llvm::isa<TF::TensorListLengthOp>(op) ||
llvm::isa<TF::TensorListPushBackOp>(op) ||
llvm::isa<TF::TensorListReserveOp>(op) ||
llvm::isa<TF::TensorListSetItemOp>(op) ||
llvm::isa<TF::TensorListStackOp>(op) ||
llvm::isa<TF::TensorListResizeOp>(op)) {
return true;
}
}
return false;
}
// Returns a set of integers that correspond to the tensorlist arguments in
// the function.
llvm::SmallSet<int, 4> GetTensorListArgumentsIndex(func::FuncOp func) {
llvm::SmallSet<int, 4> set;
for (const auto &arg_and_idx : llvm::enumerate(func.getArguments())) {
if (IsTensorListType(arg_and_idx.value().getType(), arg_and_idx.value())) {
set.insert(arg_and_idx.index());
}
}
return set;
}
// Returns a set of integers that correspond to the tensorlist results in the
// function.
llvm::SmallSet<int, 4> GetTensorListResultsIndex(func::FuncOp func) {
llvm::SmallSet<int, 4> set;
for (const auto &result_and_idx :
llvm::enumerate(func.getFunctionType().getResults())) {
if (IsTensorListType(result_and_idx.value(), llvm::None)) {
set.insert(result_and_idx.index());
}
}
return set;
}
// Updates the tensorlist types based on the input index. If the tensorlist's
// size isn't changed(which is indicated by `resized_tensor_list_index`), then
// we will use the original operand's type, otherwise update it with the
// unranked tensor type.
template <typename R>
void UpdateTensorListTypes(
const llvm::SmallSet<int, 4> &tensor_list_index,
const llvm::SmallSet<int, 4> &resized_tensor_list_index,
ArrayRef<Type> types, R &&range, ValueRange operands,
llvm::SmallVectorImpl<Type> *updated_types) {
int i = 0;
for (const auto it : llvm::zip(types, range, operands)) {
if (tensor_list_index.count(i)) {
// Only change the tensorlist's type to unranked tensor if it has been
// resized.
if (resized_tensor_list_index.count(i)) {
updated_types->push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
} else {
updated_types->push_back(std::get<2>(it).getType());
}
} else {
updated_types->push_back(std::get<0>(it));
}
++i;
}
}
// Updates the tensorlist types to unranked tensor types based on the input
// index.
template <typename R>
void ChangeVariantToUnrankedTensorType(
const llvm::SmallSet<int, 4> &tensor_list_index, ArrayRef<Type> types,
R &&range, llvm::SmallVectorImpl<Type> *updated_types) {
int i = 0;
for (const auto it : llvm::zip(types, range)) {
if (tensor_list_index.count(i)) {
updated_types->push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
} else {
updated_types->push_back(std::get<0>(it));
}
++i;
}
}
// Updates the specified function's type and region signature.
void UpdateFunctionAndRegionType(ConversionPatternRewriter &rewriter,
func::FuncOp func,
llvm::ArrayRef<Type> updated_argument_types,
llvm::ArrayRef<Type> updated_result_types) {
// Change `func`'s argument type to `unranked_argument_types`. If its
// return types contain a `DT_VARIANT`, change it to the unranked type
// derived from the corresponding argument.
rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get(func.getContext(), updated_argument_types,
updated_result_types));
});
Region &entry = func.getRegion();
TypeConverter::SignatureConversion signature_conversion(
entry.getNumArguments());
for (const BlockArgument &arg : entry.getArguments()) {
signature_conversion.addInputs(arg.getArgNumber(),
updated_argument_types[arg.getArgNumber()]);
}
rewriter.applySignatureConversion(&entry, signature_conversion);
}
// Changes the function type of `cond_func` and `body_func` for the given While
// op.
LogicalResult UpdateFunctionTypesForWhileOp(
ConversionPatternRewriter &rewriter, TF::WhileOp op, ValueRange operands,
const llvm::SmallSet<int, 4> &tensor_list_args,
const llvm::SmallSet<int, 4> &resized_tensor_lists) {
int func_index = 0;
for (func::FuncOp func : {op.cond_function(), op.body_function()}) {
++func_index;
if (!func) continue;
FunctionType func_type = func.getFunctionType();
int num_inputs = func_type.getNumInputs();
int num_results = func_type.getNumResults();
// For each argument type in function's arguments, change it to uranked
// tensor type if it's a variant type.
SmallVector<Type, 8> updated_argument_types;
updated_argument_types.reserve(num_inputs);
UpdateTensorListTypes<mlir::OperandRange>(
tensor_list_args, resized_tensor_lists, func_type.getInputs(),
op.getOperands(), operands, &updated_argument_types);
// Change all DT_VARIANT result types in function results to unranked tensor
// type with element type derived from the corresponding input operand. This
// is correct because while body's inputs and results have the same type.
SmallVector<Type, 8> updated_result_types;
updated_result_types.reserve(num_results);
if (func_index == 1) {
// We only update the result types for the body function.
for (Type ty : func_type.getResults()) {
updated_result_types.push_back(ty);
}
} else {
UpdateTensorListTypes<mlir::OperandRange>(
tensor_list_args, resized_tensor_lists, func_type.getResults(),
op.getOperands(), operands, &updated_result_types);
}
UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
updated_result_types);
}
return success();
}
// Changes the function type of `then_function` and `else_function` for the
// given If op.
LogicalResult UpdateFunctionTypesForIfOp(
ConversionPatternRewriter &rewriter, TF::IfOp op, ValueRange operands,
const llvm::SmallSet<int, 4> &tensor_list_args,
const llvm::SmallSet<int, 4> &resized_tensor_lists,
llvm::ArrayRef<Type> updated_result_types) {
for (func::FuncOp func : {op.else_function(), op.then_function()}) {
if (!func) continue;
FunctionType func_type = func.getFunctionType();
int num_inputs = func_type.getNumInputs();
// Update the argument types of the function. If it's a tensorlist and
// is not resized inside the function, we will use the corresponding
// operand's type, otherwise change its type to unranked tensor type.
SmallVector<Type, 8> updated_argument_types;
updated_argument_types.reserve(num_inputs);
UpdateTensorListTypes<mlir::OperandRange>(
tensor_list_args, resized_tensor_lists, func_type.getInputs(),
op.getOperands().drop_front(), operands.drop_front(),
&updated_argument_types);
UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
updated_result_types);
}
return success();
}
// Returns a `llvm::DenseMap` which maps from the index of tensorlist in the
// result, to the index of the same tensorlist in the arguments. For `If` op's
// branch functions, the results and arguments are not usually matched 1-1. This
// will let us konw which tensorlist result maps to which tensorlist in the
// arguments. Once we know this info it will help us decide the types of the
// result tensorlist based on the operand's of the `If` op.
llvm::DenseMap<int, int> MapTensorListResultToArgument(func::FuncOp func) {
// `map_fn` will trace upwards along the use-def chain of the ssa value. It
// starts from the last ssa value (returned by the function), and check its
// parent op iteratively. If the root ssa value appears in the function's
// argument list, it will return the index of the corresponding argument,
// otherwise it will return -1.
auto map_fn = [](Value value) -> int {
Value parent = value;
while (true) {
if (auto identity = parent.getDefiningOp<TF::IdentityOp>()) {
parent = identity.input();
} else if (auto set_item =
parent.getDefiningOp<TF::TensorListSetItemOp>()) {
parent = set_item.input_handle();
} else {
break;
}
}
if (auto block_arg = parent.dyn_cast<mlir::BlockArgument>()) {
return block_arg.getArgNumber();
}
// Returns -1 if we don't find which this result maps to.
return -1;
};
llvm::SmallVector<Value, 4> returns;
for (auto res : func.getBody().back().getTerminator()->getOperands()) {
returns.push_back(res);
}
llvm::DenseMap<int, int> result;
for (const auto &result_and_idx : llvm::enumerate(returns)) {
if (IsTensorListType(result_and_idx.value().getType(),
result_and_idx.value())) {
int arg_idx = map_fn(result_and_idx.value());
if (arg_idx != -1) {
result.insert({result_and_idx.index(), arg_idx});
}
}
}
return result;
}
// Updates the tensorlist result types for the `If` Op. If the tensorlist result
// maps to a specific argument (indicated by `tensor_list_map`), and also that
// tensorlist argument's shape isn't changed (indicated by
// `resized_tensor_list_index`), we will update this tensorlist's result type to
// the corresponding operand's type. In all other cases we change the
// tensorlist's type to unranked tensor type.
template <typename R>
void UpdateTensorListResultTypesForIf(
const llvm::SmallSet<int, 4> &tensor_list_index,
const llvm::SmallSet<int, 4> &resized_tensor_list_index,
const llvm::DenseMap<int, int> &tensor_list_map, ArrayRef<Type> types,
R &&range, ValueRange operands,
llvm::SmallVectorImpl<Type> *updated_types) {
int i = 0;
for (const auto it : llvm::zip(types, range)) {
if (!tensor_list_index.count(i)) {
updated_types->push_back(std::get<0>(it));
++i;
continue;
}
auto iter = tensor_list_map.find(i);
if (iter != tensor_list_map.end()) {
int arg_idx = iter->second;
if (!resized_tensor_list_index.count(arg_idx)) {
// If the mapped tensorlist argument's size isn't changed, we will
// use the corresponding `operand` type.
updated_types->push_back(operands[arg_idx].getType());
++i;
continue;
}
}
updated_types->push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
++i;
}
}
struct ConvertIf : public OpConversionPattern<TF::IfOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Find all Tensor List arugments.
auto tensor_list_args = GetTensorListArgumentsIndex(op.else_function());
auto tensor_list_results = GetTensorListResultsIndex(op.else_function());
auto tensor_list_map = MapTensorListResultToArgument(op.else_function());
llvm::SmallSet<int, 4> resized_tensor_lists =
GetResizedTensorListIndexes(op.else_function(), tensor_list_args);
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumResults());
llvm::SmallVector<Type, 4> op_result_types;
for (Type ty : op.getResultTypes()) {
op_result_types.push_back(ty);
}
UpdateTensorListResultTypesForIf<mlir::ResultRange>(
tensor_list_results, resized_tensor_lists, tensor_list_map,
op_result_types, op->getResults(), adaptor.getOperands().drop_front(),
&result_types);
// Create a new if op with new operands and updated result types.
auto converted = rewriter.create<TF::IfOp>(
op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
converted->removeAttr("T");
(void)UpdateFunctionTypesForIfOp(rewriter, converted, adaptor.getOperands(),
tensor_list_args, resized_tensor_lists,
result_types);
rewriter.replaceOp(op, converted.getResults());
return success();
}
};
struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::WhileOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Find all Tensor List arugments.
auto tensor_list_args = GetTensorListArgumentsIndex(op.body_function());
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
// Change all DT_VARIANT result types to unranked tensor type.
llvm::SmallVector<Type, 4> op_result_types;
for (Type ty : op.getResultTypes()) {
op_result_types.push_back(ty);
}
llvm::SmallSet<int, 4> resized_tensor_lists =
GetResizedTensorListIndexes(op.body_function(), tensor_list_args);
UpdateTensorListTypes<mlir::OperandRange>(
tensor_list_args, resized_tensor_lists, op_result_types,
op.getOperands(), adaptor.getOperands(), &result_types);
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileOp>(
op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
converted->removeAttr("T");
(void)UpdateFunctionTypesForWhileOp(rewriter, converted,
adaptor.getOperands(), tensor_list_args,
resized_tensor_lists);
rewriter.replaceOp(op, converted.getResults());
return success();
}
};
struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::WhileRegionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
// Change all DT_VARIANT result types to unranked tensor type.
for (auto it : llvm::zip(op.getResultTypes(), adaptor.getOperands()))
result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileRegionOp>(
op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
// Inline the regions from the old while into the new one, and apply
// signature conversion to inlined region.
for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
Region &old_region = *std::get<0>(it);
Region &new_region = *std::get<1>(it);
Block &entry = old_region.front();
// Build signature conversion for the region.
TypeConverter::SignatureConversion signature_conversion(
adaptor.getOperands().size());
for (auto it : llvm::zip(entry.getArguments(), adaptor.getOperands())) {
BlockArgument arg = std::get<0>(it);
signature_conversion.addInputs(
arg.getArgNumber(),
VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
}
rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
rewriter.applySignatureConversion(&new_region, signature_conversion);
}
rewriter.replaceOp(op, converted.getResults());
return success();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
void LowerStaticTensorListPass::runOnOperation() {
auto *context = &getContext();
// TensorFlow operations that doesn't have operands and results of type
// variant are legal. Here, we don't distinguish between variants encoding
// TensorList or some other type as that information is not available here.
// Partial legalization is used below to still allow ops with variant types
// still.
auto is_legal = [](Operation *op) {
auto is_not_variant = [](Type ty) {
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
};
return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
llvm::all_of(op->getResultTypes(), is_not_variant);
};
ConversionTarget target(*context);
target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(is_legal);
target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
TF::TensorListGetItemOp, TF::TensorListLengthOp,
TF::TensorListPushBackOp, TF::TensorListReserveOp,
TF::TensorListSetItemOp, TF::TensorListStackOp,
TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
// TODO(hinsu): Use TFLite constant op for constants.
target.addLegalOp<arith::ConstantOp>();
target.addLegalOp<func::FuncOp>();
target.addDynamicallyLegalOp<func::ReturnOp>(is_legal);
target.addDynamicallyLegalOp<TF::YieldOp>(is_legal);
target.addLegalOp<TFL::CustomOp>();
// Register fused LSTM/RNN ops as legal.
target.addLegalOp<TFL::LSTMOp>();
target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns.add<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
ConvertTensorListLength, ConvertTensorListPushBack,
ConvertTensorListStack, ConvertTensorListResize, ConvertWhile,
ConvertWhileRegion, ConvertIf, ConvertReturn, ConvertYield>(
context);
patterns.add<ConvertTensorListSetItem>(context, enable_dynamic_update_slice);
patterns.add<ConvertEmptyTensorList, ConvertTensorListConcatV2,
ConvertTensorListReserve>(context, allow_tensorlist_pass_through,
default_to_single_batch);
ModuleOp module = getOperation();
if (!allow_tensorlist_pass_through) {
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
module.emitError(
"Lowering tensor list ops is failed. Please consider using Select TF "
"ops and disabling `_experimental_lower_tensor_list_ops` flag in the "
"TFLite converter object. For example, "
"converter.target_spec.supported_ops = "
"[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]\\n "
"converter._experimental_lower_tensor_list_ops = False");
signalPassFailure();
}
} else {
// If `allow_tensorlist_pass_through` is set to true, if legalization fails
// we should not leak the diagnostic info outside this pass. Hence we use
// a `StatusScopedDiagnosticHandler` here to capture diagnostics generated
// within this pass.
StatusScopedDiagnosticHandler handler(context);
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
auto _ = handler.ConsumeStatus();
}
}
}
} // namespace
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass.
std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
bool allow_tensorlist_pass_through, bool default_to_single_batch,
bool enable_dynamic_update_slice) {
return std::make_unique<LowerStaticTensorListPass>(
allow_tensorlist_pass_through, default_to_single_batch,
enable_dynamic_update_slice);
}
static PassRegistration<LowerStaticTensorListPass> pass;
} // namespace mlir