| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This transformation pass converts operations in TensorFlow dialect into |
| // operations that are legal in the TensorFlow Lite dialect. Operations that |
| // can be legalized to TensorFlow Lite dialect with simple replacements are part |
| // of this pass and other operations that may create extra ops should be part of |
| // the PrepareTF pass which should be run before this pass. That way any |
| // constant folding opportunities from the extra ops can be exploited by the |
| // constant folding support for the TensorFlow ops. |
| |
| #include <climits> |
| #include <complex> |
| #include <cstdint> |
| #include <utility> |
| |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/Hashing.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/Support/Threading.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Diagnostics.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/DialectConversion.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" |
| #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" |
| #include "tensorflow/compiler/mlir/lite/utils/validators.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" |
| #include "tensorflow/compiler/xla/status.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/protobuf/error_codes.pb.h" |
| |
| namespace mlir { |
| namespace TFL { |
| |
| //===----------------------------------------------------------------------===// |
| // The actual LegalizeTF Pass. |
| namespace { |
| |
| constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; |
| constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn"; |
| constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; |
| |
| // Legalize operations in functions. |
| class LegalizeTF : public PassWrapper<LegalizeTF, OperationPass<func::FuncOp>> { |
| void getDependentDialects(DialectRegistry& registry) const override { |
| registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>(); |
| } |
| |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeTF) |
| |
| LegalizeTF() = default; |
| LegalizeTF(const LegalizeTF&) {} |
| explicit LegalizeTF(bool run_tfl_runtime_verification, |
| bool preserve_assert_op) { |
| run_tfl_runtime_verification_ = run_tfl_runtime_verification; |
| preserve_assert_op_ = preserve_assert_op; |
| } |
| |
| StringRef getArgument() const final { |
| // This is the argument used to refer to the pass in |
| // the textual format (on the commandline for example). |
| return "tfl-legalize-tf"; |
| } |
| StringRef getDescription() const final { |
| // This is a brief description of the pass. |
| return "Legalize from TensorFlow to TensorFlow Lite dialect"; |
| } |
| |
| /// Performs the lowering to TFLite dialect. |
| void runOnOperation() override; |
| |
| private: |
| Option<bool> run_tfl_runtime_verification_{ |
| *this, "run-tfl-runtime-verification", |
| llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)}; |
| Option<bool> preserve_assert_op_{ |
| *this, "preserve-assert-op", |
| llvm::cl::desc("Preserve AssertOp during tfl legalization."), |
| llvm::cl::init(false)}; |
| }; |
| |
| // Returns true if all tensor value in `values` has static shape and same shape. |
| bool HasSameStaticShapes(Operation* op) { |
| auto values = op->getOperands(); |
| int index = 0; |
| ArrayRef<int64_t> shape; |
| for (Value value : values) { |
| auto shaped_type = value.getType().dyn_cast<ShapedType>(); |
| if (!shaped_type || !shaped_type.hasStaticShape()) { |
| return false; |
| } |
| if (index == 0) { |
| shape = shaped_type.getShape(); |
| } else { |
| if (shape != shaped_type.getShape()) { |
| return false; |
| } |
| } |
| ++index; |
| } |
| return true; |
| } |
| |
| // Util that casts 'val' to Int32 by adding a cast Op. |
| Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { |
| IntegerType new_ele_type = rewriter.getIntegerType(32); |
| if (auto shaped_type = val.getType().dyn_cast<RankedTensorType>()) { |
| ShapedType new_type = |
| RankedTensorType::get(shaped_type.getShape(), new_ele_type); |
| return rewriter.createOrFold<TF::CastOp>(loc, new_type, val, |
| rewriter.getBoolAttr(false)); |
| } |
| return rewriter.createOrFold<TF::CastOp>( |
| loc, UnrankedTensorType::get(new_ele_type), val, |
| rewriter.getBoolAttr(false)); |
| } |
| |
| // Get shape of an operand or result, support both dynamic and static shape. |
| Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { |
| auto shaped_type = input.getType().cast<ShapedType>(); |
| if (shaped_type.hasStaticShape()) { |
| auto static_shape = shaped_type.getShape(); |
| auto static_shape_type = |
| RankedTensorType::get(static_shape.size(), rewriter.getIntegerType(64)); |
| auto static_shape_attr = |
| mlir::DenseIntElementsAttr::get(static_shape_type, static_shape); |
| return rewriter.create<TF::ConstOp>(loc, static_shape_attr).output(); |
| } |
| |
| // If the shape is not static, create a new ShapeOp. |
| BoolAttr false_attr = rewriter.getBoolAttr(false); |
| return rewriter |
| .create<TF::ShapeOp>(loc, input, |
| /*use_32bit=*/false_attr) |
| .output(); |
| } |
| |
| mlir::TFL::MirrorPaddingType GetTFLMirrorPaddingFromString( |
| mlir::StringAttr padding) { |
| return llvm::StringSwitch<mlir::TFL::MirrorPaddingType>(padding.getValue()) |
| .Case("REFLECT", mlir::TFL::MirrorPaddingType::REFLECT) |
| .Case("SYMMETRIC", mlir::TFL::MirrorPaddingType::SYMMETRIC); |
| } |
| |
| #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" |
| |
| #define DECL_CONVERT_OP(tf_op) \ |
| struct ConvertTF##tf_op##Op : public RewritePattern { \ |
| explicit ConvertTF##tf_op##Op(MLIRContext* context) \ |
| : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \ |
| LogicalResult matchAndRewrite(Operation* op, \ |
| PatternRewriter& rewriter) const override; \ |
| } |
| |
| // TODO(antiagainst): Define this pattern in a table-driven manner once variadic |
| // operands are properly supported in declarative rewrite rule specification. |
| |
| DECL_CONVERT_OP(Assert); |
| DECL_CONVERT_OP(ConcatV2); |
| DECL_CONVERT_OP(MatMul); |
| DECL_CONVERT_OP(MatrixDiagV2); |
| DECL_CONVERT_OP(MatrixDiagV3); |
| DECL_CONVERT_OP(Pack); |
| DECL_CONVERT_OP(Split); |
| DECL_CONVERT_OP(SplitV); |
| DECL_CONVERT_OP(Unpack); |
| DECL_CONVERT_OP(Conv3D); |
| DECL_CONVERT_OP(Conv3DBackpropInputV2); |
| |
| #undef DECL_CONVERT_OP |
| |
| // Converts any IntegerAttr to an IntegerAttr of an i32 type. |
| // The value won't change in the new attribute, but if the value is out of |
| // the bound of i32, the function returns a failure. |
| LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) { |
| if (attr.getType().isInteger(/*width=*/32)) { |
| *attr_i32 = attr; |
| return success(); |
| } |
| |
| int64_t value = attr.getInt(); |
| if (value > std::numeric_limits<int>::max() || |
| value < std::numeric_limits<int>::min()) { |
| return failure(); |
| } |
| |
| *attr_i32 = IntegerAttr::get( |
| IntegerType::get(attr.getContext(), /*width=*/32), value); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFConcatV2Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_concat_op = cast<TF::ConcatV2Op>(op); |
| |
| auto values = tf_concat_op.values(); |
| auto output_type = tf_concat_op.output().getType(); |
| // Extract axis attribute from constant axis tensor |
| ElementsAttr axis; |
| if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); |
| IntegerAttr axis_int = ExtractSingleElementAsInteger(axis); |
| |
| // "axis" operand could be a i64 tensor. Resolve it here. |
| IntegerAttr axis_i32; |
| if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure(); |
| |
| StringAttr fused_activation_function = |
| StringAttr::get(rewriter.getContext(), "NONE"); |
| rewriter.replaceOpWithNewOp<ConcatenationOp>( |
| op, output_type, values, axis_i32, fused_activation_function); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFMatMulOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_matmul_op = cast<TF::MatMulOp>(op); |
| auto lhs = op->getOperand(0); |
| auto rhs = op->getOperand(1); |
| auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> { |
| RankedTensorType type = |
| input.getType().dyn_cast_or_null<RankedTensorType>(); |
| if (!type || type.getRank() != 2) return {failure(), nullptr}; |
| |
| auto permute_attr = DenseIntElementsAttr::get( |
| RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0}); |
| auto permute = rewriter.create<arith::ConstantOp>( |
| op->getLoc(), permute_attr.getType(), permute_attr); |
| llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1], |
| type.getShape()[0]}; |
| auto output = rewriter.create<TFL::TransposeOp>( |
| op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()), |
| input, permute); |
| return {success(), output}; |
| }; |
| |
| // TODO(jpienaar): Remove once handled via dailect conversion. |
| if (tf_matmul_op.transpose_a()) { |
| LogicalResult result = success(); |
| std::tie(result, lhs) = transpose(lhs); |
| if (failed(result)) return failure(); |
| } |
| if (!tf_matmul_op.transpose_b()) { |
| LogicalResult result = success(); |
| std::tie(result, rhs) = transpose(rhs); |
| if (failed(result)) return failure(); |
| } |
| |
| Type output_type = tf_matmul_op.getResult().getType(); |
| auto no_input = rewriter.create<TFL::NoValueOp>( |
| op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); |
| auto fc_op = rewriter.create<FullyConnectedOp>( |
| op->getLoc(), ArrayRef<Type>{output_type}, |
| /*input=*/lhs, /*filter=*/rhs, /*bias=*/no_input, |
| /*fused_activation_function=*/rewriter.getStringAttr("NONE"), |
| /*weights_format=*/rewriter.getStringAttr("DEFAULT"), |
| /*keep_num_dims=*/rewriter.getBoolAttr(false), |
| /*asymmetric_quantize_inputs=*/mlir::BoolAttr()); |
| rewriter.replaceOp(op, {fc_op.getResult(0)}); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFPackOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_pack_op = cast<TF::PackOp>(op); |
| |
| SmallVector<Value, 4> values(tf_pack_op.values()); |
| auto output_type = tf_pack_op.output().getType(); |
| auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); |
| // Axis can be negative. |
| auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis()); |
| |
| rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count, |
| axis); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFSplitOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_split_op = cast<TF::SplitOp>(op); |
| |
| // Number of splits cannot be negative. |
| auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split()); |
| |
| rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(), |
| tf_split_op.split_dim(), |
| tf_split_op.value(), num_split); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFSplitVOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_splitv_op = cast<TF::SplitVOp>(op); |
| |
| // Number of splits cannot be negative. |
| auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split()); |
| |
| rewriter.replaceOpWithNewOp<TFL::SplitVOp>( |
| op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(), |
| tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFUnpackOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_unpack_op = cast<TF::UnpackOp>(op); |
| |
| auto input = tf_unpack_op.value(); |
| auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); |
| // Axis can be negative. |
| auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis()); |
| |
| rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(), |
| input, num, axis); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFConv3DOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| if (!TFDataFormatIsNDHWC(op)) return failure(); |
| |
| auto tf_op = cast<TF::Conv3DOp>(op); |
| |
| IntegerAttr stride_depth, stride_height, stride_width; |
| if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height, |
| &stride_width)) |
| return failure(); |
| |
| IntegerAttr dilation_depth_factor, dilation_height_factor, |
| dilation_width_factor; |
| if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor, |
| &dilation_height_factor, &dilation_width_factor)) { |
| // If the 'dilations' attribute is missing, we use the default value (1) |
| // for all dilation depth, height and width factor. |
| dilation_depth_factor = rewriter.getI32IntegerAttr(1); |
| dilation_height_factor = rewriter.getI32IntegerAttr(1); |
| dilation_width_factor = rewriter.getI32IntegerAttr(1); |
| } |
| |
| StringAttr padding; |
| if (!TFPaddingIsSameOrValid(op, &padding)) return failure(); |
| |
| // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D |
| // with other ops can fill the bias. |
| Value none = rewriter.create<TFL::NoValueOp>( |
| op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); |
| |
| rewriter.replaceOpWithNewOp<TFL::Conv3DOp>( |
| op, tf_op.getType(), tf_op.input(), tf_op.filter(), |
| /*bias=*/none, dilation_depth_factor, dilation_height_factor, |
| dilation_width_factor, |
| /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding, |
| stride_depth, stride_height, stride_width); |
| |
| return success(); |
| } |
| |
| LogicalResult ConvertTFConv3DBackpropInputV2Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| if (!TFDataFormatIsNDHWC(op)) return failure(); |
| |
| auto tf_op = cast<TF::Conv3DBackpropInputV2Op>(op); |
| |
| IntegerAttr stride_depth, stride_height, stride_width; |
| if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height, |
| &stride_width)) |
| return failure(); |
| |
| IntegerAttr dilation_depth_factor, dilation_height_factor, |
| dilation_width_factor; |
| if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor, |
| &dilation_height_factor, &dilation_width_factor)) { |
| // If the 'dilations' attribute is missing, we use the default value (1) |
| // for all dilation depth, height and width factor. |
| dilation_depth_factor = rewriter.getI32IntegerAttr(1); |
| dilation_height_factor = rewriter.getI32IntegerAttr(1); |
| dilation_width_factor = rewriter.getI32IntegerAttr(1); |
| } |
| |
| StringAttr padding; |
| if (!TFPaddingIsSameOrValid(op, &padding)) return failure(); |
| |
| // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D |
| // with other ops can fill the bias. |
| Value none = rewriter.create<TFL::NoValueOp>( |
| op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); |
| |
| Value output_shape = |
| CreateCastToInt32(tf_op.input_sizes(), op->getLoc(), rewriter); |
| |
| rewriter.replaceOpWithNewOp<TFL::Conv3DTransposeOp>( |
| op, tf_op.getType(), output_shape, tf_op.filter(), tf_op.out_backprop(), |
| /*bias=*/none, dilation_depth_factor, dilation_height_factor, |
| dilation_width_factor, |
| /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding, |
| stride_depth, stride_height, stride_width); |
| |
| return success(); |
| } |
| |
| // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute |
| // only has effects when processing multiple diagonals. Since TFLite converts |
| // MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we |
| // can safely ignore this V3 attribute. |
| // We can't pass `rewriter` by reference because clang-tidy will want it to be |
| // constant (`const PatternRewriter& rewriter`). If we do that, we won't be able |
| // to call `rewriter::replaceOpWihNewOp`, which is not a const member function. |
| template <typename MatrixDiagV2OrV3Op> |
| bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { |
| auto tf_matrix_diag_v2_or_v3_op = cast<MatrixDiagV2OrV3Op>(op); |
| |
| if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false; |
| |
| auto input = tf_matrix_diag_v2_or_v3_op.diagonal(); |
| auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType(); |
| |
| // Extract k constant tensor and check value = 0. |
| ElementsAttr k; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k))) |
| return false; |
| if (ExtractSingleElementAsInteger(k).getInt() != 0) return false; |
| |
| // Extract num_rows constant tensor and check value = -1. |
| ElementsAttr num_rows; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(), |
| m_Constant(&num_rows))) |
| return false; |
| if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false; |
| |
| // Extract num_cols constant tensor and check value = -1. |
| ElementsAttr num_cols; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(), |
| m_Constant(&num_cols))) |
| return false; |
| if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false; |
| |
| // Verify padding_value is a tensor with all 0s. |
| mlir::Value padding_value = tf_matrix_diag_v2_or_v3_op.padding_value(); |
| mlir::Type element_type = |
| padding_value.getType().cast<ShapedType>().getElementType(); |
| if (element_type.isa<FloatType>()) { |
| DenseFPElementsAttr padding_attr; |
| if (!matchPattern(padding_value, m_Constant(&padding_attr)) || |
| !padding_attr.isSplat() || |
| !padding_attr.getSplatValue<APFloat>().isZero()) { |
| return false; |
| } |
| } else if (element_type.isa<IntegerType>()) { |
| DenseIntElementsAttr padding_attr; |
| if (!matchPattern(padding_value, m_Constant(&padding_attr)) || |
| !padding_attr.isSplat() || |
| !padding_attr.getSplatValue<APInt>().isZero()) { |
| return false; |
| } |
| } else { |
| // If the padding value is neither float nor int, conservatively assume it |
| // contains nonzeros. |
| return false; |
| } |
| |
| rewriter->replaceOpWithNewOp<MatrixDiagOp>(op, output_type, input); |
| return true; |
| } |
| |
| LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter)) |
| return success(); |
| return failure(); |
| } |
| |
| LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter)) |
| return success(); |
| return failure(); |
| } |
| |
| // TF Lite doesn't support Assert, we just drop the assert from the graph. |
| LogicalResult ConvertTFAssertOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| |
| // Legalize unidirectional sequence lstm. |
| struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { |
| explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) |
| : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation* op, |
| PatternRewriter& rewriter) const override { |
| auto tflite_indices_attr = |
| op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices); |
| if (!tflite_indices_attr) return failure(); |
| |
| SmallVector<int64_t, 20> tflite_indices; |
| for (auto index_attr : tflite_indices_attr.getValue()) { |
| IntegerAttr index = index_attr.cast<IntegerAttr>(); |
| tflite_indices.push_back(index.getInt()); |
| } |
| |
| // Optional input placeholder. |
| Value none = rewriter.create<TFL::NoValueOp>( |
| op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); |
| |
| // Populate inputs. |
| // UnidirectionalSequenceLstm is expected to have 24 inputs. |
| SmallVector<Value, 24> inputs; |
| int count = 0; |
| int total_ophint_converted_inputs = tflite_indices.size(); |
| for (int i = 0; i < 24; ++i) { |
| if (count < total_ophint_converted_inputs && tflite_indices[count] == i) { |
| // specified input. |
| inputs.push_back(op->getOperand(i)); |
| count++; |
| } else { |
| // Non specified input. |
| inputs.push_back(none); |
| } |
| } |
| |
| // Populate outputs. |
| // UnidirectionalSequenceLstm should only have 1 output, and that is the |
| // original ophint converted node's 3rd output. |
| SmallVector<Type, 4> result_types; |
| result_types.push_back(op->getOpResult(2).getType()); |
| |
| // Populate attributes. |
| SmallVector<NamedAttribute, 4> attributes; |
| // Activation will always be tanh. |
| attributes.push_back(rewriter.getNamedAttr("fused_activation_function", |
| rewriter.getStringAttr("TANH"))); |
| // cell_clip. |
| attributes.push_back( |
| rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0))); |
| // proj_clip. |
| attributes.push_back( |
| rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0))); |
| // will always be time_majored. |
| attributes.push_back( |
| rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true))); |
| |
| Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>( |
| op->getLoc(), result_types, inputs, attributes); |
| |
| // Rewire the output. |
| rewriter.replaceOp(op, {nullptr, nullptr, lstm_result}); |
| return success(); |
| } |
| }; |
| |
| // Legalize unidirectional seqeucen rnn. |
| struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { |
| explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context) |
| : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation* op, |
| PatternRewriter& rewriter) const override { |
| auto tflite_indices_attr = |
| op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices); |
| if (!tflite_indices_attr) return failure(); |
| |
| if (op->getNumOperands() != 5) { |
| op->emitError() |
| << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only " |
| << op->getNumOperands() << " provided"; |
| return failure(); |
| } |
| |
| if (op->getNumResults() != 2) { |
| op->emitError() |
| << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only " |
| << op->getNumResults() << " found"; |
| return failure(); |
| } |
| |
| // Populate inputs. |
| // UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them |
| // are optional inputs. |
| SmallVector<Value, 5> inputs; |
| for (int i = 0; i < 5; ++i) { |
| inputs.push_back(op->getOperand(i)); |
| } |
| |
| // Populate outputs. |
| // UnidirectionalSequenceRnn should only have 1 output, and that is the |
| // original ophint converted node's 2nd output. |
| SmallVector<Type, 4> result_types; |
| result_types.push_back(op->getOpResult(1).getType()); |
| |
| // Populate attributes. |
| SmallVector<NamedAttribute, 2> attributes; |
| // Activation will always be tanh. |
| attributes.push_back(rewriter.getNamedAttr("fused_activation_function", |
| rewriter.getStringAttr("TANH"))); |
| |
| // will always be time_majored. |
| attributes.push_back( |
| rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true))); |
| |
| Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>( |
| op->getLoc(), result_types, inputs, attributes); |
| |
| // Rewire the output. |
| rewriter.replaceOp(op, {nullptr, rnn_result}); |
| |
| return success(); |
| } |
| }; |
| |
| // Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to |
| // to make binary broadcast-able op conversion always successful and does not |
| // require flex delegate. |
| template <typename SourceOp> |
| class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> { |
| public: |
| using OpRewritePattern<SourceOp>::OpRewritePattern; |
| |
| LogicalResult rewriteOpWithDynamicInput(Operation* op, |
| PatternRewriter& rewriter) const { |
| auto lhs = op->getOperand(0); |
| auto rhs = op->getOperand(1); |
| auto out = op->getResult(0); |
| |
| // Calculates symbolic broadcast shape that is only used in types. |
| SmallVector<int64_t, 4> symbolic_broadcast_shape; |
| // Matches fail when lhs or rhs is unranked tensor. |
| // TODO(b/176202543): Support unranked tensor. |
| if (!lhs.getType().cast<ShapedType>().hasRank() || |
| !rhs.getType().cast<ShapedType>().hasRank()) { |
| return failure(); |
| } |
| if (!OpTrait::util::getBroadcastedShape( |
| lhs.getType().cast<ShapedType>().getShape(), |
| rhs.getType().cast<ShapedType>().getShape(), |
| symbolic_broadcast_shape)) { |
| return failure(); |
| } |
| |
| // Calculates the broadcast shape using BroadcastArgs op. |
| Value lhs_shape = GetShape(lhs, op->getLoc(), rewriter); |
| Value rhs_shape = GetShape(rhs, op->getLoc(), rewriter); |
| auto broadcast_shape = |
| rewriter |
| .create<TF::BroadcastArgsOp>( |
| op->getLoc(), |
| RankedTensorType::get(symbolic_broadcast_shape.size(), |
| rewriter.getIntegerType(64)), |
| lhs_shape, rhs_shape) |
| .r0(); |
| |
| // Broadcasts inputs using BroadcastTo op. |
| auto broadcast_type = RankedTensorType::get( |
| symbolic_broadcast_shape, getElementTypeOrSelf(lhs.getType())); |
| auto broadcasted_lhs = |
| rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs, |
| broadcast_shape) |
| .output(); |
| auto broadcasted_rhs = |
| rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs, |
| broadcast_shape) |
| .output(); |
| |
| // Recreate an op with the above BroadcastTo op results. |
| RankedTensorType result_type = RankedTensorType::get( |
| symbolic_broadcast_shape, getElementTypeOrSelf(out.getType())); |
| rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, broadcasted_lhs, |
| broadcasted_rhs); |
| return success(); |
| } |
| |
| LogicalResult matchAndRewrite(SourceOp src_op, |
| PatternRewriter& rewriter) const override { |
| Operation* op = static_cast<Operation*>(src_op); |
| auto lhs = op->getOperand(0); |
| auto rhs = op->getOperand(1); |
| |
| if (!lhs.getType().cast<ShapedType>().hasStaticShape() || |
| !rhs.getType().cast<ShapedType>().hasStaticShape()) { |
| return rewriteOpWithDynamicInput(op, rewriter); |
| } |
| |
| auto lhs_shape = lhs.getType().cast<ShapedType>().getShape(); |
| auto rhs_shape = rhs.getType().cast<ShapedType>().getShape(); |
| |
| if (lhs_shape == rhs_shape) { |
| return failure(); |
| } |
| |
| // Calculate the broadcasted shape. |
| SmallVector<int64_t, 4> result_shape; |
| if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, |
| result_shape)) { |
| return failure(); |
| } |
| |
| RankedTensorType result_type = RankedTensorType::get( |
| result_shape, getElementTypeOrSelf(op->getResult(0).getType())); |
| |
| // Create a const op, that stores the above broadcasted shape. |
| auto new_shape_attr = mlir::DenseIntElementsAttr::get( |
| RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)), |
| result_shape); |
| auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr); |
| |
| // Apply BroadcastTo ops to each input. |
| auto broadcast_type = RankedTensorType::get( |
| result_shape, getElementTypeOrSelf(lhs.getType())); |
| |
| if (result_type.getShape() != lhs_shape) { |
| lhs = rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs, |
| new_shape) |
| .output(); |
| } |
| if (result_type.getShape() != rhs_shape) { |
| rhs = rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs, |
| new_shape) |
| .output(); |
| } |
| |
| // Recreate an op with the above Broadcast op results. |
| rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs); |
| return success(); |
| } |
| }; |
| |
| // This specialization is for TF SelectV2 op. SelectV2 op have three inputs and |
| // they should have broadcastable shapes. |
| template <> |
| class ApplyExplicitBroadcasting<TF::SelectV2Op> |
| : public OpRewritePattern<TF::SelectV2Op> { |
| public: |
| using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern; |
| |
| LogicalResult rewriteOpWithDynamicInput(Operation* op, |
| PatternRewriter& rewriter) const { |
| auto cond = op->getOperand(0); |
| auto lhs = op->getOperand(1); |
| auto rhs = op->getOperand(2); |
| auto out = op->getResult(0); |
| |
| // Matches fail when lhs|rhs|cond is unranked tensor. |
| // TODO(b/176202543): Support unranked tensor. |
| if (!lhs.getType().cast<ShapedType>().hasRank() || |
| !rhs.getType().cast<ShapedType>().hasRank() || |
| !cond.getType().cast<ShapedType>().hasRank()) { |
| return failure(); |
| } |
| |
| // Calculates symbolic broadcast shape that is only used in types. |
| SmallVector<int64_t, 4> symbolic_broadcast_lhs_rhs_shape; |
| if (!OpTrait::util::getBroadcastedShape( |
| lhs.getType().cast<ShapedType>().getShape(), |
| rhs.getType().cast<ShapedType>().getShape(), |
| symbolic_broadcast_lhs_rhs_shape)) { |
| return failure(); |
| } |
| SmallVector<int64_t, 4> symbolic_broadcast_shape; |
| if (!OpTrait::util::getBroadcastedShape( |
| cond.getType().cast<ShapedType>().getShape(), |
| symbolic_broadcast_lhs_rhs_shape, symbolic_broadcast_shape)) { |
| return failure(); |
| } |
| |
| // Calculates the broadcast shape using BroadcastArgs op. |
| Value cond_shape = GetShape(cond, op->getLoc(), rewriter); |
| Value lhs_shape = GetShape(lhs, op->getLoc(), rewriter); |
| Value rhs_shape = GetShape(rhs, op->getLoc(), rewriter); |
| auto broadcast_shape_value = |
| rewriter |
| .create<TF::BroadcastArgsOp>(op->getLoc(), lhs_shape.getType(), |
| lhs_shape, rhs_shape) |
| .r0(); |
| broadcast_shape_value = |
| rewriter |
| .create<TF::BroadcastArgsOp>(op->getLoc(), lhs_shape.getType(), |
| broadcast_shape_value, cond_shape) |
| .r0(); |
| |
| // Broadcasting inputs using BroadcastTo op. |
| auto broadcast_type = RankedTensorType::get( |
| symbolic_broadcast_shape, getElementTypeOrSelf(out.getType())); |
| auto broadcasted_cond = |
| rewriter |
| .create<TF::BroadcastToOp>( |
| op->getLoc(), |
| RankedTensorType::get(symbolic_broadcast_shape, |
| rewriter.getIntegerType(1)), |
| cond, broadcast_shape_value) |
| .output(); |
| auto broadcasted_lhs = |
| rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs, |
| broadcast_shape_value) |
| .output(); |
| auto broadcasted_rhs = |
| rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs, |
| broadcast_shape_value) |
| .output(); |
| |
| // Recreate an op with the above BroadcastTo op results. |
| rewriter.replaceOpWithNewOp<TF::SelectV2Op>( |
| op, broadcast_type, broadcasted_cond, broadcasted_lhs, broadcasted_rhs); |
| return success(); |
| } |
| |
| LogicalResult matchAndRewrite(TF::SelectV2Op src_op, |
| PatternRewriter& rewriter) const override { |
| Operation* op = static_cast<Operation*>(src_op); |
| auto cond = op->getOperand(0); |
| auto lhs = op->getOperand(1); |
| auto rhs = op->getOperand(2); |
| |
| // Should have static shapes to calculate the broadcasted shape. |
| if (!lhs.getType().cast<ShapedType>().hasStaticShape() || |
| !rhs.getType().cast<ShapedType>().hasStaticShape() || |
| !cond.getType().cast<ShapedType>().hasStaticShape()) { |
| return rewriteOpWithDynamicInput(op, rewriter); |
| } |
| |
| auto lhs_shape = lhs.getType().cast<ShapedType>().getShape(); |
| auto rhs_shape = rhs.getType().cast<ShapedType>().getShape(); |
| auto cond_shape = cond.getType().cast<ShapedType>().getShape(); |
| |
| if (lhs_shape == rhs_shape && cond_shape == lhs_shape) { |
| return failure(); |
| } |
| |
| // Calculate the broadcasted shape. |
| SmallVector<int64_t, 4> broadcasted_shape; |
| if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, |
| broadcasted_shape)) { |
| return failure(); |
| } |
| |
| SmallVector<int64_t, 4> result_shape; |
| if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape, |
| result_shape)) { |
| return failure(); |
| } |
| |
| // Create a const op, that stores the above broadcasted shape. |
| auto shape_type = |
| RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)); |
| auto new_shape_attr = |
| mlir::DenseIntElementsAttr::get(shape_type, result_shape); |
| auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr); |
| |
| // Apply BroadcastTo ops to each input. |
| auto cond_result_type = |
| RankedTensorType::get(result_shape, rewriter.getIntegerType(1)); |
| auto result_type = RankedTensorType::get( |
| result_shape, getElementTypeOrSelf(lhs.getType())); |
| |
| if (result_shape != cond_shape) { |
| cond = rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type, |
| cond, new_shape) |
| .output(); |
| } |
| if (result_shape != lhs_shape) { |
| lhs = rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs, |
| new_shape) |
| .output(); |
| } |
| if (result_shape != rhs_shape) { |
| rhs = rewriter |
| .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs, |
| new_shape) |
| .output(); |
| } |
| |
| // Recreate an op with the above Broadcast op results. |
| rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs, |
| rhs); |
| return success(); |
| } |
| }; |
| |
| void addPatterns(MLIRContext* context, RewritePatternSet& patterns, |
| bool preserve_assert_op) { |
| // Add TF->TF lowering patterns. |
| TF::PopulateLoweringTFPatterns(context, &patterns); |
| |
| // Add the generated patterns to the list. |
| populateWithGenerated(patterns); |
| patterns.add<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op, |
| ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp, |
| ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFConv3DOp, |
| ConvertTFConv3DBackpropInputV2Op>(context); |
| if (!preserve_assert_op) patterns.add<ConvertTFAssertOp>(context); |
| |
| // Ophint python converter converted tf node pattern. |
| patterns.add<LegalizeUnidirectionalSequenceLstm, |
| LegalizeUnidirectionalSequenceRnn>(context); |
| } |
| |
| bool applyPatterns(func::FuncOp func, ConversionTarget& target, |
| FrozenRewritePatternSet& frozenPatterns) { |
| // Keep trying to convert. |
| // TODO(karimnosseir): This is similar to what apply greedy patterns does. |
| // Look if there is a function that tries until it converge. |
| // Currently unit-test doesn't do multiple tries, so we need this. |
| const int max_iterations = 15; |
| for (int i = 0; i < max_iterations; ++i) { |
| if (failed(applyPartialConversion(func, target, frozenPatterns))) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| void LegalizeTF::runOnOperation() { |
| auto* context = &getContext(); |
| auto func = getOperation(); |
| |
| ConversionTarget target(*context); |
| // It is legal to have TF ops in the graph still which can be |
| // used later or in the case of SELECT were we allow TF ops in the final |
| // graph. |
| target.addLegalOp<mlir::arith::ConstantOp>(); |
| target.addLegalOp<mlir::func::ConstantOp>(); |
| target.addLegalOp<TFL::NoValueOp>(); |
| target.addLegalOp<ConstOp>(); |
| target.addLegalOp<DequantizeOp>(); |
| target.addLegalOp<QConstOp>(); |
| if (run_tfl_runtime_verification_) { |
| target.addDynamicallyLegalDialect<TensorFlowLiteDialect>([](Operation* op) { |
| auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op); |
| if (!tfl_op) return false; |
| return succeeded(tfl_op.VerifyTflRuntimeConstraints( |
| op, /*emit_error_on_verify_fail=*/false)); |
| }); |
| } else { |
| target.addLegalDialect<TensorFlowLiteDialect>(); |
| } |
| |
| RewritePatternSet stage1Patterns(&getContext()); |
| |
| addPatterns(context, stage1Patterns, preserve_assert_op_); |
| |
| FrozenRewritePatternSet stage1FrozenPatterns(std::move(stage1Patterns)); |
| if (!applyPatterns(func, target, stage1FrozenPatterns)) |
| return signalPassFailure(); |
| |
| // Explict BroadcastTo addition for left-over broadcast-able ops. |
| // The following pattern matchings should be done after the other legalization |
| // rules in order not to add unnecessary BroadcastTo ops. |
| RewritePatternSet stage2Patterns(&getContext()); |
| |
| addPatterns(context, stage2Patterns, preserve_assert_op_); |
| |
| stage2Patterns.add<ApplyExplicitBroadcasting<TF::LessEqualOp>, |
| ApplyExplicitBroadcasting<TF::GreaterEqualOp>, |
| ApplyExplicitBroadcasting<TF::NotEqualOp>, |
| ApplyExplicitBroadcasting<TF::GreaterOp>, |
| ApplyExplicitBroadcasting<TF::LessOp>, |
| ApplyExplicitBroadcasting<TF::EqualOp>, |
| ApplyExplicitBroadcasting<TF::AddOp>, |
| ApplyExplicitBroadcasting<TF::AddV2Op>, |
| ApplyExplicitBroadcasting<TF::MulOp>, |
| ApplyExplicitBroadcasting<TF::DivOp>, |
| ApplyExplicitBroadcasting<TF::RealDivOp>, |
| ApplyExplicitBroadcasting<TF::SubOp>, |
| ApplyExplicitBroadcasting<TF::FloorDivOp>, |
| ApplyExplicitBroadcasting<TF::FloorModOp>, |
| ApplyExplicitBroadcasting<TF::PowOp>, |
| ApplyExplicitBroadcasting<TF::MaximumOp>, |
| ApplyExplicitBroadcasting<TF::MinimumOp>, |
| ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>, |
| ApplyExplicitBroadcasting<TF::SelectV2Op>>(context); |
| |
| FrozenRewritePatternSet stage2FrozenPatterns(std::move(stage2Patterns)); |
| if (!applyPatterns(func, target, stage2FrozenPatterns)) |
| return signalPassFailure(); |
| } |
| |
| } // namespace |
| |
| // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. |
| std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeTFPass( |
| bool run_tfl_runtime_verification, bool preserve_assert_op) { |
| return std::make_unique<LegalizeTF>(run_tfl_runtime_verification, |
| preserve_assert_op); |
| } |
| |
| static PassRegistration<LegalizeTF> pass; |
| |
| } // namespace TFL |
| } // namespace mlir |