| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This transformation pass converts operations in TensorFlow dialect into |
| // operations that are legal in the TensorFlow Lite dialect. Operations that |
| // can be legalized to TensorFlow Lite dialect with simple replacements are part |
| // of this pass and other operations that may create extra ops should be part of |
| // the PrepareTF pass which should be run before this pass. That way any |
| // constant folding opportunities from the extra ops can be exploited by the |
| // constant folding support for the TensorFlow ops. |
| |
| #include <climits> |
| #include <complex> |
| #include <cstdint> |
| |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/IR/StandardTypes.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/Functional.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Transforms/DialectConversion.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" |
| #include "tensorflow/compiler/mlir/lite/utils/validators.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" |
| #include "tensorflow/compiler/xla/status.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/lib/random/philox_random.h" |
| #include "tensorflow/core/lib/random/random_distributions.h" |
| #include "tensorflow/core/protobuf/error_codes.pb.h" |
| |
| namespace mlir { |
| namespace TFL { |
| |
| //===----------------------------------------------------------------------===// |
| // The actual LegalizeTF Pass. |
| namespace { |
| |
| using xla::Status; |
| using xla::StatusOr; |
| |
| constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; |
| constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn"; |
| constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; |
| |
| // Legalize operations in functions. |
| class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { |
| public: |
| LegalizeTF() = default; |
| LegalizeTF(const LegalizeTF&) {} |
| explicit LegalizeTF(bool run_tfl_runtime_verification) { |
| run_tfl_runtime_verification_ = run_tfl_runtime_verification; |
| } |
| |
| /// Performs the lowering to TFLite dialect. |
| void runOnFunction() override; |
| |
| private: |
| Option<bool> run_tfl_runtime_verification_{ |
| *this, "run-tfl-runtime-verification", |
| llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)}; |
| }; |
| |
| // Returns true if all tensor value in `values` has static shape and same shape. |
| bool HasSameStaticShapes(Operation* op) { |
| auto values = op->getOperands(); |
| int index = 0; |
| ArrayRef<int64_t> shape; |
| for (Value value : values) { |
| auto shaped_type = value.getType().dyn_cast<ShapedType>(); |
| if (!shaped_type || !shaped_type.hasStaticShape()) { |
| return false; |
| } |
| if (index == 0) { |
| shape = shaped_type.getShape(); |
| } else { |
| if (shape != shaped_type.getShape()) { |
| return false; |
| } |
| } |
| ++index; |
| } |
| return true; |
| } |
| |
| #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" |
| |
| #define DECL_CONVERT_OP(tf_op) \ |
| struct ConvertTF##tf_op##Op : public RewritePattern { \ |
| explicit ConvertTF##tf_op##Op(MLIRContext* context) \ |
| : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \ |
| LogicalResult matchAndRewrite(Operation* op, \ |
| PatternRewriter& rewriter) const override; \ |
| } |
| |
| // TODO(antiagainst): Define this pattern in a table-driven manner once variadic |
| // operands are properly supported in declarative rewrite rule specification. |
| |
| DECL_CONVERT_OP(Assert); |
| DECL_CONVERT_OP(Concat); |
| DECL_CONVERT_OP(ConcatV2); |
| DECL_CONVERT_OP(MatMul); |
| DECL_CONVERT_OP(MatrixDiagV2); |
| DECL_CONVERT_OP(MatrixDiagV3); |
| DECL_CONVERT_OP(Pack); |
| DECL_CONVERT_OP(Reshape); |
| DECL_CONVERT_OP(Split); |
| DECL_CONVERT_OP(SplitV); |
| DECL_CONVERT_OP(StridedSlice); |
| DECL_CONVERT_OP(Unpack); |
| DECL_CONVERT_OP(Reciprocal); |
| DECL_CONVERT_OP(RandomUniform); |
| DECL_CONVERT_OP(BroadcastTo); |
| |
| #undef DECL_CONVERT_OP |
| |
| LogicalResult ConvertTFRandomUniformOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto random_uniform_op = cast<TF::RandomUniformOp>(op); |
| if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) { |
| return failure(); |
| } |
| if (!random_uniform_op.dtype().isF32()) { |
| return failure(); |
| } |
| typedef tensorflow::random::UniformDistribution< |
| tensorflow::random::PhiloxRandom, float> |
| Distribution; |
| |
| tensorflow::random::PhiloxRandom generator( |
| random_uniform_op.seed().getSExtValue(), |
| random_uniform_op.seed2().getSExtValue()); |
| Distribution dist; |
| int num_elements = 0; |
| if (auto output_type = |
| random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) { |
| if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) { |
| if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) { |
| return failure(); |
| } |
| num_elements = output_type.getNumElements(); |
| size_t offset = 0; |
| size_t num_samples = Distribution::kResultElementCount; |
| llvm::SmallVector<float, 32> data; |
| data.resize(num_elements); |
| while (offset < num_elements) { |
| const typename Distribution::ResultType samples = dist(&generator); |
| std::copy(&samples[0], |
| &samples[0] + std::min(num_samples, data.size() - offset), |
| &data[0] + offset); |
| offset += num_samples; |
| } |
| auto output_data = DenseFPElementsAttr::get(output_type, data); |
| rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| |
| LogicalResult ConvertTFConcatOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_concat_op = cast<TF::ConcatOp>(op); |
| |
| auto values = tf_concat_op.values(); |
| auto output_type = tf_concat_op.output().getType(); |
| // Extract axis attribute from constant concat_dims tensor |
| ElementsAttr axis; |
| if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis))) |
| return failure(); |
| |
| StringAttr fused_activation_function = |
| StringAttr::get("NONE", rewriter.getContext()); |
| rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>( |
| op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis), |
| fused_activation_function); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFConcatV2Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_concat_op = cast<TF::ConcatV2Op>(op); |
| |
| auto values = tf_concat_op.values(); |
| auto output_type = tf_concat_op.output().getType(); |
| // Extract axis attribute from constant axis tensor |
| ElementsAttr axis; |
| if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); |
| |
| StringAttr fused_activation_function = |
| StringAttr::get("NONE", rewriter.getContext()); |
| rewriter.replaceOpWithNewOp<ConcatenationOp>( |
| op, output_type, values, ExtractSingleElementAsInteger(axis), |
| fused_activation_function); |
| return success(); |
| } |
| |
| // The following is effectively: |
| // def : Pat< |
| // (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a, |
| // ConstBoolAttrTrue:$transpose_b), |
| // (TFL_FullyConnectedOp:$__0 $a, $b, |
| // NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>; |
| LogicalResult ConvertTFMatMulOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_matmul_op = cast<TF::MatMulOp>(op); |
| if (tf_matmul_op.transpose_a()) return failure(); |
| if (!tf_matmul_op.transpose_b()) return failure(); |
| |
| Type output_type = tf_matmul_op.getResult().getType(); |
| // TODO(jpienaar): Follow up post shuffle discussion. |
| auto no_input = rewriter.create<ConstantOp>( |
| op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); |
| auto fc_op = rewriter.create<FullyConnectedOp>( |
| op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0), |
| op->getOperand(1), no_input, rewriter.getStringAttr("NONE"), |
| rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false)); |
| rewriter.replaceOp(op, {fc_op.getResult(0)}); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFPackOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_pack_op = cast<TF::PackOp>(op); |
| |
| SmallVector<Value, 4> values(tf_pack_op.values()); |
| auto output_type = tf_pack_op.output().getType(); |
| auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); |
| // Axis can be negative. |
| auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue()); |
| |
| rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count, |
| axis); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFReshapeOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_reshape_op = cast<TF::ReshapeOp>(op); |
| |
| auto input = tf_reshape_op.tensor(); |
| auto shape = tf_reshape_op.shape(); |
| |
| ShapedType shape_type = shape.getType().cast<ShapedType>(); |
| // The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast. |
| if (!shape_type.getElementType().isSignlessInteger(32)) { |
| auto new_shape = shape_type.getShape(); |
| IntegerType new_ele_type = rewriter.getIntegerType(32); |
| ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type); |
| // Uses TF::CastOp to be folded if the shape input is a constant. |
| shape = rewriter |
| .create<TF::CastOp>(op->getLoc(), new_type, shape, |
| rewriter.getBoolAttr(false)) |
| .y(); |
| } |
| rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(), |
| input, shape); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFSplitOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_split_op = cast<TF::SplitOp>(op); |
| |
| auto output_types = functional::map([](Value v) { return v.getType(); }, |
| tf_split_op.output()); |
| // Number of splits cannot be negative. |
| auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split()); |
| |
| rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, output_types, |
| tf_split_op.split_dim(), |
| tf_split_op.value(), num_split); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFSplitVOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_splitv_op = cast<TF::SplitVOp>(op); |
| |
| auto output_types = functional::map([](Value v) { return v.getType(); }, |
| tf_splitv_op.output()); |
| // Number of splits cannot be negative. |
| auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split()); |
| |
| rewriter.replaceOpWithNewOp<TFL::SplitVOp>( |
| op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(), |
| tf_splitv_op.split_dim(), num_split); |
| return success(); |
| } |
| |
| Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter, |
| Value attribute, |
| ArrayRef<int32_t> padding_val, int* mask) { |
| DenseIntElementsAttr dense_elem_attr; |
| SmallVector<int32_t, 8> padded_val; |
| |
| auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>(); |
| if (!ranked_attr_type || |
| !matchPattern(attribute, m_Constant(&dense_elem_attr))) { |
| // If the input attribute is neither ranked type nor constant, we |
| // can't do any padding. Instead we just return it. |
| return attribute; |
| } |
| for (auto idx : dense_elem_attr.getIntValues()) { |
| padded_val.push_back(idx.getSExtValue()); |
| } |
| auto attr_dim_count = ranked_attr_type.getShape()[0]; |
| int full_dim_count = padding_val.size(); |
| for (int i = attr_dim_count; i < full_dim_count; ++i) { |
| padded_val.push_back(padding_val[i]); |
| if (mask) *mask |= 1 << i; |
| } |
| auto type = |
| RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32)); |
| auto attr = DenseElementsAttr::get<int32_t>(type, padded_val); |
| return rewriter.create<ConstantOp>(op->getLoc(), type, attr); |
| } |
| |
| LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op); |
| auto ranked_input_type = |
| tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>(); |
| if (!ranked_input_type) { |
| // If input is not a ranked tensor, we can't deduce the padding dimensions |
| // from it, so we just do a plain conversion here. |
| rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>( |
| op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), |
| tf_strided_slice_op.begin(), tf_strided_slice_op.end(), |
| tf_strided_slice_op.strides(), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.begin_mask().getSExtValue()), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.end_mask().getSExtValue()), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.ellipsis_mask().getSExtValue()), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.new_axis_mask().getSExtValue()), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.shrink_axis_mask().getSExtValue())); |
| return success(); |
| } |
| |
| int num_input_dims = ranked_input_type.getRank(); |
| // Pad `begin` array with zero values and update the `begin_mask`. |
| SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0); |
| int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue(); |
| Value padded_begin = PadStridedSliceAttributeArray( |
| op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask); |
| // Pad `end` array with `input_shape` and update the `end_mask`. |
| int end_mask = tf_strided_slice_op.end_mask().getSExtValue(); |
| auto input_shape = ranked_input_type.getShape(); |
| SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end()); |
| Value padded_end = PadStridedSliceAttributeArray( |
| op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask); |
| // Pad `strides` array with ones. |
| SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1); |
| Value padded_strides = PadStridedSliceAttributeArray( |
| op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr); |
| rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>( |
| op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(), |
| padded_begin, padded_end, padded_strides, |
| rewriter.getI32IntegerAttr(begin_mask), |
| rewriter.getI32IntegerAttr(end_mask), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.ellipsis_mask().getSExtValue()), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.new_axis_mask().getSExtValue()), |
| rewriter.getI32IntegerAttr( |
| tf_strided_slice_op.shrink_axis_mask().getSExtValue())); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFUnpackOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_unpack_op = cast<TF::UnpackOp>(op); |
| |
| auto input = tf_unpack_op.value(); |
| auto output_types = functional::map([](Value v) { return v.getType(); }, |
| tf_unpack_op.output()); |
| auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); |
| // Axis can be negative. |
| auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue()); |
| |
| rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis); |
| return success(); |
| } |
| |
| // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute |
| // only has effects when processing multiple diagonals. Since TFLite converts |
| // MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we |
| // can safely ignore this V3 attribute. |
| // We can't pass `rewriter` by reference because clang-tidy will want it to be |
| // constant (`const PatternRewriter& rewriter`). If we do that, we won't be able |
| // to call `rewriter::replaceOpWihNewOp`, which is not a const member function. |
| template <typename MatrixDiagV2OrV3Op> |
| bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { |
| auto tf_matrix_diag_v2_or_v3_op = cast<MatrixDiagV2OrV3Op>(op); |
| |
| if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false; |
| |
| auto input = tf_matrix_diag_v2_or_v3_op.diagonal(); |
| auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType(); |
| |
| // Extract k constant tensor and check value = 0. |
| ElementsAttr k; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k))) |
| return false; |
| if (ExtractSingleElementAsInteger(k).getInt() != 0) return false; |
| |
| // Extract num_rows constant tensor and check value = -1. |
| ElementsAttr num_rows; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(), |
| m_Constant(&num_rows))) |
| return false; |
| if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false; |
| |
| // Extract num_cols constant tensor and check value = -1. |
| ElementsAttr num_cols; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(), |
| m_Constant(&num_cols))) |
| return false; |
| if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false; |
| |
| // Verify padding_value is an integer tensor with all 0s. |
| ElementsAttr padding_value; |
| if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(), |
| m_Constant(&padding_value))) |
| return false; |
| for (auto value : padding_value.getValues<APInt>()) { |
| if (value != 0) return false; |
| } |
| |
| rewriter->replaceOpWithNewOp<MatrixDiagOp>(op, output_type, input); |
| return true; |
| } |
| |
| LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter)) |
| return success(); |
| return failure(); |
| } |
| |
| LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter)) |
| return success(); |
| return failure(); |
| } |
| |
| // TF Lite doesn't support Assert, we just drop the assert from the graph. |
| LogicalResult ConvertTFAssertOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| |
| StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter, |
| Location loc, |
| ShapedType shaped_type, |
| int value) { |
| Type element_type = shaped_type.getElementType(); |
| ShapedType scalar_type = RankedTensorType::get({}, element_type); |
| Attribute attr; |
| switch (element_type.getKind()) { |
| case mlir::StandardTypes::F16: { |
| auto floatType = mlir::FloatType::getF16(element_type.getContext()); |
| auto floatAttr = |
| mlir::FloatAttr::get(floatType, static_cast<float>(value)); |
| std::vector<Attribute> floatValues({floatAttr}); |
| attr = DenseElementsAttr::get(scalar_type, floatValues); |
| break; |
| } |
| case mlir::StandardTypes::F32: { |
| attr = |
| DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value)); |
| break; |
| } |
| case mlir::StandardTypes::Complex: { |
| auto etype = element_type.cast<mlir::ComplexType>().getElementType(); |
| if (etype.isF32()) { |
| auto dialect = etype.getContext()->getRegisteredDialect("tf"); |
| tensorflow::TensorProto repr; |
| repr.set_dtype(tensorflow::DT_COMPLEX64); |
| |
| tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); |
| shape->set_unknown_rank(false); |
| shape->add_dim()->set_size(int64_t{1}); |
| std::string content; |
| auto complex_value = |
| std::complex<float>(static_cast<float>(value), 0.0f); |
| content.assign(reinterpret_cast<const char*>(&complex_value), |
| sizeof(complex_value)); |
| repr.set_tensor_content(content); |
| std::string mangled = tensorflow::mangling_util::MangleTensor(repr); |
| |
| attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); |
| break; |
| } |
| return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); |
| } |
| case mlir::StandardTypes::Integer: { |
| const auto& itype = element_type.cast<mlir::IntegerType>(); |
| switch (itype.getWidth()) { |
| case 8: |
| attr = DenseElementsAttr::get<int8_t>(scalar_type, |
| static_cast<int8_t>(value)); |
| break; |
| case 16: |
| attr = DenseElementsAttr::get<int16_t>(scalar_type, |
| static_cast<int16_t>(value)); |
| break; |
| case 32: |
| attr = DenseElementsAttr::get<int32_t>(scalar_type, |
| static_cast<int32_t>(value)); |
| break; |
| case 64: |
| attr = DenseElementsAttr::get<int64_t>(scalar_type, |
| static_cast<int64_t>(value)); |
| break; |
| default: |
| return Status(tensorflow::error::INVALID_ARGUMENT, |
| "Unsupported type"); |
| } |
| break; |
| } |
| default: |
| return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); |
| } |
| return rewriter->create<ConstantOp>(loc, scalar_type, attr); |
| } |
| |
| LogicalResult ConvertTFReciprocalOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op); |
| |
| auto status_or_const_op = CreateConstOpWithSingleValue( |
| &rewriter, op->getLoc(), |
| tf_reciprocal_op.x().getType().cast<ShapedType>(), 1); |
| if (!status_or_const_op.ok()) { |
| return failure(); |
| } |
| |
| StringAttr fused_activation_function = |
| StringAttr::get("NONE", rewriter.getContext()); |
| |
| rewriter.replaceOpWithNewOp<TFL::DivOp>(op, status_or_const_op.ValueOrDie(), |
| tf_reciprocal_op.x(), |
| fused_activation_function); |
| return success(); |
| } |
| |
| LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( |
| Operation* op, PatternRewriter& rewriter) const { |
| auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op); |
| auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>(); |
| auto output_type = tf_broadcast_to_op.output().getType(); |
| |
| auto status_or_const_op = |
| CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1); |
| if (!status_or_const_op.ok()) { |
| return failure(); |
| } |
| |
| auto tfl_fill_op = rewriter.create<TFL::FillOp>( |
| op->getLoc(), output_type, tf_broadcast_to_op.shape(), |
| status_or_const_op.ValueOrDie()); |
| |
| StringAttr fused_activation_function = |
| StringAttr::get("NONE", rewriter.getContext()); |
| |
| rewriter.replaceOpWithNewOp<TFL::MulOp>( |
| op, output_type, tf_broadcast_to_op.input(), tfl_fill_op, |
| fused_activation_function); |
| return success(); |
| } |
| |
| // Legalize unidirectional sequence lstm. |
| struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { |
| explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) |
| : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation* op, |
| PatternRewriter& rewriter) const override { |
| auto tflite_indices_attr = |
| op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices); |
| if (!tflite_indices_attr) return failure(); |
| |
| SmallVector<int64_t, 20> tflite_indices; |
| for (auto index_attr : tflite_indices_attr.getValue()) { |
| IntegerAttr index = index_attr.cast<IntegerAttr>(); |
| tflite_indices.push_back(index.getInt()); |
| } |
| |
| // Optional input placeholder. |
| Value none = rewriter.create<mlir::ConstantOp>( |
| op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); |
| |
| // Populate inputs. |
| // UnidirectionalSequenceLstm is expected to have 24 inputs. |
| SmallVector<Value, 24> inputs; |
| int count = 0; |
| int total_ophint_converted_inputs = tflite_indices.size(); |
| for (int i = 0; i < 24; ++i) { |
| if (count < total_ophint_converted_inputs && tflite_indices[count] == i) { |
| // specified input. |
| inputs.push_back(op->getOperand(i)); |
| count++; |
| } else { |
| // Non specified input. |
| inputs.push_back(none); |
| } |
| } |
| |
| // Populate outputs. |
| // UnidirectionalSequenceLstm should only have 1 output, and that is the |
| // original ophint converted node's 3rd output. |
| SmallVector<Type, 4> result_types; |
| result_types.push_back(op->getOpResult(2).getType()); |
| |
| // Populate attributes. |
| SmallVector<NamedAttribute, 4> attributes; |
| // Activation will always be tanh. |
| attributes.push_back(rewriter.getNamedAttr("fused_activation_function", |
| rewriter.getStringAttr("TANH"))); |
| // cell_clip. |
| attributes.push_back( |
| rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0))); |
| // proj_clip. |
| attributes.push_back( |
| rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0))); |
| // will always be time_majored. |
| attributes.push_back( |
| rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true))); |
| |
| auto lstm_op = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>( |
| op->getLoc(), result_types, inputs, attributes); |
| |
| // Rewire the output. |
| op->getResult(2).replaceAllUsesWith(lstm_op.getResult()); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| // Legalize unidirectional seqeucen rnn. |
| struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { |
| explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context) |
| : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation* op, |
| PatternRewriter& rewriter) const override { |
| auto tflite_indices_attr = |
| op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices); |
| if (!tflite_indices_attr) return failure(); |
| |
| if (op->getNumOperands() != 5) { |
| op->emitError() |
| << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only " |
| << op->getNumOperands() << " provided"; |
| return failure(); |
| } |
| |
| if (op->getNumResults() != 2) { |
| op->emitError() |
| << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only " |
| << op->getNumResults() << " found"; |
| return failure(); |
| } |
| |
| // Populate inputs. |
| // UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them |
| // are optional inputs. |
| SmallVector<Value, 5> inputs; |
| for (int i = 0; i < 5; ++i) { |
| inputs.push_back(op->getOperand(i)); |
| } |
| |
| // Populate outputs. |
| // UnidirectionalSequenceRnn should only have 1 output, and that is the |
| // original ophint converted node's 2nd output. |
| SmallVector<Type, 4> result_types; |
| result_types.push_back(op->getOpResult(1).getType()); |
| |
| // Populate attributes. |
| SmallVector<NamedAttribute, 2> attributes; |
| // Activation will always be tanh. |
| attributes.push_back(rewriter.getNamedAttr("fused_activation_function", |
| rewriter.getStringAttr("TANH"))); |
| |
| // will always be time_majored. |
| attributes.push_back( |
| rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true))); |
| |
| auto rnn_op = rewriter.create<TFL::UnidirectionalSequenceRNNOp>( |
| op->getLoc(), result_types, inputs, attributes); |
| |
| // Rewire the output. |
| op->getResult(1).replaceAllUsesWith(rnn_op.getResult()); |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| }; |
| |
| void LegalizeTF::runOnFunction() { |
| OwningRewritePatternList patterns; |
| auto* context = &getContext(); |
| auto func = getFunction(); |
| |
| // Add the generated patterns to the list. |
| populateWithGenerated(context, &patterns); |
| patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp, |
| ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, |
| ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp, |
| ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, |
| ConvertTFAssertOp, ConvertTFReciprocalOp, |
| ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context); |
| |
| // Ophint python converter converted tf node pattern. |
| patterns.insert<LegalizeUnidirectionalSequenceLstm, |
| LegalizeUnidirectionalSequenceRnn>(context); |
| |
| ConversionTarget target(*context); |
| // It is legal to have TF ops in the graph still which can be |
| // used later or in the case of SELECT were we allow TF ops in the final |
| // graph. |
| target.addLegalOp<mlir::ConstantOp>(); |
| target.addLegalOp<ConstOp>(); |
| if (run_tfl_runtime_verification_) { |
| target.addDynamicallyLegalDialect<TensorFlowLiteDialect>( |
| Optional<ConversionTarget::DynamicLegalityCallbackFn>( |
| [](Operation* op) { |
| auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op); |
| if (!tfl_op) return false; |
| return succeeded(tfl_op.VerifyTflRuntimeConstraints( |
| tfl_op.getOperation(), |
| /*failure_on_operand_type_mismatch=*/false)); |
| })); |
| } else { |
| target.addLegalDialect<TensorFlowLiteDialect>(); |
| } |
| // Keep trying to convert. |
| // TODO(karimnosseir): This is similar to what apply greedy patterns does. |
| // Look if there is a function that tries until it converge. |
| // Currently unit-test doesn't do multiple tries, so we need this. |
| const int max_iterations = 15; |
| for (int i = 0; i < max_iterations; ++i) { |
| if (failed(applyPartialConversion(func, target, patterns))) { |
| return; |
| } |
| } |
| } |
| |
| } // namespace |
| |
| // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. |
| std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass( |
| bool run_tfl_runtime_verification) { |
| return std::make_unique<LegalizeTF>(run_tfl_runtime_verification); |
| } |
| |
| static PassRegistration<LegalizeTF> pass( |
| "tfl-legalize-tf", "Legalize from TensorFlow to TensorFlow Lite dialect"); |
| |
| } // namespace TFL |
| } // namespace mlir |