| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This transformation pass prepares for legalization to the TFLite dialect by |
| // converting operations in TensorFlow dialect into operations that can be |
| // legalized to TensorFlow Lite dialect with simple replacements. The newly |
| // created operations are in the TensorFlow dialect if the operation can be |
| // represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is |
| // used. For example, Conv2D in TFLite which uses OHWI data format for filters |
| // is not supported in TensorFlow because TensorFlow requires filters in the |
| // HWIO data format. |
| // |
| // Motivation to prepare for the TFLite legalization before the actual |
| // legalization is to exploit constant folding opportunities in any newly |
| // created ops by leveraging constant folding support for the TensorFlow ops. |
| // This way TFLite can be used as a serialization format only and does not |
| // require access to the TFLite runtime for optimizations as required by the |
| // TFLite team. |
| |
| #include <climits> |
| #include <cstdint> |
| |
| #include "absl/memory/memory.h" |
| #include "absl/numeric/bits.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/DialectConversion.h" // from @llvm-project |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" |
| #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" |
| #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" |
| #include "tensorflow/compiler/mlir/lite/utils/validators.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
| |
| #define DEBUG_TYPE "tf-tfl-legalization" |
| |
| namespace mlir { |
| namespace TFL { |
| namespace { |
| // Returns a TF_CastOp to I32. This function is used for CastOps that are |
| // intermediate nodes in a TableGen pattern result. In such a case, the |
| // destination type is not inferred and must be given explicitly. |
| // |
| // Preconditions: The given value must have a ShapedType. |
| static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x, |
| BoolAttr truncate) { |
| auto x_type = x.getType().dyn_cast_or_null<ShapedType>(); |
| if (!x_type) llvm_unreachable("unsupported type"); |
| Type type = x_type.clone(builder->getI32Type()); |
| return builder->create<TF::CastOp>(loc, type, x, truncate); |
| } |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // The actual PrepareTF Pass. |
| // |
| // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this |
| // pass. |
| namespace { |
| |
| // Prepare TF operations in functions for subsequent legalization. |
| class PrepareTFPass |
| : public PassWrapper<PrepareTFPass, OperationPass<func::FuncOp>> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareTFPass) |
| |
| PrepareTFPass() = default; |
| PrepareTFPass(const PrepareTFPass &) {} |
| explicit PrepareTFPass(bool unfold_batch_matmul, |
| bool allow_bf16_and_f16_type_legalization, |
| bool use_fake_quant_num_bits = false) { |
| unfold_batch_matmul_ = unfold_batch_matmul; |
| allow_bf16_and_f16_type_legalization_ = |
| allow_bf16_and_f16_type_legalization; |
| use_fake_quant_num_bits_ = use_fake_quant_num_bits; |
| } |
| |
| StringRef getArgument() const final { |
| // This is the argument used to refer to the pass in |
| // the textual format (on the commandline for example). |
| return "tfl-prepare-tf"; |
| } |
| StringRef getDescription() const final { |
| // This is a brief description of the pass. |
| return "Prepare TF for legalization to TensorFlow Lite dialect"; |
| } |
| |
| void runOnOperation() override; |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<mhlo::MhloDialect, quant::QuantizationDialect, |
| TFL::TensorFlowLiteDialect>(); |
| } |
| |
| private: |
| Option<bool> unfold_batch_matmul_{ |
| *this, "tfl-unfold-batch-matmul", |
| llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."), |
| llvm::cl::init(true)}; |
| |
| Option<bool> allow_bf16_and_f16_type_legalization_{ |
| *this, "tfl-allow-bf16-and-f16-type-legalization", |
| llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)}; |
| |
| Option<bool> use_fake_quant_num_bits_{ |
| *this, "tfl-use-fake-quant-num-bits", |
| llvm::cl::desc("Use quantization calculated from fake quant attributes."), |
| llvm::cl::init(false)}; |
| }; |
| |
| // Transient state for preserving data from match to rewrite |
| struct ConvertTFConvOpMatchState { |
| IntegerAttr dilation_height_factor; |
| IntegerAttr dilation_width_factor; |
| StringAttr padding; |
| IntegerAttr stride_height; |
| IntegerAttr stride_width; |
| }; |
| |
| // Templated class for declaring a converter from some TensorFlow convolution |
| // op into its counterpart in TensorFlow Lite. |
| // |
| // The `ConcreteType` deriving from this template must provide the following |
| // method for constructing TensorFlow Lite op: |
| // |
| // TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state, |
| // PatternRewriter &rewriter, Location loc, |
| // Type result_type, Value input, |
| // Value filter, Value bias) const; |
| // |
| // And also the following method for getting the dimension for bias tensor: |
| // |
| // int64_t getBiasDim(ArrayRef<int64_t> filterShape) const; |
| template <typename ConcreteType, typename TFConvOpType> |
| class ConvertTFConvOp : public RewritePattern { |
| public: |
| ConvertTFConvOp(MLIRContext *context, |
| bool allow_bf16_and_f16_type_legalization) |
| : RewritePattern(TFConvOpType::getOperationName(), 1, context), |
| intAttrOne(Builder(context).getI32IntegerAttr(1)), |
| allow_bf16_and_f16_type_legalization_( |
| allow_bf16_and_f16_type_legalization) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| // Assumes TensorFlow convolution op is already verified to be |
| // in valid form. |
| |
| // Match a TFConvOpType under the following conditions: |
| // * The 'T' attribute must exist and be of value DT_FLOAT. |
| // * The 'data_format' attribute must exist and be of value "NHWC". |
| // * The 'strides' attribute must exist and is of the form [1, X, Y, 1]. |
| // * The 'dilations' attribute is optional, but it must be of the form |
| // [1, X, Y, 1] if exists. |
| |
| TFConvOpType tf_op = cast<TFConvOpType>(op); |
| if (!TFTypeIsFloat32Tensor(tf_op.input()) && |
| !(allow_bf16_and_f16_type_legalization_ && |
| TFTypeIsBFloat16OrHalfTensor(tf_op.input()))) |
| return failure(); |
| |
| if (!TFDataFormatIsNHWC(op)) return failure(); |
| |
| IntegerAttr height, width; |
| if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure(); |
| |
| ConvertTFConvOpMatchState state; |
| state.stride_height = height; |
| state.stride_width = width; |
| |
| if (TFIntListIs1XY1(op, "dilations", &height, &width)) { |
| state.dilation_height_factor = height; |
| state.dilation_width_factor = width; |
| } else { |
| // If the 'dilations' attribute is missing, we use the default value (1) |
| // for both dilation height and width factor. |
| state.dilation_height_factor = intAttrOne; |
| state.dilation_width_factor = intAttrOne; |
| } |
| |
| TFPaddingIsSameOrValid(op, &state.padding); |
| |
| // Additionally, we require the filter operand to be of 4-D tensor type so |
| // that we can extract info from the shape (e.g., for constructing bias |
| // tensor, for setting depth_multiplier attribute, etc.). |
| auto filter = tf_op.filter(); |
| auto filter_type = filter.getType().template dyn_cast<RankedTensorType>(); |
| if (!filter_type || filter_type.getRank() != 4 || |
| !filter_type.hasStaticShape()) |
| return failure(); |
| |
| Value input = tf_op.input(); |
| RankedTensorType input_type = |
| input.getType().template dyn_cast<RankedTensorType>(); |
| // Only rank size four input will be only available by the tf.Conv2D |
| // operator verification. |
| if (!input_type || input_type.isDynamicDim(3)) { |
| return failure(); |
| } |
| // Check if the given op is based on grouped convolution. |
| // Dim size zero will be verified by the tf.Conv2D operator verification. |
| if (input_type.getDimSize(3) % filter_type.getDimSize(2) != 0) { |
| return failure(); |
| } |
| |
| // TensorFlow convolution op only has two inputs, while the TFLite one has |
| // three, with the bias vector marked as optional. However, TOCO has a |
| // dedicated pass, EnsureBiasVectors, to create default bias vectors for all |
| // those missing. So we model TFLite convolution op as requiring three |
| // inputs to achieve the legalization task of EnsureBiasVector. this |
| // requires the filter tensor to have static shape. |
| |
| // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>) |
| |
| // Get a splat zero tensor with the expected dimension for the bias tensor |
| auto elem_type = filter_type.getElementType(); |
| auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim( |
| filter_type.getShape()); |
| auto bias_type = RankedTensorType::get({bias_dim}, elem_type); |
| auto bias_attr = rewriter.getZeroAttr(bias_type); |
| auto bias = |
| rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr); |
| |
| if (op->getAttrOfType<StringAttr>("padding").getValue() == "EXPLICIT") { |
| // Add Const op for padding value. |
| ArrayRef<Attribute> padding_attr_array = |
| op->getAttrOfType<ArrayAttr>("explicit_paddings").getValue(); |
| |
| auto get_int = [](Attribute attr) { |
| return attr.template cast<IntegerAttr>().getInt(); |
| }; |
| |
| SmallVector<int32_t> padding_values(padding_attr_array.size()); |
| for (int i = 0; i < padding_attr_array.size(); i++) { |
| padding_values[i] = |
| static_cast<int32_t>(get_int(padding_attr_array[i])); |
| } |
| |
| RankedTensorType padding_attr_type = RankedTensorType::get( |
| {filter_type.getRank(), 2}, rewriter.getIntegerType(32)); |
| auto padding_attr = |
| mlir::DenseIntElementsAttr::get(padding_attr_type, padding_values); |
| |
| auto padding_const = |
| rewriter.create<TF::ConstOp>(op->getLoc(), padding_attr); |
| |
| // Add Pad op. |
| auto pad_output_type = UnrankedTensorType::get(elem_type); |
| input = rewriter.create<TF::PadOp>(op->getLoc(), pad_output_type, input, |
| padding_const); |
| |
| // Set Conv padding to `VALID` since padding has been handled by Pad op. |
| state.padding = rewriter.getStringAttr("VALID"); |
| } |
| auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp( |
| &state, rewriter, op->getLoc(), tf_op.getType(), input, filter, bias); |
| |
| rewriter.replaceOp(op, conv_op.getResult()); |
| return success(); |
| } |
| |
| const IntegerAttr intAttrOne; |
| |
| private: |
| bool allow_bf16_and_f16_type_legalization_; |
| }; |
| |
| class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> { |
| public: |
| using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>; |
| |
| ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization) |
| : BaseType(context, allow_bf16_type_legalization) {} |
| |
| int64_t getBiasDim(ArrayRef<int64_t> filterShape) const { |
| return filterShape.back(); |
| } |
| |
| TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state, |
| PatternRewriter &rewriter, Location loc, |
| Type result_type, Value input, Value filter, |
| Value bias) const { |
| filter = legalizeFilter(rewriter, loc, filter); |
| return rewriter.create<TFL::Conv2DOp>( |
| loc, result_type, input, filter, bias, |
| /*dilation_h_factor=*/state->dilation_height_factor, |
| /*dilation_w_factor=*/state->dilation_width_factor, |
| /*fused_activation_function=*/rewriter.getStringAttr("NONE"), |
| /*padding=*/state->padding, |
| /*stride_h=*/state->stride_height, |
| /*stride_w=*/state->stride_width); |
| } |
| |
| private: |
| // Legalize the given filter by converting it from TensorFlow filter data |
| // format HWIO to TFLite Conv2D op filter data format OHWI and return Value |
| // for the converted filter. Requires that filter is verified by the match |
| // method that it is a 4-D RankedTensorType. |
| Value legalizeFilter(PatternRewriter &rewriter, Location loc, |
| Value filter) const { |
| // Create a constant op for HWIO to OHWI transpose permutation. |
| SmallVector<int, 4> perm = {3, 0, 1, 2}; |
| auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())}, |
| rewriter.getIntegerType(32)); |
| auto perm_attr = |
| DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm)); |
| auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr); |
| |
| // Create tensor type for the transpose result. |
| auto filter_type = filter.getType().cast<RankedTensorType>(); |
| auto result_shape = |
| llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) { |
| return filter_type.getDimSize(dim); |
| })); |
| auto elem_type = filter_type.getElementType(); |
| auto result_type = RankedTensorType::get(result_shape, elem_type); |
| |
| return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op); |
| } |
| }; |
| |
| class ConvertTFDepthwiseConv2dNative |
| : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative, |
| TF::DepthwiseConv2dNativeOp> { |
| public: |
| using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative, |
| TF::DepthwiseConv2dNativeOp>; |
| |
| ConvertTFDepthwiseConv2dNative(MLIRContext *context, |
| bool allow_bf16_type_legalization) |
| : BaseType(context, allow_bf16_type_legalization) {} |
| |
| int64_t getBiasDim(ArrayRef<int64_t> filterShape) const { |
| return filterShape[2] * filterShape[3]; |
| } |
| |
| TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state, |
| PatternRewriter &rewriter, Location loc, |
| Type result_type, Value input, |
| Value filter, Value bias) const { |
| // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional |
| // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not |
| // have a corresponding 'depth_multiplier' attribute; the multiplier is the |
| // fourth dimension in the 4-D filter tensor. We query the multiplier from |
| // tf.DepthwiseConv2dNative and set it as the attribute value accordingly. |
| auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3); |
| |
| filter = legalizeFilter(rewriter, loc, filter); |
| return rewriter.create<TFL::DepthwiseConv2DOp>( |
| loc, result_type, input, filter, bias, |
| /*dilation_h_factor=*/state->dilation_height_factor, |
| /*dilation_w_factor=*/state->dilation_width_factor, |
| /*fused_activation_function=*/rewriter.getStringAttr("NONE"), |
| /*padding=*/state->padding, |
| /*stride_h=*/state->stride_height, |
| /*stride_w=*/state->stride_width, |
| /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier)); |
| } |
| |
| private: |
| /// Legalize the given filter by converting it from TensorFlow filter data |
| /// format to TFLite DepthwiseConv2D op filter data format and return Value |
| /// for the converted filter. TensorFlow filter data format is |
| /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite |
| /// filter data format is [1, filter_height, filter_width, out_channels]. |
| /// Requires that filter is verified by the match method that it is a 4-D |
| /// RankedTensorType. |
| Value legalizeFilter(PatternRewriter &rewriter, Location loc, |
| Value filter) const { |
| auto filter_type = filter.getType().cast<RankedTensorType>(); |
| auto filterShape = filter_type.getShape(); |
| SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1], |
| filterShape[2] * filterShape[3]}; |
| auto elem_type = filter_type.getElementType(); |
| auto result_type = RankedTensorType::get(result_shape, elem_type); |
| // TensorFlow Lite `Reshape` op only support int32 shape tensor currently. |
| auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32)); |
| SmallVector<Attribute, 4> result_shape_data(4); |
| for (int i = 0; i < 4; ++i) { |
| result_shape_data[i] = |
| rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i])); |
| } |
| auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data); |
| auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr); |
| |
| return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape); |
| } |
| }; |
| |
| // StridedSlice can have complicated attributes like begin_axis_mask, |
| // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These |
| // masks will complicate the strided_slice computation logic, we can simplify |
| // the logic by inserting a reshape op to pad the inputs so strided_slice can |
| // be easier to handle. |
| // |
| // So the graph may looks like below: |
| // original_input -> strided_slice -> output |
| // (transforms) |
| // original_input -> reshape -> strided_slice -> output |
| // |
| // And the new shape is computed based on the masks. |
| // |
| // An example for new_axis_mask. say the new_axis_mask is 9 which represents |
| // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so |
| // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1]. |
| struct ConvertTFStridedSlice : public RewritePattern { |
| explicit ConvertTFStridedSlice(MLIRContext *context) |
| : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {} |
| |
| LogicalResult RewriteNewAxisMask(Operation *op, |
| PatternRewriter &rewriter) const { |
| TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op); |
| uint64_t new_axis_mask = strided_slice_op.new_axis_mask(); |
| |
| if (strided_slice_op.ellipsis_mask() != 0) { |
| // Ellipsis mask should have been lowered-away prior to invoking this |
| // function. |
| op->emitError() << "encountered a logical error"; |
| return failure(); |
| } |
| |
| // Insert a new reshape op. |
| Value original_input = strided_slice_op.input(); |
| RankedTensorType original_input_type = |
| original_input.getType().dyn_cast<RankedTensorType>(); |
| if (!original_input_type) { |
| return failure(); |
| } |
| |
| const ArrayRef<int64_t> &original_input_shape = |
| original_input_type.getShape(); |
| SmallVector<int64_t, 4> revised_shape; |
| int index = 0; |
| const int original_input_rank = original_input_shape.size(); |
| while (index < original_input_rank || new_axis_mask) { |
| if (new_axis_mask & 1) { |
| revised_shape.emplace_back(1); |
| } else { |
| revised_shape.emplace_back(original_input_shape[index++]); |
| } |
| new_axis_mask >>= 1; |
| } |
| |
| if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure(); |
| |
| const int dim_size = revised_shape.size(); |
| Location loc = strided_slice_op.getLoc(); |
| auto shape_type = |
| RankedTensorType::get({dim_size}, rewriter.getIntegerType(32)); |
| SmallVector<Attribute, 4> result_shape_data(dim_size); |
| for (int i = 0; i < dim_size; ++i) { |
| result_shape_data[i] = |
| rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i])); |
| } |
| |
| auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data); |
| auto shape = |
| rewriter.create<arith::ConstantOp>(loc, shape_type, shape_attr); |
| auto revised_output_type = RankedTensorType::get( |
| revised_shape, original_input_type.getElementType()); |
| TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>( |
| loc, revised_output_type, original_input, shape); |
| |
| // Replace the original strided_slice. |
| uint64_t revised_begin_mask = strided_slice_op.begin_mask(); |
| uint64_t revised_end_mask = strided_slice_op.end_mask(); |
| // Since we expand the dims, we need to apply them to the begin_mask & |
| // end_mask. |
| revised_begin_mask |= strided_slice_op.new_axis_mask(); |
| revised_end_mask |= strided_slice_op.new_axis_mask(); |
| |
| // Enforce operator precedence. |
| uint64_t revised_shrink_axis_mask = |
| strided_slice_op.shrink_axis_mask() & ~strided_slice_op.new_axis_mask(); |
| |
| auto attribute_type = rewriter.getIntegerType(64); |
| rewriter.replaceOpWithNewOp<TF::StridedSliceOp>( |
| op, strided_slice_op.getType(), reshape, strided_slice_op.begin(), |
| strided_slice_op.end(), strided_slice_op.strides(), |
| rewriter.getIntegerAttr(attribute_type, revised_begin_mask), |
| rewriter.getIntegerAttr(attribute_type, revised_end_mask), |
| rewriter.getIntegerAttr(attribute_type, |
| strided_slice_op.ellipsis_mask()), |
| rewriter.getI64IntegerAttr(0), |
| rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask)); |
| return success(); |
| } |
| |
| LogicalResult RewriteEllipsisMask(Operation *op, |
| PatternRewriter &rewriter) const { |
| TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op); |
| |
| uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask(); |
| uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask(); |
| uint64_t new_axis_mask = strided_slice_op.new_axis_mask(); |
| |
| // Enforce operator precedence. |
| shrink_axis_mask &= ~ellipsis_mask; |
| new_axis_mask &= ~ellipsis_mask; |
| |
| DenseIntElementsAttr begin_dense_elem_attr; |
| Value begin = strided_slice_op.begin(); |
| auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>(); |
| if (!begin_ranked_attr_type || |
| !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) { |
| return failure(); |
| } |
| |
| DenseIntElementsAttr end_dense_elem_attr; |
| Value end = strided_slice_op.end(); |
| auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>(); |
| if (!end_ranked_attr_type || |
| !matchPattern(end, m_Constant(&end_dense_elem_attr))) { |
| return failure(); |
| } |
| |
| DenseIntElementsAttr stride_dense_elem_attr; |
| Value stride = strided_slice_op.strides(); |
| auto stride_ranked_attr_type = |
| stride.getType().dyn_cast<RankedTensorType>(); |
| if (!stride_ranked_attr_type || |
| !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) { |
| return failure(); |
| } |
| |
| Value input = strided_slice_op.input(); |
| RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>(); |
| if (!input_type) { |
| return failure(); |
| } |
| const ArrayRef<int64_t> input_shape = input_type.getShape(); |
| |
| const int input_size = input_shape.size(); |
| |
| RankedTensorType begin_type = begin.getType().cast<RankedTensorType>(); |
| const ArrayRef<int64_t> begin_shape = begin_type.getShape(); |
| const int begin_dim = begin_shape.size(); |
| |
| if (begin_dim != 1) return failure(); |
| |
| // The ellipsis fill might exceed the current output shape because we are |
| // also taking account of any to-be-inserted new axes. |
| const int ellipsis_filled_dim_size = |
| input_size - begin_shape[0] + 1 + absl::popcount(new_axis_mask); |
| |
| int64_t begin_mask = strided_slice_op.begin_mask(); |
| int64_t end_mask = strided_slice_op.end_mask(); |
| int64_t revised_begin_mask = 0; |
| int64_t revised_end_mask = 0; |
| int64_t revised_shrink_axis_mask = 0; |
| int64_t revised_new_axis_mask = 0; |
| |
| SmallVector<int32_t, 4> padded_begin; |
| SmallVector<int32_t, 4> padded_end; |
| SmallVector<int32_t, 4> padded_stride; |
| |
| // Before the ellipsis. |
| int index = 0; |
| int new_index = 0; |
| while (((ellipsis_mask >> index) & 1) == 0) { |
| padded_begin.push_back(begin_dense_elem_attr.getValues<int32_t>()[index]); |
| padded_end.push_back(end_dense_elem_attr.getValues<int32_t>()[index]); |
| padded_stride.push_back( |
| stride_dense_elem_attr.getValues<int32_t>()[index]); |
| if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index); |
| if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index); |
| if ((shrink_axis_mask >> index) & 1) |
| revised_shrink_axis_mask |= (1 << new_index); |
| |
| if ((new_axis_mask >> index) & 1) |
| revised_new_axis_mask |= (1 << new_index); |
| |
| ++index; |
| ++new_index; |
| } |
| |
| // Ellipsis. |
| for (; new_index < index + ellipsis_filled_dim_size; ++new_index) { |
| revised_begin_mask |= (1 << new_index); |
| revised_end_mask |= (1 << new_index); |
| |
| // Mimic the begin/end/strides mask behavior. |
| padded_begin.push_back(0); |
| padded_end.push_back(0); |
| padded_stride.push_back(1); |
| } |
| |
| // Account for ellipsis mask. |
| ++index; |
| |
| // After the ellipsis. |
| for (; index < begin_shape[0];) { |
| padded_begin.push_back(begin_dense_elem_attr.getValues<int32_t>()[index]); |
| padded_end.push_back(end_dense_elem_attr.getValues<int32_t>()[index]); |
| padded_stride.push_back( |
| stride_dense_elem_attr.getValues<int32_t>()[index]); |
| |
| if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index); |
| if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index); |
| if ((shrink_axis_mask >> index) & 1) |
| revised_shrink_axis_mask |= (1 << new_index); |
| if ((new_axis_mask >> index) & 1) |
| revised_new_axis_mask |= (1 << new_index); |
| |
| ++index; |
| ++new_index; |
| } |
| |
| auto attribute_type = rewriter.getIntegerType(64); |
| |
| int full_dim_count = padded_begin.size(); |
| auto type = |
| RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32)); |
| |
| auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin); |
| auto begin_op = |
| rewriter.create<arith::ConstantOp>(op->getLoc(), type, begin_attr); |
| auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end); |
| auto end_op = |
| rewriter.create<arith::ConstantOp>(op->getLoc(), type, end_attr); |
| auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride); |
| auto stride_op = |
| rewriter.create<arith::ConstantOp>(op->getLoc(), type, stride_attr); |
| |
| rewriter.replaceOpWithNewOp<TF::StridedSliceOp>( |
| op, strided_slice_op.getType(), input, begin_op.getResult(), |
| end_op.getResult(), stride_op.getResult(), |
| rewriter.getIntegerAttr(attribute_type, revised_begin_mask), |
| rewriter.getIntegerAttr(attribute_type, revised_end_mask), |
| /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0), |
| rewriter.getIntegerAttr(attribute_type, revised_new_axis_mask), |
| rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask)); |
| |
| return success(); |
| } |
| |
| void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr, |
| SmallVectorImpl<int32_t> &val, |
| SmallVectorImpl<int32_t> &padded_val, |
| ArrayRef<int32_t> padding_val, |
| int *mask) const { |
| for (const auto &idx : dense_elem_attr.getValues<APInt>()) { |
| val.push_back(idx.getSExtValue()); |
| padded_val.push_back(idx.getSExtValue()); |
| } |
| int attr_dim_count = val.size(); |
| int full_dim_count = padding_val.size(); |
| for (int i = attr_dim_count; i < full_dim_count; ++i) { |
| padded_val.push_back(padding_val[i]); |
| if (mask) *mask |= 1 << i; |
| } |
| } |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op); |
| |
| // Handle ellipsis mask. |
| if (strided_slice_op.ellipsis_mask() != 0) { |
| return RewriteEllipsisMask(strided_slice_op, rewriter); |
| } |
| |
| // Handle new axis mask. |
| if (strided_slice_op.new_axis_mask() != 0) { |
| return RewriteNewAxisMask(strided_slice_op, rewriter); |
| } |
| |
| auto ranked_input_type = |
| strided_slice_op.input().getType().dyn_cast<RankedTensorType>(); |
| if (!ranked_input_type) { |
| return failure(); |
| } |
| |
| auto begin_attr = strided_slice_op.begin(); |
| auto end_attr = strided_slice_op.end(); |
| auto strides_attr = strided_slice_op.strides(); |
| |
| auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>(); |
| auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>(); |
| auto strides_attr_type = |
| strides_attr.getType().dyn_cast<RankedTensorType>(); |
| |
| DenseIntElementsAttr begin_elem_attr; |
| DenseIntElementsAttr end_elem_attr; |
| DenseIntElementsAttr strides_elem_attr; |
| |
| if (!begin_attr_type || |
| !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) { |
| return failure(); |
| } |
| if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) { |
| return failure(); |
| } |
| if (!strides_attr_type || |
| !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) { |
| return failure(); |
| } |
| |
| SmallVector<int32_t, 4> begin, end, strides; |
| SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides; |
| |
| int num_input_dims = ranked_input_type.getRank(); |
| SmallVector<int32_t, 4> padding_begin(num_input_dims, 0); |
| auto input_shape = ranked_input_type.getShape(); |
| SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end()); |
| SmallVector<int32_t, 4> padding_strides(num_input_dims, 1); |
| |
| int begin_mask = strided_slice_op.begin_mask(); |
| int end_mask = strided_slice_op.end_mask(); |
| |
| PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin, |
| padding_begin, &begin_mask); |
| PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end, |
| &end_mask); |
| PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides, |
| padding_strides, nullptr); |
| |
| if (begin == padded_begin && end == padded_end && |
| strides == padded_strides && |
| begin_mask == strided_slice_op.begin_mask() && |
| end_mask == strided_slice_op.end_mask()) { |
| return failure(); |
| } |
| |
| auto begin_end_type = |
| RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32)); |
| auto new_begin_attr = rewriter.create<arith::ConstantOp>( |
| op->getLoc(), begin_end_type, |
| DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin)); |
| auto new_end_attr = rewriter.create<arith::ConstantOp>( |
| op->getLoc(), begin_end_type, |
| DenseElementsAttr::get<int32_t>(begin_end_type, padded_end)); |
| auto strides_type = |
| RankedTensorType::get({static_cast<long>(padded_strides.size())}, |
| rewriter.getIntegerType(32)); |
| auto new_strides_attr = rewriter.create<arith::ConstantOp>( |
| op->getLoc(), strides_type, |
| DenseElementsAttr::get<int32_t>(strides_type, padded_strides)); |
| |
| auto attribute_type = rewriter.getIntegerType(64); |
| rewriter.replaceOpWithNewOp<TF::StridedSliceOp>( |
| op, strided_slice_op.output().getType(), strided_slice_op.input(), |
| new_begin_attr, new_end_attr, new_strides_attr, |
| rewriter.getIntegerAttr(attribute_type, begin_mask), |
| rewriter.getIntegerAttr(attribute_type, end_mask), |
| rewriter.getIntegerAttr(attribute_type, |
| strided_slice_op.ellipsis_mask()), |
| rewriter.getIntegerAttr(attribute_type, |
| strided_slice_op.new_axis_mask()), |
| rewriter.getIntegerAttr(attribute_type, |
| strided_slice_op.shrink_axis_mask())); |
| |
| return success(); |
| } |
| }; |
| |
| struct ConvertTFBroadcastTo : public RewritePattern { |
| explicit ConvertTFBroadcastTo(MLIRContext *context) |
| : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op); |
| auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>(); |
| auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>(); |
| auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>(); |
| Type element_type = input_type.getElementType(); |
| |
| // Allow lowering when low dimension inputs are given and its type is F32 or |
| // I32. |
| if (!((output_type.hasRank() && output_type.getRank() <= 4) || |
| (shape_type.hasStaticShape() && shape_type.getRank() == 1 && |
| shape_type.getDimSize(0) <= 4))) |
| return failure(); |
| |
| if (!(element_type.isa<BFloat16Type, Float32Type>() || |
| element_type.isInteger(32) || element_type.isInteger(16))) |
| return failure(); |
| |
| auto status_or_const_op = |
| CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1); |
| if (!status_or_const_op.ok()) { |
| return failure(); |
| } |
| |
| auto tf_fill_op = rewriter.create<TF::FillOp>( |
| op->getLoc(), output_type, tf_broadcast_to_op.shape(), |
| status_or_const_op.ValueOrDie()); |
| |
| auto mul_op = rewriter.create<TF::MulOp>( |
| op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op); |
| rewriter.replaceOp(op, mul_op.getResult()); |
| return success(); |
| } |
| }; |
| |
| // The below pattern is equivalent to the DRR rule below |
| // The checks are dependent on generated values, so we can't add |
| // the checks on intermediate values, ideally we should find equivalent |
| // checks that guarantees the resultant ops are valid. |
| // The extra conditions are the broadcasting conditions. |
| // |
| // The pattern lower FusedBatchNormV3 to arithmetic ops. |
| // Specifically, performs the following calculation: |
| // |
| // (x - mean) * scale / sqrt(variance + epsilon) + offset |
| // |
| // Let multiplier = scale / sqrt(variance + epsilon), |
| // to compute |
| // (x - mean) * scale / sqrt(variance + epsilon) + offset, |
| // is then to compute |
| // (x * multiplier) + (offset - mean * multiplier). |
| // |
| // def : Pattern< |
| // (TF_FusedBatchNormV3Op:$root |
| // $x, $scale, $offset, $mean, $variance, |
| // F32Attr:$epsilon, $exponential_avg_factor, |
| // $data_format, FalseBoolAttr:$is_training), |
| // [(TF_AddOp |
| // (TF_MulOp |
| // $x, |
| // (TF_MulOp:$multiplier |
| // $scale, |
| // (TF_RsqrtOp |
| // (TF_AddOp $variance, |
| // (TF_ConstOp $epsilon))))), |
| // (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), |
| // // We already guaranteed that the last five results have no use so it does |
| // // not matter what value we provide here for replacement. |
| // /*batch_mean=*/(replaceWithValue $x), |
| // /*batch_variance=*/(replaceWithValue $x), |
| // /*reserve_space_1=*/(replaceWithValue $x), |
| // /*reserve_space_2=*/(replaceWithValue $x), |
| // /*reserve_space_3=*/(replaceWithValue $x)], |
| // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), |
| // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), |
| // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>; |
| // |
| // When is_training is set to true, the given variance and mean are not used. |
| // In above calculation, they are replaced by new values. These new mean and |
| // variance are calculated as following: |
| // new_mean = mean(x, axis=[0, 1, 2]) |
| // new_variance = mean(squared_difference(x, new_mean), axis=[0, 1, 2]) |
| // |
| // The DDR rule for the is_training equals true case is as following: |
| // def : Pattern< |
| // (TF_FusedBatchNormV3Op:$root |
| // $x, $scale, $offset, $mean, $variance, |
| // F32Attr:$epsilon, $exponential_avg_factor, |
| // $data_format, FalseBoolAttr:$is_training), |
| // [(TF_AddOp |
| // (TF_MulOp |
| // $x, |
| // (TF_MulOp:$multiplier |
| // $scale, |
| // (TF_RsqrtOp |
| // (TF_AddOp |
| // (TF_MeanOp |
| // (TF_SquaredDifferenceOp $x, $new_mean), |
| // (TF_ConstOp [0,1,2])), |
| // (TF_ConstOp $epsilon))))), |
| // (TF_SubOp |
| // $offset, |
| // (TF_MulOp |
| // (TF_MeanOp $x, (TF_ConstOp [0,1,2])), |
| // $multiplier))), |
| // // We already guaranteed that the last five results have no use so it does |
| // // not matter what value we provide here for replacement. |
| // /*batch_mean=*/(replaceWithValue $x), |
| // /*batch_variance=*/(replaceWithValue $x), |
| // /*reserve_space_1=*/(replaceWithValue $x), |
| // /*reserve_space_2=*/(replaceWithValue $x), |
| // /*reserve_space_3=*/(replaceWithValue $x)], |
| // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), |
| // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), |
| // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>; |
| |
| struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { |
| explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context) |
| : ::mlir::RewritePattern( |
| "tf.FusedBatchNormV3", 1, context, |
| {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}) {} |
| |
| ::mlir::LogicalResult matchAndRewrite( |
| ::mlir::Operation *fused_batch_norm, |
| ::mlir::PatternRewriter &rewriter) const override { |
| // Variables for capturing values and attributes used for creating ops |
| Operation::operand_range mean(fused_batch_norm->getOperands()); |
| ::mlir::FloatAttr exponential_avg_factor; |
| ::mlir::TF::FusedBatchNormV3Op root; |
| Operation::operand_range offset(fused_batch_norm->getOperands()); |
| Operation::operand_range x(fused_batch_norm->getOperands()); |
| Operation::operand_range scale(fused_batch_norm->getOperands()); |
| Operation::operand_range variance(fused_batch_norm->getOperands()); |
| ::mlir::FloatAttr epsilon; |
| ::mlir::BoolAttr is_training; |
| |
| // Match |
| auto fused_batch_norm_op = |
| dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm); |
| root = fused_batch_norm_op; |
| x = fused_batch_norm_op.getODSOperands(0); |
| scale = fused_batch_norm_op.getODSOperands(1); |
| offset = fused_batch_norm_op.getODSOperands(2); |
| mean = fused_batch_norm_op.getODSOperands(3); |
| variance = fused_batch_norm_op.getODSOperands(4); |
| |
| ::mlir::Value mean_value = (*mean.begin()); |
| ::mlir::Value variance_value = (*variance.begin()); |
| |
| if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure(); |
| |
| { |
| epsilon = |
| fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon"); |
| if (!epsilon) |
| epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f); |
| |
| if (!(((epsilon.isa<::mlir::FloatAttr>())) && |
| ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to " |
| "satisfy constraint: 32-bit float attribute"; |
| }); |
| } |
| } |
| { |
| exponential_avg_factor = |
| fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>( |
| "exponential_avg_factor"); |
| if (!exponential_avg_factor) |
| exponential_avg_factor = |
| rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f); |
| } |
| if (!TFDataFormatIsNHWC(fused_batch_norm_op) && |
| !TFDataFormatIsNDHWC(fused_batch_norm_op)) |
| return failure(); |
| |
| if (!(((*root.getODSResults(1).begin()).use_empty()))) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "entities '' failed to satisfy constraint: has no use"; |
| }); |
| } |
| |
| if (!(((*root.getODSResults(2).begin()).use_empty()))) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "entities '' failed to satisfy constraint: has no use"; |
| }); |
| } |
| |
| if (!(((*root.getODSResults(3).begin()).use_empty()))) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "entities '' failed to satisfy constraint: has no use"; |
| }); |
| } |
| |
| if (!(((*root.getODSResults(4).begin()).use_empty()))) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "entities '' failed to satisfy constraint: has no use"; |
| }); |
| } |
| |
| if (!(((*root.getODSResults(5).begin()).use_empty()))) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "entities '' failed to satisfy constraint: has no use"; |
| }); |
| } |
| |
| is_training = |
| fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training"); |
| auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()}); |
| |
| // We need to make sure input and output shapes are compatible. |
| int64_t last_dim = -1; |
| { |
| auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) { |
| auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>(); |
| if (!v_type) return true; |
| int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1); |
| if (v_last_dim == -1) return true; |
| if (last_dim != -1 && v_last_dim != last_dim) return false; |
| last_dim = v_last_dim; |
| return true; |
| }; |
| |
| if (!is_last_dim_compatible(*x.begin(), last_dim) || |
| !is_last_dim_compatible(*scale.begin(), last_dim) || |
| !is_last_dim_compatible(*offset.begin(), last_dim)) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "Shapes of scale and offset should be 1D and " |
| "compatible with x"; |
| }); |
| } |
| |
| if (!is_training.getValue()) { |
| if (!is_last_dim_compatible(mean_value, last_dim) || |
| !is_last_dim_compatible(variance_value, last_dim)) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "Shapes of mean and variance should be 1D and " |
| "compatible with x"; |
| }); |
| } |
| } |
| |
| // Check if output shape and input shape are compatible. |
| auto x_type = (*x.begin()).getType(); |
| auto y_type = (*root.getODSResults(0).begin()).getType(); |
| if (!OpTrait::util::getBroadcastedType(x_type, y_type)) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "Shapes of x and the first output should be compatible"; |
| }); |
| } |
| } |
| |
| // For training, mean and variance is calculated from input values. |
| if (is_training.getValue()) { |
| auto input_type = fused_batch_norm_op.x() |
| .getType() |
| .dyn_cast_or_null<RankedTensorType>(); |
| if (!input_type || input_type.getRank() != 4) { |
| return rewriter.notifyMatchFailure( |
| fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { |
| diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals " |
| "True is only supported with input of rank 4"; |
| }); |
| } |
| |
| ::mlir::TF::ConstOp reduce_dim_op; |
| { |
| auto reduce_dim_type = |
| ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(32)); |
| ::mlir::SmallVector<int32_t, 3> reduce_dim_values = {0, 1, 2}; |
| reduce_dim_op = rewriter.create<TF::ConstOp>( |
| odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type, |
| reduce_dim_values)); |
| } |
| |
| auto new_mean_type = |
| ::mlir::RankedTensorType::get({last_dim}, rewriter.getF32Type()); |
| ::mlir::TF::MeanOp mean_op_1; |
| { |
| ::mlir::Value x_value = (*x.begin()); |
| mean_op_1 = rewriter.create<TF::MeanOp>( |
| odsLoc, new_mean_type, x_value, reduce_dim_op, |
| /*keep_dims=*/rewriter.getBoolAttr(false)); |
| } |
| |
| ::mlir::TF::SquaredDifferenceOp square_diff_op; |
| { |
| ::mlir::Value tblgen_value_0 = (*x.begin()); |
| ::mlir::Value tblgen_value_1 = (*mean_op_1.getODSResults(0).begin()); |
| // If x has shape of [b, h, w, c], the result of mean_op_1 will have |
| // shape of [c]. Therefore, their shapes are always compatible. |
| square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>( |
| odsLoc, tblgen_value_0, tblgen_value_1); |
| } |
| |
| ::mlir::TF::MeanOp mean_op_2; |
| { |
| ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin()); |
| mean_op_2 = rewriter.create<TF::MeanOp>( |
| odsLoc, new_mean_type, input_value, reduce_dim_op, |
| /*keep_dims=*/rewriter.getBoolAttr(false)); |
| } |
| |
| mean_value = (*mean_op_1.getODSResults(0).begin()); |
| variance_value = (*mean_op_2.getODSResults(0).begin()); |
| } // End is_training equals true if. |
| |
| ::llvm::SmallVector<::mlir::Value, 4> replace_values; |
| ::mlir::TF::ConstOp epsilon_const_op; |
| { |
| epsilon_const_op = |
| rewriter.create<::mlir::TF::ConstOp>(odsLoc, |
| /*value=*/epsilon); |
| } |
| ::mlir::TF::AddOp add_op_1; |
| { |
| ::mlir::Value epsilon_value = |
| (*epsilon_const_op.getODSResults(0).begin()); |
| // Multiplying with a constant, no need to check broadcastibility. |
| add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc, |
| /*x=*/variance_value, |
| /*y=*/epsilon_value); |
| } |
| ::mlir::TF::RsqrtOp rsqrt_op; |
| { |
| ::mlir::SmallVector<::mlir::Value, 4> tblgen_values; |
| ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; |
| tblgen_values.push_back((*add_op_1.getODSResults(0).begin())); |
| rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values, |
| tblgen_attrs); |
| } |
| ::mlir::TF::MulOp multiplier; |
| { |
| ::mlir::Value tblgen_value_0 = (*scale.begin()); |
| ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin()); |
| multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc, |
| /*x=*/tblgen_value_0, |
| /*y=*/tblgen_value_1); |
| } |
| ::mlir::TF::MulOp mul_op_1; |
| { |
| ::mlir::Value tblgen_value_0 = (*x.begin()); |
| ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin()); |
| mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, |
| /*x=*/tblgen_value_0, |
| /*y=*/tblgen_value_1); |
| } |
| ::mlir::TF::MulOp mul_op_2; |
| { |
| ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin()); |
| mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, |
| /*x=*/mean_value, |
| /*y=*/multiplier_value); |
| } |
| ::mlir::TF::SubOp sub_op; |
| { |
| ::mlir::Value tblgen_value_0 = (*offset.begin()); |
| ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin()); |
| sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc, |
| /*x=*/tblgen_value_0, |
| /*y=*/tblgen_value_1); |
| } |
| ::mlir::TF::AddOp add_op_2; |
| { |
| ::mlir::SmallVector<::mlir::Value, 4> tblgen_values; |
| ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; |
| tblgen_values.push_back((*mul_op_1.getODSResults(0).begin())); |
| tblgen_values.push_back((*sub_op.getODSResults(0).begin())); |
| ::mlir::SmallVector<::mlir::Type, 4> tblgen_types; |
| for (auto v : fused_batch_norm_op.getODSResults(0)) { |
| tblgen_types.push_back(v.getType()); |
| } |
| add_op_2 = rewriter.create<::mlir::TF::AddOp>( |
| odsLoc, tblgen_types, tblgen_values, tblgen_attrs); |
| } |
| for (auto v : |
| ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) { |
| replace_values.push_back(v); |
| } |
| for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { |
| replace_values.push_back(v); |
| } |
| for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { |
| replace_values.push_back(v); |
| } |
| for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { |
| replace_values.push_back(v); |
| } |
| for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { |
| replace_values.push_back(v); |
| } |
| for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { |
| replace_values.push_back(v); |
| } |
| rewriter.replaceOp(fused_batch_norm, replace_values); |
| return success(); |
| }; |
| }; |
| |
| #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" |
| |
| // Returns success if all the operations in the `op`'s regions including `op` |
| // itself are legal in a TFLite pipeline. |
| LogicalResult ValidateOp(Operation *op) { |
| bool has_illegal_ops = false; |
| op->walk([&](Operation *op) { |
| if (isa<TF::VariableV2Op>(op)) { |
| has_illegal_ops = true; |
| op->emitOpError() << "is illegal in a TFLite pipeline"; |
| } |
| }); |
| |
| return failure(has_illegal_ops); |
| } |
| |
| // Converts a set of TF2XLA ops into pure TF ops for future legalizations as |
| // TF2XLA ops aren't supported by later stages. |
| LogicalResult ConvertTf2XlaOps(func::FuncOp func, MLIRContext *context) { |
| ConversionTarget target(*context); |
| target.addLegalDialect<arith::ArithmeticDialect>(); |
| target.addLegalDialect<func::FuncDialect>(); |
| target.addLegalDialect<TF::TensorFlowDialect>(); |
| target.addLegalOp<ModuleOp>(); |
| target.addLegalOp<func::FuncOp>(); |
| target.addIllegalOp<TF::XlaConvOp>(); |
| target.addIllegalOp<TF::XlaGatherOp>(); |
| |
| RewritePatternSet patterns(context); |
| mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context); |
| mhlo::PopulateLegalizeTfPatterns(context, &patterns); |
| TF::PopulateLegalizeHloToTfPatterns(&patterns, context); |
| mhlo::GatherOp::getCanonicalizationPatterns(patterns, context); |
| |
| return applyPartialConversion(func, target, std::move(patterns)); |
| } |
| |
| // Convert rfft to rfft2d. |
| // The transformation pattern looks like below: |
| // |
| // input fft_len |
| // \ / |
| // rfft |
| // |
| // || |
| // \/ |
| // |
| // input fft_len |
| // \ / |
| // expand_dim concat with [1] at the front |
| // \ / |
| // rfft_2d |
| // | |
| // squeeze |
| struct ConvertRfftToRfft2d : public RewritePattern { |
| explicit ConvertRfftToRfft2d(MLIRContext *context) |
| : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto rfft_op = dyn_cast<TF::RFFTOp>(op); |
| |
| auto input = rfft_op.input(); |
| auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>(); |
| if (!input_type) return failure(); |
| auto fft_len = rfft_op.fft_length(); |
| auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>(); |
| if (!fft_len_type) return failure(); |
| |
| auto output_type = |
| rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>(); |
| if (!output_type) return failure(); |
| |
| // Expanded inputs. |
| // Insert at -2 location. |
| auto one_ele_type = |
| mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32)); |
| auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(), |
| one_ele_type, -2); |
| |
| SmallVector<int64_t, 4> expanded_input_shape; |
| SmallVector<int64_t, 4> expanded_output_shape; |
| int expanded_rank = input_type.getRank() + 1; |
| int r = 0; |
| for (int i = 0; i < expanded_rank; ++i) { |
| if (i == expanded_rank - 2) { |
| expanded_input_shape.push_back(1); |
| expanded_output_shape.push_back(1); |
| } else { |
| expanded_input_shape.push_back(input_type.getDimSize(r)); |
| expanded_output_shape.push_back(output_type.getDimSize(r)); |
| r++; |
| } |
| } |
| |
| auto expaned_input_type = mlir::RankedTensorType::get( |
| expanded_input_shape, input_type.getElementType()); |
| TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>( |
| rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult()); |
| |
| // Expanded fft_len. |
| auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1}); |
| |
| auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr); |
| |
| auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(), |
| one_ele_type, 0); |
| |
| auto expanded_fft_len_type = |
| mlir::RankedTensorType::get({2}, fft_len_type.getElementType()); |
| |
| TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>( |
| rfft_op.getLoc(), expanded_fft_len_type, |
| SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult()); |
| |
| // Insert the rfft_2d. |
| auto rfft2d_out_type = mlir::RankedTensorType::get( |
| expanded_output_shape, output_type.getElementType()); |
| TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>( |
| rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(), |
| expanded_fft_len.getResult()); |
| |
| // Insert the squeeze op. |
| auto squeeze_dim = rewriter.getI64ArrayAttr({-2}); |
| TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>( |
| rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim); |
| |
| rewriter.replaceOp(op, squeeze.getResult()); |
| |
| return success(); |
| } |
| }; |
| |
| // Replaces the Identity op with its input in either of the following scenarios |
| // : 1) The Identity op's input and output have same types/shapes. 2) The result |
| // of Identity op is only used by TF ops. |
| struct RemoveIdentity : public OpRewritePattern<TF::IdentityOp> { |
| using OpRewritePattern<TF::IdentityOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TF::IdentityOp identity, |
| PatternRewriter &rewriter) const override { |
| // Replace the op with the input if input and result have the same type. |
| if (identity.input().getType() == identity.getType()) { |
| rewriter.replaceOp(identity, identity.input()); |
| return success(); |
| } |
| // Replace the op with the input if output is only used by TF ops. |
| // Currently this is more on the conservative side since we need to ensure |
| // every consumer op to be a TF op before applying this pattern. We can |
| // consider to revisit this in the future if this turns out to be too |
| // restrictive. |
| for (Operation *user : identity->getUsers()) { |
| if (user->getDialect()->getNamespace() != "tf") { |
| return failure(); |
| } |
| } |
| |
| rewriter.replaceOp(identity, identity.input()); |
| return success(); |
| } |
| }; |
| |
| void PrepareTFPass::runOnOperation() { |
| MLIRContext *ctx = &getContext(); |
| RewritePatternSet patterns(ctx); |
| RewritePatternSet phase_2_patterns(ctx); |
| auto func = getOperation(); |
| |
| // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since |
| // PrepareTFPass is the very first TFLite pass in the pipeline. |
| // TODO(jingpu): It might be better to split this check into its own pass |
| // to make things more modular. |
| if (failed(ValidateOp(func))) { |
| func.emitError() << "tfl-prepare-tf pass failed."; |
| signalPassFailure(); |
| return; |
| } |
| |
| if (failed(ConvertTf2XlaOps(func, ctx))) { |
| signalPassFailure(); |
| return; |
| } |
| |
| // This pattern will try to identify and optimize for dilated convolution. |
| // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be |
| // replaced with a single Conv op with dilation parameter. |
| patterns.add<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat, |
| ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx); |
| |
| patterns.add<RemoveIdentity>(ctx); |
| TFL::populateWithGenerated(patterns); |
| // TODO(karimnosseir): Split to separate pass probably after |
| // deciding on long term plan for this optimization. |
| // This will allow optimizing any TF_Mul->TF_Conv in the graph |
| // and any expanded from FusedBatchNorm. We need to do this |
| // before converting TF_Conv to TFL_Conv |
| (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); |
| |
| // Remove the wrapper of the tf.FakeQuant* ops and also insert the |
| // tfl.quantize and tfl.dequantize to preserve the quantization parameters. |
| // This is done after the first round of optimization to make sure all the |
| // min/max operands of the tf.FakeQuant* are constants to be matched. The |
| // following round of optimization will folding the unwrapped |
| // tf.FakeQuant* ops with the weight constants. |
| if (failed(ConvertFakeQuantOps(func, ctx, use_fake_quant_num_bits_))) { |
| signalPassFailure(); |
| return; |
| } |
| |
| // Load the generated pattern again, so new quantization pass-through |
| // will be applied. |
| TFL::populateWithGenerated(phase_2_patterns); |
| if (unfold_batch_matmul_) { |
| TF::PopulateUnrollTfBatchMatMul(ctx, phase_2_patterns); |
| } |
| phase_2_patterns |
| .add<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFStridedSlice, |
| ConvertRfftToRfft2d, RemoveIdentity>(ctx); |
| phase_2_patterns.add<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>( |
| ctx, allow_bf16_and_f16_type_legalization_); |
| |
| (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns)); |
| } |
| |
| } // namespace |
| |
| // Creates an instance of the TensorFlow Lite dialect PrepareTF pass. |
| std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareTFPass( |
| bool unfold_batch_matmul, bool allow_bf16_and_f16_type_legalization, |
| bool use_fake_quant_num_bits) { |
| return std::make_unique<PrepareTFPass>(unfold_batch_matmul, |
| allow_bf16_and_f16_type_legalization, |
| use_fake_quant_num_bits); |
| } |
| |
| static PassRegistration<PrepareTFPass> pass; |
| |
| } // namespace TFL |
| } // namespace mlir |