| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <functional> |
| #include <iterator> |
| #include <limits> |
| #include <numeric> |
| #include <string> |
| #include <tuple> |
| #include <type_traits> |
| |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/Sequence.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/ADT/iterator_range.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/Dialect/Traits.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Diagnostics.h" // from @llvm-project |
| #include "mlir/IR/DialectImplementation.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Matchers.h" // from @llvm-project |
| #include "mlir/IR/OpDefinition.h" // from @llvm-project |
| #include "mlir/IR/OpImplementation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Parser/Parser.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/InliningUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" |
| #include "tensorflow/core/framework/kernel_shape_util.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/util/padding.h" |
| #include "tensorflow/core/util/tensor_format.h" |
| |
| namespace mlir { |
| namespace TF { |
| |
| namespace { |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // AddOp |
| //===----------------------------------------------------------------------===// |
| |
| void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<AddToAddV2>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AddNOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) { |
| if (operands.size() == 1) return *inputs().begin(); |
| |
| // Fold if there is only one single non-zero operand or all operands are zero. |
| int non_zero_index = -1; |
| auto IsKnownZero = [](Attribute attr) { |
| if (!attr) return false; |
| auto splat = attr.dyn_cast<SplatElementsAttr>(); |
| if (!splat) return false; |
| Type element_ty = splat.getType().getElementType(); |
| if (element_ty.isa<FloatType>()) |
| return splat.getSplatValue<llvm::APFloat>().isZero(); |
| if (element_ty.isa<IntegerType>()) |
| return splat.getSplatValue<llvm::APInt>().getSExtValue() == 0; |
| return false; |
| }; |
| |
| for (auto it : llvm::enumerate(operands)) { |
| if (IsKnownZero(it.value())) continue; |
| if (non_zero_index != -1) { |
| // Don't fold if we find more than 1 non-zero operand. |
| return {}; |
| } |
| non_zero_index = it.index(); |
| } |
| |
| // Only fold when the result shape is fully static. |
| auto result_ty = getType().dyn_cast<ShapedType>(); |
| if (!result_ty || !result_ty.hasStaticShape()) return {}; |
| |
| if (non_zero_index == -1) { |
| return SplatElementsAttr::get( |
| result_ty, |
| operands.begin()->cast<DenseElementsAttr>().getSplatValue<Attribute>()); |
| } |
| |
| // Check the non-zero operand's shape matches the result shape. |
| if (result_ty == inputs()[non_zero_index].getType()) |
| return inputs()[non_zero_index]; |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AddV2Op |
| //===----------------------------------------------------------------------===// |
| |
| void AddV2Op::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<AddV2OfNegLeft, AddV2OfNegRight>(context); |
| } |
| |
| OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) { |
| return IdentityArithmeticOpFolder<AddV2Op>(*this, operands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AllOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult AllOp::verify() { |
| AllOp op = *this; |
| return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), |
| op.getLoc()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AnyOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult AnyOp::verify() { |
| AnyOp op = *this; |
| return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), |
| op.getLoc()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AssertOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| // Removes Assert with constant true predicate. |
| struct AssertWithTrue : public OpRewritePattern<AssertOp> { |
| using OpRewritePattern<AssertOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(AssertOp op, |
| PatternRewriter &rewriter) const override { |
| ElementsAttr cst; |
| if (matchPattern(op.condition(), m_Constant(&cst))) { |
| if (cst.getValues<bool>()[0]) { |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| } // namespace |
| |
| void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<AssertWithTrue>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchMatMulV2Op & BatchMatMulOp |
| //===----------------------------------------------------------------------===// |
| |
| template <typename OpT, |
| typename std::enable_if<llvm::is_one_of< |
| OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr> |
| static LogicalResult Verify(OpT op) { |
| if (!HasRankAtLeast(op.x(), 2)) { |
| return op.emitOpError("requires lhs operand to have rank at least two"); |
| } |
| if (!HasRankAtLeast(op.y(), 2)) { |
| return op.emitOpError("requires rhs operand to have rank at least two"); |
| } |
| |
| RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x()); |
| RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y()); |
| |
| if (!x_ty || !y_ty) return success(); |
| |
| ArrayRef<int64_t> x_shape = x_ty.getShape(); |
| ArrayRef<int64_t> y_shape = y_ty.getShape(); |
| |
| llvm::SmallVector<int64_t, 4> result_batch_shape; |
| llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2); |
| llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2); |
| |
| // Check compatibility of batch dimensions if both input shapes are known. |
| // BatchMatMul should have exactly the same batch dimensions and |
| // BatchMatMulV2 should have broadcastable batch dimensions. |
| // |
| // The last two dimensions are non-batch dimensions that don't need to |
| // participate in batch dimension compatibility check. |
| if (std::is_same<OpT, BatchMatMulOp>()) { |
| for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) { |
| int64_t x_dim = std::get<0>(dim_pairs); |
| int64_t y_dim = std::get<1>(dim_pairs); |
| if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) && |
| x_dim != y_dim) { |
| return op.emitOpError() |
| << "found mismatching batch dimensions for lhs shape " << x_ty |
| << " and rhs shape " << y_ty; |
| } |
| } |
| } else { |
| if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches, |
| result_batch_shape)) |
| return op.emitOpError() |
| << "found incompatible broadcast batch dimensions for lhs shape " |
| << x_ty << " and rhs shape " << y_ty; |
| } |
| |
| RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); |
| if (!output_ty) return success(); |
| |
| int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank()); |
| if (output_ty.getRank() != expected_output_rank) |
| return op.emitOpError() |
| << "found invalid output rank, expected " << expected_output_rank |
| << " but got " << output_ty.getRank(); |
| |
| // Check output batch dim with potential broadcasting. |
| ArrayRef<int64_t> output_shape = output_ty.getShape(); |
| for (int i = 0; i < result_batch_shape.size(); ++i) { |
| if (output_shape[i] != ShapedType::kDynamicSize && |
| result_batch_shape[i] != ShapedType::kDynamicSize && |
| output_shape[i] != result_batch_shape[i]) |
| return op.emitOpError() |
| << "has mismatching input batch dimension " |
| << result_batch_shape[i] << " and output batch dimension " |
| << output_shape[i]; |
| } |
| |
| // Check output shape for non-batch dimension, following documentation below. |
| // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul |
| int64_t x_row_dim = x_shape[x_shape.size() - 2]; |
| int64_t x_col_dim = x_shape[x_shape.size() - 1]; |
| int64_t y_row_dim = y_shape[y_shape.size() - 2]; |
| int64_t y_col_dim = y_shape[y_shape.size() - 1]; |
| int64_t out_row_dim = output_shape[output_shape.size() - 2]; |
| int64_t out_col_dim = output_shape[output_shape.size() - 1]; |
| |
| int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim; |
| int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim; |
| |
| if (expected_out_row_dim != ShapedType::kDynamicSize && |
| out_row_dim != ShapedType::kDynamicSize && |
| out_row_dim != expected_out_row_dim) |
| return op.emitOpError() |
| << "found invalid output dimension on row, expected " |
| << expected_out_row_dim << " but got " << out_row_dim; |
| if (expected_out_col_dim != ShapedType::kDynamicSize && |
| out_col_dim != ShapedType::kDynamicSize && |
| out_col_dim != expected_out_col_dim) |
| return op.emitOpError() |
| << "found invalid output dimension on col, expected " |
| << expected_out_col_dim << " but got " << out_col_dim; |
| |
| return success(); |
| } |
| LogicalResult BatchMatMulOp::verify() { return Verify(*this); } |
| LogicalResult BatchMatMulV2Op::verify() { return Verify(*this); } |
| |
| void BatchMatMulOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<BatchMatMulToV2>(context); |
| } |
| |
| void BatchMatMulV2Op::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<BatchMatMulV2ToMatMul>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchToSpaceOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BatchToSpaceOp::verify() { |
| BatchToSpaceOp op = *this; |
| // Op already has a constraint that block_size >= 2. |
| int64_t block_size = op.block_size(); |
| |
| llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize); |
| auto input_type = op.input().getType().cast<TensorType>(); |
| if (input_type.hasRank()) { |
| if (input_type.getRank() != 4) |
| return op.emitOpError() |
| << "requires input to be a 4D tensor, but got " << input_type; |
| |
| int64_t input_batch = input_type.getDimSize(0); |
| if (input_batch != ShapedType::kDynamicSize && |
| input_batch % (block_size * block_size) != 0) { |
| return op.emitOpError() |
| << "requires input batch (dimension 0) to be evenly divisible " |
| "by (block_size * block_size), but got input batch " |
| << input_batch << " and block_size " << block_size; |
| } |
| |
| input_shape.assign(input_type.getShape().begin(), |
| input_type.getShape().end()); |
| } |
| |
| auto crops_type = op.crops().getType().cast<TensorType>(); |
| if (crops_type.hasRank()) { |
| if (crops_type.getRank() != 2) |
| return op.emitOpError() |
| << "requires crops to be a 2D tensor, but got " << crops_type; |
| |
| auto dim_of_size = [&](int64_t dim, int64_t size) { |
| if (crops_type.isDynamicDim(dim)) return true; |
| return crops_type.getDimSize(dim) == size; |
| }; |
| if (!dim_of_size(0, 2) || !dim_of_size(1, 2)) |
| return op.emitOpError() |
| << "requires crops to be a tensor<2x2>, but got " << crops_type; |
| } |
| |
| DenseIntElementsAttr crops_attr; |
| // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]], |
| // and flattened as [crop_top, crop_bottom, crop_left, crop_right] |
| llvm::SmallVector<int64_t, 4> crops_values; |
| if (matchPattern(op.crops(), m_Constant(&crops_attr))) { |
| assert(crops_attr.getNumElements() == 4 && |
| "tf.BatchToSpace crops must have 4 elements"); |
| |
| auto crops_range = crops_attr.getValues<APInt>(); |
| for (const auto &crops_value : crops_range) { |
| int64_t crops_value_int = crops_value.getSExtValue(); |
| if (crops_value_int < 0) |
| return op.emitOpError() |
| << "requires all crop values to be nonnegative, but got " |
| << crops_attr; |
| |
| crops_values.push_back(crops_value_int); |
| } |
| } |
| |
| auto output_type = op.output().getType().cast<TensorType>(); |
| if (output_type.hasRank()) { |
| if (output_type.getRank() != 4) |
| return op.emitOpError() |
| << "requires output to be a 4D tensor, but got " << output_type; |
| |
| auto static_dims = [](int64_t dim_a, int64_t dim_b) { |
| return dim_a != ShapedType::kDynamicSize && |
| dim_b != ShapedType::kDynamicSize; |
| }; |
| |
| auto output_shape = output_type.getShape(); |
| |
| // output batch = input batch / (block_size * block_size). |
| int64_t input_batch = input_shape[0]; |
| int64_t output_batch = output_shape[0]; |
| if (static_dims(input_batch, output_batch) && |
| (output_batch * block_size * block_size) != input_batch) |
| return op.emitOpError() |
| << "requires output batch (dimension 0) to be equal to input " |
| "batch (dimension 0) / (block_size * block_size), but got " |
| "output batch " |
| << output_batch << ", input batch " << input_batch |
| << ", and block_size " << block_size; |
| |
| auto check_spatial_dim = [&](int64_t spatial_dim_index, |
| llvm::StringRef dim_name, |
| llvm::StringRef crop_a_name, |
| llvm::StringRef crop_b_name) -> LogicalResult { |
| int64_t input_dim = input_shape[spatial_dim_index]; |
| int64_t output_dim = output_shape[spatial_dim_index]; |
| if (!static_dims(input_dim, output_dim)) return success(); |
| |
| int64_t input_dim_pad = input_dim * block_size; |
| // If crops are unknown, the maximum output spatial dim size is input |
| // spatial dim size * block_size, as crops can be minimum 0. |
| if (crops_values.empty() && output_dim > input_dim * block_size) |
| return op.emitOpError() |
| << "requires output " << dim_name << " (dimension " |
| << spatial_dim_index << ") to be less than or equal to input " |
| << dim_name << " (dimension " << spatial_dim_index |
| << ") * block_size, but got output " << dim_name << " " |
| << output_dim << ", input " << dim_name << " " << input_dim |
| << ", and block_size " << block_size; |
| |
| if (!crops_values.empty()) { |
| // output spatial dim = input spatial dim * block_size - crops. |
| int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)]; |
| int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1]; |
| if (output_dim != input_dim_pad - crop_a - crop_b) |
| return op.emitOpError() |
| << "requires output " << dim_name << " (dimension " |
| << spatial_dim_index << ") to be equal to input " << dim_name |
| << " (dimension " << spatial_dim_index << ") * block_size - " |
| << crop_a_name << " - " << crop_b_name << ", but got output " |
| << dim_name << " " << output_dim << ", input " << dim_name |
| << " " << input_dim << ", " << crop_a_name << " " << crop_a |
| << ", " << crop_b_name << " " << crop_b << ", and block_size " |
| << block_size; |
| } |
| |
| return success(); |
| }; |
| |
| if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) || |
| failed(check_spatial_dim(2, "width", "crop_left", "crop_right"))) |
| return failure(); |
| |
| int64_t input_depth = input_shape[3]; |
| int64_t output_depth = output_shape[3]; |
| if (static_dims(input_depth, output_depth) && output_depth != input_depth) |
| return op.emitOpError() |
| << "requires output depth (dimension 3) to be equal to input " |
| "depth (dimension 3), but got output depth " |
| << output_depth << " and input depth " << input_depth; |
| } |
| |
| return success(); |
| } |
| |
| void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<BatchToSpaceToBatchToSpaceND>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchToSpaceNDOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BatchToSpaceNDOp::verify() { |
| BatchToSpaceNDOp op = *this; |
| auto block_shape_ty = op.block_shape().getType().cast<ShapedType>(); |
| auto crops_ty = op.crops().getType().cast<ShapedType>(); |
| |
| if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) { |
| const int block_rank = block_shape_ty.getShape().front(); |
| if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank || |
| crops_ty.getShape()[1] != 2) { |
| op.emitOpError() << "crops should have shape [" << block_rank |
| << ", 2] instead of " << crops_ty.getShape(); |
| return failure(); |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BiasAddOp |
| //===----------------------------------------------------------------------===// |
| |
| // Verifies that, |
| // * the value and bias operands have valid ranks or are unranked. |
| // * Channel dimension of the value operand and length of bias matches if they |
| // are not unknown. |
| // |
| LogicalResult BiasAddOp::verify() { |
| BiasAddOp op = *this; |
| absl::string_view data_format(op.data_format().data(), |
| op.data_format().size()); |
| tensorflow::TensorFormat format; |
| bool is_valid = FormatFromString(data_format, &format); |
| DCHECK(is_valid) << data_format; |
| if (format == tensorflow::TensorFormat::FORMAT_NHWC) { |
| if (!HasRankAtLeast(op.value(), 2)) |
| return op.emitOpError( |
| "requires value operand to have rank at least two with `NHWC` data " |
| "format"); |
| } else { |
| // Op definition requires data_format to be either NHWC or NCHW. |
| DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); |
| if (!HasRankAtLeast(op.value(), 3)) |
| return op.emitOpError( |
| "requires value operand to have rank at least three with `NCHW` data " |
| "format"); |
| } |
| |
| if (!IsOfRankOrUnranked(op.bias(), 1)) |
| return op.emitOpError("requires bias operand to have rank exactly one"); |
| |
| RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>(); |
| RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>(); |
| if (!bias_ty || !value_ty) return success(); |
| |
| int64_t feature_dim_idx = |
| tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format); |
| int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); |
| int64_t bias_len = bias_ty.getDimSize(0); |
| if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) { |
| return op.emitOpError() |
| << "requires channel dimension and feature dimension to match; " |
| "found " |
| << feature_dim << " and " << bias_len << ", respectively"; |
| } |
| return success(); |
| } |
| |
| LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) { |
| return ::mlir::TF::UpdateDataFormat(data_format, this); |
| } |
| |
| StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) { |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) |
| return data_format(); |
| |
| // Prefer NHWC for GPU devices. |
| return "NHWC"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BiasAddGradOp |
| //===----------------------------------------------------------------------===// |
| |
| // Verifies that, |
| // * the out_backprop operands have valid ranks or are unranked. |
| // |
| LogicalResult BiasAddGradOp::verify() { |
| BiasAddGradOp op = *this; |
| absl::string_view data_format(op.data_format().data(), |
| op.data_format().size()); |
| tensorflow::TensorFormat format; |
| bool is_valid = FormatFromString(data_format, &format); |
| DCHECK(is_valid) << data_format; |
| if (format == tensorflow::TensorFormat::FORMAT_NHWC) { |
| if (!HasRankAtLeast(op.out_backprop(), 2)) |
| return op.emitOpError( |
| "requires out_backprop operand to have rank at least two with `NHWC` " |
| "data format"); |
| } else { |
| // Op definition requires data_format to be either NHWC or NCHW. |
| DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); |
| if (!HasRankAtLeast(op.out_backprop(), 3)) |
| return op.emitOpError( |
| "requires out_backprop operand to have rank at least three with " |
| "`NCHW` data format"); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BiasAddV1Op |
| //===----------------------------------------------------------------------===// |
| |
| void BiasAddV1Op::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<BiasAddV1ToBiasAdd>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // arith::BitcastOp |
| //===----------------------------------------------------------------------===// |
| |
| void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<BitcastSameType, BitcastNested>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BroadcastToOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BroadcastToOp::verify() { |
| // TODO(antiagainst): check that |
| // * The 'shape' input is an 1-D int tensor. |
| // * Each dimension pair of the source and target shapes are either equal |
| // or one of them is one. |
| return success(); |
| } |
| |
| OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) { |
| Value input = this->input(); |
| |
| // Fold broadcast if operand and result types are the same and all dimensions |
| // are statically known (no-op broadcast). |
| auto result_ty = getType().dyn_cast<ShapedType>(); |
| if (!result_ty || !result_ty.hasStaticShape()) return {}; |
| |
| if (result_ty == input.getType()) return input; |
| |
| DenseIntElementsAttr cst_attr; |
| if (!matchPattern(input, m_Constant(&cst_attr))) return {}; |
| if (!cst_attr.isSplat()) return {}; |
| |
| return DenseElementsAttr::get(result_ty, cst_attr.getSplatValue<Attribute>()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BroadcastGradientArgsOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Returns `true` if both s0 & s1 are defined via constant op, and fills |
| // s0_shape & s1_shape. |
| bool ExtractInputConstShape(BroadcastGradientArgsOp op, |
| DenseIntElementsAttr &s0, DenseIntElementsAttr &s1, |
| SmallVectorImpl<int64_t> &s0_shape, |
| SmallVectorImpl<int64_t> &s1_shape) { |
| if (!matchPattern(op.s0(), m_Constant(&s0))) return false; |
| if (!matchPattern(op.s1(), m_Constant(&s1))) return false; |
| |
| for (auto s : s0.getValues<APInt>()) s0_shape.push_back(s.getSExtValue()); |
| for (auto s : s1.getValues<APInt>()) s1_shape.push_back(s.getSExtValue()); |
| |
| return true; |
| } |
| |
| // Calculates r0 & r1 output based on inputs and calculated broadcasted shape. |
| // |
| // For given bcasted_shape, s0_shape and s1_shape, the broadcasted dimension is |
| // calculated and push back to its corresponding result, r0 or r1. For example, |
| // for s0_shape [1,4] and s1_shape [4, 4], bcasted_shape is computed to be |
| // [4,4] - this leads to the result of r0 to be [0] as the first dimension of s0 |
| // is broadcasted, and r1 to be <> as no broadcasting is happening for s1. |
| void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape, |
| ArrayRef<int64_t> s0_shape, |
| ArrayRef<int64_t> s1_shape, |
| SmallVectorImpl<int64_t> &r0, |
| SmallVectorImpl<int64_t> &r1) { |
| r0.clear(); |
| r1.clear(); |
| |
| // No broadcasting is required if both the shapes are equal. |
| if (s0_shape == s1_shape) return; |
| |
| for (int i = bcasted_shape.size(); i > 0; --i) { |
| int idx = bcasted_shape.size() - i; |
| int s0_idx = i > s0_shape.size() ? -1 : s0_shape.size() - i; |
| int s1_idx = i > s1_shape.size() ? -1 : s1_shape.size() - i; |
| if (s0_idx == -1) { |
| r0.push_back(idx); |
| if (s1_shape[s1_idx] == 1) r1.push_back(idx); |
| } else if (s1_idx == -1) { |
| r1.push_back(idx); |
| if (s0_shape[s0_idx] == 1) r0.push_back(idx); |
| } else if (s0_shape[s0_idx] != s1_shape[s1_idx]) { |
| if (s0_shape[s0_idx] != bcasted_shape[idx]) |
| r0.push_back(idx); |
| else |
| r1.push_back(idx); |
| } else if (s0_shape[s0_idx] == 1) { |
| // This op is used to compute the gradient dimensions requiring reduction |
| // to match the input dimensions. In case both the dimensions are one, |
| // reducing the dimension has no effect. We choose to reduce such |
| // dimensions to match the TensorFlow kernel behavior. However, note that |
| // the TF behavior in this case is inconsistent with the case with the |
| // same shapes. |
| r0.push_back(idx); |
| r1.push_back(idx); |
| } |
| } |
| } |
| } // namespace |
| |
| // Verifies that, |
| // * Broadcast compatability for input shapes. |
| // * Output shape dimension matches the expected dimension size for input |
| // shapes. |
| LogicalResult BroadcastGradientArgsOp::verify() { |
| BroadcastGradientArgsOp op = *this; |
| SmallVector<int64_t, 4> s0_shape, s1_shape; |
| DenseIntElementsAttr s0, s1; |
| if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success(); |
| |
| // If both shape is known const, try to validate shape on them as well. |
| SmallVector<int64_t, 4> bcasted_shape; |
| if (!OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape)) |
| return op.emitOpError() << "requires broadcast compatible shape tensors " |
| "for 's0' and 's1', but got " |
| << s0 << " and " << s1; |
| |
| SmallVector<int64_t, 4> r0, r1; |
| GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0, |
| r1); |
| |
| // Verify that output types are of rank one and matches the computed result |
| // shape. |
| auto r0_ty = op.r0().getType().dyn_cast<RankedTensorType>(); |
| auto r1_ty = op.r1().getType().dyn_cast<RankedTensorType>(); |
| if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size()) |
| return op.emitOpError() << "requires dimension 0 size of 'r0' to be " |
| << r0.size() << " but got " << r0_ty.getShape()[0]; |
| if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getDimSize(0) != r1.size()) |
| return op.emitOpError() << "requires dimension 0 size of 'r1' to be " |
| << r1.size() << " but got " << r1_ty.getShape()[0]; |
| |
| return success(); |
| } |
| |
| LogicalResult BroadcastGradientArgsOp::fold( |
| ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { |
| SmallVector<int64_t, 4> s0_shape, s1_shape; |
| DenseIntElementsAttr s0, s1; |
| if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape)) |
| return failure(); |
| |
| // Fold BroadcastGradientArgs into two constants if both of the inputs have |
| // known shape. |
| SmallVector<int64_t, 4> bcasted_shape; |
| // Verifier should already ensure the broadcast compatibility. |
| bool bcast_compatible = |
| OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape); |
| assert(bcast_compatible); |
| (void)bcast_compatible; |
| |
| SmallVector<int64_t, 4> r0, r1; |
| GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0, |
| r1); |
| |
| auto build_out_dense_element = [](SmallVectorImpl<int64_t> &shape, |
| Type input_type) { |
| Type element_type = input_type.cast<mlir::TensorType>().getElementType(); |
| RankedTensorType type = RankedTensorType::get( |
| {static_cast<int64_t>(shape.size())}, element_type); |
| // Input could only be i32 or i64. For i32, downcast to int32_t array. |
| if (element_type.isInteger(32)) { |
| SmallVector<int32_t, 4> i32_shape; |
| for (auto s : shape) i32_shape.push_back(static_cast<int32_t>(s)); |
| return DenseIntElementsAttr::get(type, i32_shape); |
| } else { |
| assert(element_type.isInteger(64)); |
| return DenseIntElementsAttr::get(type, shape); |
| } |
| }; |
| |
| results.push_back(build_out_dense_element(r0, this->s0().getType())); |
| results.push_back(build_out_dense_element(r1, this->s1().getType())); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CaseOp |
| //===----------------------------------------------------------------------===// |
| |
| class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> { |
| public: |
| explicit FoldConstantCaseOp(MLIRContext *context) |
| : OpRewritePattern<TF::CaseOp>(context) {} |
| LogicalResult matchAndRewrite(TF::CaseOp op, |
| PatternRewriter &rewriter) const override; |
| }; |
| |
| LogicalResult FoldConstantCaseOp::matchAndRewrite( |
| TF::CaseOp op, PatternRewriter &rewriter) const { |
| // Extract the constant cond value. |
| DenseIntElementsAttr branch; |
| if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); |
| |
| int index = *branch.getValues<int>().begin(); |
| if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1; |
| |
| auto func = op.branches()[index].cast<SymbolRefAttr>(); |
| auto empty = rewriter.getStringAttr(""); |
| ReplaceTfOpWithNewOp<PartitionedCallOp>( |
| rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func, |
| /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); |
| return success(); |
| } |
| |
| void CaseOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<FoldConstantCaseOp, DropAttributes<CaseOp>>(context); |
| } |
| |
| static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) { |
| if (!IsOfRankOrUnranked(branch_index, 0)) |
| return op->emitOpError() |
| << "expects 'branch_index' to be a scalar, but got " |
| << branch_index.getType(); |
| return success(); |
| } |
| |
| static LogicalResult VerifyCaseOrIfOpBranchFunctions( |
| SymbolTableCollection &symbol_table, Operation *op, |
| ArrayRef<Attribute> branches, |
| llvm::function_ref<std::string(unsigned branch_index)> branch_name) { |
| SmallVector<FunctionType, 2> branch_types; |
| branch_types.reserve(branches.size()); |
| |
| if (llvm::any_of(op->getOperands(), |
| [](Value value) { return value == nullptr; })) |
| return op->emitOpError("operation has null operand"); |
| |
| // Functions have one less operand compared to op as first operand is elided |
| // (`cond` of `tf.If` and `branch_index` of `tf.Case`). |
| TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"}; |
| TypeRangeWithDesc result{op->getResultTypes(), "result"}; |
| |
| for (auto branch : llvm::enumerate(branches)) { |
| auto branch_func = symbol_table.lookupNearestSymbolFrom<func::FuncOp>( |
| op, branch.value().cast<SymbolRefAttr>()); |
| if (!branch_func) |
| return op->emitOpError() |
| << "expects " << branch_name(branch.index()) << " (" |
| << branch.value() << ") to point to a defined function"; |
| |
| FunctionType branch_type = branch_func.getFunctionType(); |
| std::string desc = branch_name(branch.index()) + " input"; |
| TypeRangeWithDesc branch_input{branch_type.getInputs(), desc}; |
| if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input))) |
| return failure(); |
| |
| desc = branch_name(branch.index()) + " result"; |
| TypeRangeWithDesc branch_result{branch_type.getResults(), desc}; |
| if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result))) |
| return failure(); |
| |
| branch_types.push_back(branch_type); |
| } |
| |
| // If branches have incompatible input types that means that no tensor can |
| // serve as input to all the functions. Hence, the op is invalid. |
| int expected_num_inputs = op->getNumOperands() - 1; |
| for (int i = 0; i < expected_num_inputs; ++i) { |
| SmallVector<Type, 2> branch_input_i_types; |
| branch_input_i_types.reserve(branches.size()); |
| llvm::transform( |
| branch_types, std::back_inserter(branch_input_i_types), |
| [i](FunctionType &branch_type) { return branch_type.getInput(i); }); |
| if (!AreCastCompatible(branch_input_i_types)) { |
| std::string input_types_str; |
| llvm::raw_string_ostream os(input_types_str); |
| llvm::interleaveComma(branch_input_i_types, os); |
| return op->emitOpError() |
| << "expects all branch input type(s) (" << os.str() |
| << ") at index " << i << " to be cast compatible"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult CaseOp::verify() { |
| CaseOp op = *this; |
| return VerifyCaseOpBase(op, op.branch_index()); |
| } |
| |
| LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) { |
| auto branch_name = [](unsigned index) { |
| return llvm::formatv("branch #{0}", index).str(); |
| }; |
| return VerifyCaseOrIfOpBranchFunctions(symbol_table, *this, |
| branches().getValue(), branch_name); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CaseRegionOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CaseRegionOp::verify() { |
| CaseRegionOp op = *this; |
| if (op.branches().empty()) |
| return op.emitOpError() << "expects to have at least 1 region"; |
| |
| if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); |
| |
| TypeRangeWithDesc results{op.getResultTypes(), "result"}; |
| |
| for (auto region_and_idx : llvm::enumerate(op.branches())) { |
| std::string description = |
| llvm::formatv("branch #{0} result", region_and_idx.index()).str(); |
| Operation *yield = region_and_idx.value().front().getTerminator(); |
| TypeRangeWithDesc branch_results{yield->getOperandTypes(), description}; |
| if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results))) |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| namespace { |
| // Eliminate values that pass through the CaseRegionOp or IfRegionOp branches. |
| template <class CaseOrIfRegionOp> |
| class CaseOrIfRegionEliminatePassThrough |
| : public OpRewritePattern<CaseOrIfRegionOp> { |
| using OpRewritePattern<CaseOrIfRegionOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(CaseOrIfRegionOp op, |
| PatternRewriter &rewriter) const override { |
| RegionRange branches = op.getRegions(); |
| SmallVector<Type, 4> new_result_types; |
| // Maps pass through results to extern values. |
| llvm::SmallDenseMap<Value, Value, 4> result_to_extern_value; |
| |
| for (auto result : op.getResults()) { |
| unsigned index = result.getResultNumber(); |
| Region *first_branch = *branches.begin(); |
| Operation *first_terminator = first_branch->front().getTerminator(); |
| Value returned_val = first_terminator->getOperand(index); |
| |
| // Pass through values would be defined outside the branch region. Keep |
| // the type of non pass through results to create a new op later, if |
| // required. |
| if (returned_val.getParentBlock() == &first_branch->front()) { |
| new_result_types.push_back(result.getType()); |
| continue; |
| } |
| // Check if the same extern value is returned in each branch. |
| for (Region *region : branches.drop_front()) { |
| Operation *terminator = region->front().getTerminator(); |
| if (terminator->getOperand(index) != returned_val) return failure(); |
| } |
| result_to_extern_value[result] = returned_val; |
| } |
| |
| // If no pass through values are found, no change is required. |
| if (result_to_extern_value.empty()) return failure(); |
| |
| // Create new case/if region op. |
| auto new_op = rewriter.create<CaseOrIfRegionOp>( |
| op.getLoc(), new_result_types, op.getOperand(), op->getAttrs(), |
| op.getNumRegions()); |
| |
| int next_index = 0; |
| for (auto result : op.getResults()) { |
| if (!result_to_extern_value.count(result)) { |
| result.replaceAllUsesWith(new_op.getResult(next_index++)); |
| continue; |
| } |
| result.replaceAllUsesWith(result_to_extern_value[result]); |
| for (Region *branch : branches) |
| branch->front().getTerminator()->eraseOperand(next_index); |
| } |
| |
| // Move region bodies to the new op. |
| for (auto region_index : llvm::seq<int>(0, branches.size())) |
| new_op.getRegion(region_index).takeBody(op.getRegion(region_index)); |
| |
| op.erase(); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void CaseRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<CaseOrIfRegionEliminatePassThrough<TF::CaseRegionOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CastOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { |
| // Cast with the same type is a no-op. |
| Value operand = getOperand(); |
| if (getType() == operand.getType()) return operand; |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConcatOp and ConcatV2Op |
| //===----------------------------------------------------------------------===// |
| |
| template <typename OpT, |
| typename std::enable_if<llvm::is_one_of< |
| OpT, ConcatOp, ConcatV2Op>::value>::type * = nullptr> |
| static LogicalResult Verify(OpT op) { |
| // TODO(hinsu): Convert variadic length attributes to derived attributes. |
| Operation::operand_range values = op.values(); |
| |
| int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1; |
| Value axis = *op.getODSOperands(axis_idx).begin(); |
| if (!HasRankAtMost(axis, 1)) { |
| return op.emitOpError( |
| "requires axis to be of scalar type (or vector type for older " |
| "versions)"); |
| } |
| |
| return VerifyTypesCompatibility(values, |
| /*mask_one_dim=*/true, op.getOperation()); |
| } |
| |
| LogicalResult ConcatOp::verify() { return Verify(*this); } |
| LogicalResult ConcatV2Op::verify() { return Verify(*this); } |
| |
| void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<ConvertToConcatV2>(context); |
| } |
| |
| namespace { |
| |
| // Hoist coefficient-wise unary operation out of the Concat op: |
| // |
| // %0 = "tf.Log1p"(%arg_0) |
| // %1 = "tf.Log1p"(%arg_1) |
| // ... |
| // %n = "tf.Log1p"(%arg_n) |
| // %m = "tf.ConcatV2"(%0, %1, ..., %n, %axis) |
| // |
| // Rewrite it to: |
| // |
| // %0 = "tf.ConcatV2"(%arg_0, %arg_1, ..., %arg_n, %axis) |
| // %1 = "tf.Log1p"(%0) |
| class HoistCwiseUnaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> { |
| public: |
| explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context) |
| : OpRewritePattern<TF::ConcatV2Op>(context) {} |
| LogicalResult matchAndRewrite(TF::ConcatV2Op op, |
| PatternRewriter &rewriter) const override; |
| }; |
| |
| LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite( |
| TF::ConcatV2Op op, PatternRewriter &rewriter) const { |
| auto loc = op.getLoc(); |
| |
| // All concat operands must be defined by ops. |
| Operation *first_arg_op = op.values().front().getDefiningOp(); |
| if (first_arg_op == nullptr) return failure(); |
| |
| // All concat operands must be produced by the coeff-wise unary operation. |
| if (!first_arg_op->hasTrait<OpTrait::TF::CwiseUnary>()) return failure(); |
| |
| // All concat operands must be defined by the op of same kind. |
| bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool { |
| Operation *arg_op = arg.getDefiningOp(); |
| return arg_op && arg_op->getName() == first_arg_op->getName(); |
| }); |
| if (!args_same_op) return failure(); |
| |
| // Collect unary operations operands. |
| auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value { |
| return arg.getDefiningOp()->getOperand(0); |
| }); |
| SmallVector<Value, 8> unary_ops_args(unary_operands); |
| |
| // Concatenate unary ops operands. |
| auto concat_unary_operands = |
| rewriter.create<ConcatV2Op>(loc, op.getType(), unary_ops_args, op.axis()); |
| |
| // Replace original concat with an unary op. |
| OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(), |
| concat_unary_operands.getResult(), |
| op.getResult().getType(), |
| ArrayRef<NamedAttribute>()); |
| Operation *new_unary_op = rewriter.create(new_unary_op_state); |
| |
| rewriter.replaceOp(op, new_unary_op->getResults()); |
| |
| return success(); |
| } |
| |
| // Hoist coefficient-wise binary operation out of the Concat op: |
| // |
| // %0 = tf.Mul(%lhs_0, %rhs_0) |
| // %1 = tf.Mul(%lhs_1, %rhs_1) |
| // ... |
| // %n = tf.Mul(%lhs_n, %rhs_n) |
| // %m = tf.ConcatV2(%0, %1, ..., %n, %axis) |
| // |
| // Rewrite it to: |
| // |
| // %0 = tf.ConcatV2(%lhs0, %lhs1, ..., %lhs_n, %lhs_concat_axis) |
| // %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis) |
| // %2 = tf.Mul(%0, %1) |
| // |
| // If a minor fraction of the Concat inputs are not of the same binary op kind |
| // (tf.Mul in the above example), we will synthesize the binary ops for those |
| // inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a |
| // tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to |
| // float32 tensors. |
| // TODO(hongm): Implement this op synthesis optimization for other dtypes if |
| // needed. |
| // |
| // Because coefficient-wise binary operations support implicit broadcasting, we |
| // should be very careful with this optimization, and do not accidentally |
| // produce incorrect concat operations. |
| class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> { |
| public: |
| explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context) |
| : OpRewritePattern<TF::ConcatV2Op>(context) {} |
| LogicalResult matchAndRewrite(TF::ConcatV2Op op, |
| PatternRewriter &rewriter) const override; |
| |
| private: |
| struct HoistParams { |
| SmallVector<Value, 8> lhs_args; |
| SmallVector<Value, 8> rhs_args; |
| int64_t lhs_axis; |
| int64_t rhs_axis; |
| Type lhs_concat_type; |
| Type rhs_concat_type; |
| int scalar_operand_idx; // can be 0 or 1 for the binary op's operands. |
| }; |
| |
| // Returns parameters of a binary op hoisting out of concatenation if all of |
| // the operands are in one of the compatible configurations. |
| // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul), |
| // except from the ones in `exceptions`. In that case, we can synthesize that |
| // binary op kind for the values in `exceptions`. |
| Optional<HoistParams> GetHoistParams( |
| TF::ConcatV2Op op, int64_t axis, |
| const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const; |
| }; |
| |
| LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( |
| TF::ConcatV2Op op, PatternRewriter &rewriter) const { |
| auto loc = op.getLoc(); |
| |
| // Axis must be a constant scalar value. |
| DenseIntElementsAttr axis_attr; |
| if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure(); |
| if (axis_attr.getNumElements() != 1) return failure(); |
| int64_t axis = |
| axis_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue(); |
| // TODO(ezhulenev): Compute axis from rank. e.g. It might be common to concat |
| // on the channels dim for NCHW layout as axis=-2. |
| if (axis < 0) return failure(); |
| |
| // All concat operands must be defined by ops of the same kind (e.g. tf.Mul), |
| // or some other ops that we might convert to using the same op kind above |
| // (e.g. converting op A to tf.Mul(A, 1.0)) |
| // TODO(hongm): generalize the code here to support cases where the first arg |
| // has no defining op (e.g. might be a block arg). |
| Operation *first_arg_op = op.values().front().getDefiningOp(); |
| if (first_arg_op == nullptr) return failure(); |
| |
| // All concat operands must be produced by the coeff-wise binary operation. |
| if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure(); |
| |
| // All concat operands must be defined by the op of same kind, except for a |
| // minor portion which we track in `exceptions`. |
| // Map from the operands to operand indices. |
| llvm::SmallDenseMap<Value, unsigned, 4> exceptions; |
| unsigned operand_idx = 0; |
| for (Value arg : op.values()) { |
| Operation *arg_op = arg.getDefiningOp(); |
| if (arg_op && arg_op->getName() == first_arg_op->getName()) { |
| ++operand_idx; |
| continue; |
| } |
| exceptions[arg] = operand_idx++; |
| } |
| // Recall those inputs to the concat op that are not produced by a binary op |
| // of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If |
| // there are too many exceptions, it might not be cost effective to apply the |
| // concat hoisting optimization here. |
| // Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1 |
| // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1 |
| // out of 3, we do. |
| const float exception_pct_threshold = 0.5; |
| if (static_cast<float>(op.values().size()) * exception_pct_threshold <= |
| exceptions.size()) |
| return failure(); |
| |
| // Compute binary operands hoist parameters. |
| auto hoist_params = GetHoistParams(op, axis, exceptions); |
| if (!hoist_params.has_value()) return failure(); |
| |
| // Process `exceptions`: For each value there, synthesize a binary op of the |
| // above kind, so that the concat hoisting optimization can still apply. |
| if (!exceptions.empty()) { |
| int identity_val; |
| if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op)) |
| identity_val = 0; |
| else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) || |
| isa<RealDivOp>(first_arg_op)) |
| identity_val = 1; |
| else |
| return failure(); |
| DenseElementsAttr const_attr; |
| auto scalar_tensor_type = |
| first_arg_op->getOperand(hoist_params->scalar_operand_idx) |
| .getType() |
| .dyn_cast<ShapedType>(); |
| Type scalar_dtype = scalar_tensor_type.getElementType(); |
| if (scalar_dtype.isa<FloatType>()) |
| const_attr = DenseElementsAttr::get(scalar_tensor_type, |
| static_cast<float>(identity_val)); |
| else |
| return failure(); |
| |
| // All checks are passes, and we now prepare for rewrite. |
| auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr); |
| for (const auto &kv : exceptions) { |
| assert(!hoist_params->lhs_args[kv.second]); |
| assert(!hoist_params->rhs_args[kv.second]); |
| |
| if (hoist_params->scalar_operand_idx == 1) { |
| hoist_params->lhs_args[kv.second] = kv.first; |
| hoist_params->rhs_args[kv.second] = identity_const; |
| } else { |
| assert(hoist_params->scalar_operand_idx == 0); |
| hoist_params->lhs_args[kv.second] = identity_const; |
| hoist_params->rhs_args[kv.second] = kv.first; |
| } |
| } |
| } |
| |
| // Concatenates `args` along `axis`. |
| auto pack_or_concat = [&](bool is_scalar, Type result_type, ValueRange args, |
| int64_t axis) { |
| // Use `PackOp` for scalar concatenation because `ConcatV2Op` doesn't |
| // support scalar concatenation. |
| if (is_scalar) { |
| auto pack = rewriter.create<PackOp>(loc, result_type, args, |
| rewriter.getI64IntegerAttr(axis)); |
| return pack.getResult(); |
| } |
| |
| // New concatenation axis. |
| auto axis_type = RankedTensorType::get({}, getElementTypeOrSelf(axis_attr)); |
| DenseIntElementsAttr attr; |
| if (axis_type.getElementType().isInteger(32)) { |
| attr = DenseIntElementsAttr::get(axis_type, static_cast<int32_t>(axis)); |
| } else { |
| assert(axis_type.getElementType().isInteger(64)); |
| attr = DenseIntElementsAttr::get(axis_type, axis); |
| } |
| auto axis_const = rewriter.create<TF::ConstOp>(loc, attr); |
| |
| auto concat = |
| rewriter.create<ConcatV2Op>(loc, result_type, args, axis_const); |
| return concat.getResult(); |
| }; |
| |
| // Concatenate binary ops operands on the new axis. |
| Value lhs_concat = pack_or_concat( |
| hoist_params->scalar_operand_idx == 0, hoist_params->lhs_concat_type, |
| hoist_params->lhs_args, hoist_params->lhs_axis); |
| Value rhs_concat = pack_or_concat( |
| hoist_params->scalar_operand_idx == 1, hoist_params->rhs_concat_type, |
| hoist_params->rhs_args, hoist_params->rhs_axis); |
| |
| // Replace original concat with a binary op. |
| OperationState new_binary_op_state( |
| loc, first_arg_op->getName().getStringRef(), {lhs_concat, rhs_concat}, |
| op.getResult().getType(), ArrayRef<NamedAttribute>()); |
| Operation *new_binary_op = rewriter.create(new_binary_op_state); |
| |
| rewriter.replaceOp(op, new_binary_op->getResults()); |
| |
| return success(); |
| } |
| |
| Optional<HoistCwiseBinaryOutOfConcat::HoistParams> |
| HoistCwiseBinaryOutOfConcat::GetHoistParams( |
| TF::ConcatV2Op op, int64_t axis, |
| const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const { |
| assert(axis >= 0); |
| // Collects lhs or rhs arguments of concat op operands. |
| auto args = [&](int operand_idx) -> SmallVector<Value, 8> { |
| auto range = llvm::map_range(op.values(), [&](Value arg) { |
| if (exceptions.count(arg)) return Value(); |
| return arg.getDefiningOp()->getOperand(operand_idx); |
| }); |
| return {range.begin(), range.end()}; |
| }; |
| |
| // Returns true if all binary ops operands at `operand_idx` index are tensors |
| // of `axis + 1` rank and axis dim has size `1`. |
| auto is_all_tensors = [&](int operand_idx, int axis) -> bool { |
| return llvm::all_of(op.values(), [&](Value arg) -> bool { |
| mlir::Value operand; |
| if (exceptions.count(arg)) { |
| // For exceptions, since we are going to synthesize a binary op that |
| // produce the identity value, it is also required that it is a ranked |
| // tensor with rank = `axis + 1` and axis dim has size `1`. |
| operand = arg; |
| } else { |
| operand = arg.getDefiningOp()->getOperand(operand_idx); |
| } |
| auto ranked = operand.getType().dyn_cast<RankedTensorType>(); |
| return ranked && ranked.getRank() == (axis + 1) && |
| ranked.getShape()[axis] == 1; |
| }); |
| }; |
| |
| // Returns true if all binary ops operands at `operand_idx` index are scalars. |
| auto is_all_scalars = [&](int operand_idx) -> bool { |
| return llvm::all_of(op.values(), [&](Value arg) -> bool { |
| if (exceptions.count(arg)) return true; |
| auto operand = arg.getDefiningOp()->getOperand(operand_idx); |
| auto ranked = operand.getType().dyn_cast<RankedTensorType>(); |
| return ranked && ranked.hasRank() && ranked.getRank() == 0; |
| }); |
| }; |
| |
| // Concat result type must be a ranked tensor. |
| auto ranked = op.getType().dyn_cast<RankedTensorType>(); |
| if (!ranked) return None; |
| |
| // TODO(ezhulenev): Add support for more valid concat patterns. |
| |
| // Tensor + Scalar: [..., 1] + [] <- scalar |
| // ^ |
| // \- axis is the innermost dimension. |
| // |
| // Concatenate tensor arguments on the same axis as the original operation, |
| // and concatenate scalars into the vector. |
| if (is_all_tensors(0, axis) && is_all_scalars(1)) { |
| std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())}; |
| auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType()); |
| return HoistParams{args(0), |
| args(1), |
| axis, |
| 0, |
| op.getType(), |
| rhs_type, |
| /*scalar_operand_idx=*/1}; |
| } else if (is_all_tensors(1, axis) && is_all_scalars(0)) { |
| std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())}; |
| auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType()); |
| return HoistParams{args(0), |
| args(1), |
| 0, |
| axis, |
| lhs_type, |
| op.getType(), |
| /*scalar_operand_idx=*/0}; |
| } |
| return None; |
| } |
| |
| } // namespace |
| |
| void ConcatV2Op::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<HoistCwiseBinaryOutOfConcat, HoistCwiseUnaryOutOfConcat>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CumsumOp and CumprodOp |
| //===----------------------------------------------------------------------===// |
| |
| template <typename OpT, typename std::enable_if<llvm::is_one_of< |
| OpT, CumsumOp, CumprodOp>::value>::type * = nullptr> |
| static LogicalResult Verify(OpT op) { |
| if (!IsOfRankOrUnranked(op.axis(), 0)) |
| return op.emitOpError("requires scalar axis operand"); |
| |
| DenseIntElementsAttr axis_attr; |
| if (matchPattern(op.axis(), m_Constant(&axis_attr))) { |
| auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>(); |
| if (input_ty) { |
| int64_t rank = input_ty.getRank(); |
| assert(axis_attr.getNumElements() == 1 && |
| "scalar attribute should have exactly one element"); |
| int64_t axis = (*axis_attr.begin()).getSExtValue(); |
| if (axis < -rank || axis >= rank) { |
| return op.emitError() |
| << "axis operand should be within range [" << -rank << ", " |
| << rank << "); actual value: " << axis; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| LogicalResult CumprodOp::verify() { return Verify(*this); } |
| LogicalResult CumsumOp::verify() { return Verify(*this); } |
| |
| //===----------------------------------------------------------------------===// |
| // ConcatOffsetOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConcatOffsetOp::verify() { |
| ConcatOffsetOp op = *this; |
| if (op.N() < 2) |
| return op.emitOpError() << "requires N to be at least 2, got " << op.N(); |
| |
| if (op.shape().size() != op.offset().size()) |
| return op.emitOpError() |
| << "requires sizes of shapes and offsets to be the same, got sizes " |
| << op.shape().size() << " and " << op.offset().size(); |
| |
| auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>(); |
| if (ranked_dim && ranked_dim.getRank() != 0) |
| return op.emitOpError() |
| << "requires concat_dim to be a scalar, got tensor of rank " |
| << ranked_dim.getRank(); |
| |
| int64_t num_dims = -1; |
| for (auto shape_offset_idx : |
| llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { |
| Value shape = std::get<0>(shape_offset_idx.value()); |
| Value offset = std::get<1>(shape_offset_idx.value()); |
| const size_t idx = shape_offset_idx.index(); |
| |
| if (failed(verifyCompatibleShape(shape.getType(), offset.getType()))) |
| return op.emitOpError() << "requires operand and result " << idx |
| << " to have compatible shapes"; |
| |
| auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>(); |
| if (!ranked_shape) continue; |
| |
| if (ranked_shape.getRank() != 1) |
| return op.emitOpError() << "requires shape tensor operand " << idx |
| << " to be of rank 1, got tensor of rank " |
| << ranked_shape.getRank(); |
| |
| if (!ranked_shape.hasStaticShape()) continue; |
| |
| int64_t ranked_shape_dim = ranked_shape.getDimSize(0); |
| if (num_dims == -1) |
| num_dims = ranked_shape_dim; |
| else if (ranked_shape_dim != num_dims) |
| return op.emitOpError() |
| << "requires shape tensor (rank 1) operand " << idx |
| << " to be of length " << num_dims |
| << ", got tensor (rank 1) of length " << ranked_shape_dim; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| // ConcatOffset must have its first operand be concat_dim and at least two |
| // shape tensors in variadic shapes operand. |
| if (operands.size() < 3) return failure(); |
| |
| // Check concat_dim is a scalar. |
| auto concat_dim_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); |
| if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) |
| return failure(); |
| |
| llvm::SmallVector<DenseIntElementsAttr, 4> shapes; |
| shapes.reserve(operands.size() - 1); |
| for (Attribute shape : llvm::drop_begin(operands, 1)) |
| if (auto shape_attr = shape.dyn_cast_or_null<DenseIntElementsAttr>()) |
| shapes.push_back(shape_attr); |
| else |
| return failure(); |
| |
| // Check all shapes are vectors of the same length. |
| if (shapes.front().getType().getRank() != 1) return success(); |
| const int64_t num_dims = shapes.front().getNumElements(); |
| for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) |
| if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims) |
| return failure(); |
| |
| // Check concat_dim is within [-num_dims, num_dims). |
| int32_t concat_dim = (*concat_dim_attr.getValues<int32_t>().begin()); |
| if (concat_dim < 0) concat_dim += num_dims; |
| if (concat_dim >= num_dims || concat_dim < 0) return failure(); |
| |
| // Check all elements besides at concat_dim match across all shape tensors. |
| SmallVector<int32_t, 4> shape0; |
| shape0.reserve(num_dims); |
| for (int32_t dim : shapes.front().getValues<int32_t>()) shape0.push_back(dim); |
| |
| for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { |
| for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { |
| if (dims_and_idx.index() == concat_dim) continue; |
| |
| if (std::get<0>(dims_and_idx.value()) != |
| std::get<1>(dims_and_idx.value()).getSExtValue()) |
| return failure(); |
| } |
| } |
| |
| // Compute an exclusive cumulative sum of elements at concat_dim. |
| results.reserve(shapes.size()); |
| SmallVector<int32_t, 4> cumulative_sum(num_dims, 0); |
| RankedTensorType offset_type = |
| RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32)); |
| for (DenseIntElementsAttr shape : shapes) { |
| results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); |
| cumulative_sum[concat_dim] += shape.getValues<int32_t>()[concat_dim]; |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConstOp |
| //===----------------------------------------------------------------------===// |
| |
| void ConstOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "cst"); |
| } |
| |
| OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.empty() && "constant has no operands"); |
| |
| // Return the held attribute value. |
| return value(); |
| } |
| |
| // Builds a constant op with the specified attribute `value`. The result |
| // op's type is deduced from `value`; if `value` is of scalar type, |
| // wraps it up with a tensor type of empty shape. |
| // TODO(jpienaar): This one differs from the autogenerated one as it takes an |
| // attribute but always creates an ElementsAttr internally. |
| void ConstOp::build(OpBuilder &builder, OperationState &result, |
| Attribute value) { |
| ShapedType type; |
| if (auto elem_attr = value.dyn_cast<ElementsAttr>()) { |
| return ConstOp::build(builder, result, elem_attr); |
| } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) { |
| // All TensorFlow types must be tensor types. In the build() method, |
| // we want to provide more flexibility by allowing attributes of scalar |
| // types. But we need to wrap it up with ElementsAttr to construct |
| // valid TensorFlow constants. |
| auto typed_attr = value.cast<TypedAttr>(); |
| type = RankedTensorType::get(/*shape=*/{}, typed_attr.getType()); |
| return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); |
| } |
| // TODO(jpienaar): support other TensorFlow specific types. |
| llvm_unreachable("unsupported attribute type for building tf.Const"); |
| } |
| |
| void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, |
| Attribute value) { |
| // Handle the case where the type and value are already tensors. |
| if (type.isa<TensorType>() && value.isa<ElementsAttr>()) { |
| result.addTypes(type); |
| result.addAttribute("value", value); |
| return; |
| } |
| |
| // Otherwise, default to the attribute builder. |
| ConstOp::build(builder, result, value); |
| assert(type == result.types[0] && "type mismatch in construction"); |
| } |
| |
| LogicalResult ConstOp::inferReturnTypes( |
| MLIRContext *context, Optional<Location> location, ValueRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<Type> &inferredReturnTypes) { |
| auto value = attributes.get("value"); |
| if (!value) return emitOptionalError(location, "missing attribute 'value'"); |
| if (auto elem_attr = value.dyn_cast<ElementsAttr>()) { |
| inferredReturnTypes.assign({elem_attr.getType()}); |
| return success(); |
| } |
| return emitOptionalError(location, |
| "attribute 'value' failed to satisfy constraint: " |
| "constant vector/tensor"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conv2DOp and Conv3DOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult VerifyConvOpAttributes( |
| int num_dims, ArrayRef<Attribute> strides, ArrayRef<Attribute> dilations, |
| llvm::Optional<mlir::Location> location) { |
| int64_t strides_size = strides.size(); |
| if (strides_size != num_dims) |
| return emitOptionalError( |
| location, "requires strides attribute length to be ", num_dims); |
| auto is_not_positive = [](Attribute val) { |
| return val.cast<IntegerAttr>().getValue().getSExtValue() <= 0; |
| }; |
| if (llvm::any_of(strides, is_not_positive)) |
| return emitOptionalError(location, "requires positive strides"); |
| |
| int64_t dilations_size = dilations.size(); |
| if (dilations_size != num_dims) |
| return emitOptionalError( |
| location, "requires dilations attribute length to be ", num_dims); |
| if (llvm::any_of(dilations, is_not_positive)) |
| return emitOptionalError(location, "requires positive dilations"); |
| |
| return success(); |
| } |
| |
| // Verifies that, |
| // * Number of input channels is divisible by the number of filter input |
| // channels |
| template <typename OpT, typename std::enable_if<llvm::is_one_of< |
| OpT, Conv2DOp, Conv3DOp>::value>::type * = nullptr> |
| static LogicalResult Verify(OpT op) { |
| int num_spatial_dims = std::is_same<OpT, Conv2DOp>() ? 2 : 3; |
| int num_dims = 2 + num_spatial_dims; |
| |
| StringRef data_format = op.data_format(); |
| tensorflow::TensorFormat format; |
| auto data_format_is_valid = FormatFromString(data_format.str(), &format); |
| if (!data_format_is_valid) { |
| return emitOptionalError(op.getLoc(), "Invalid data format provided"); |
| } |
| |
| const StringRef paddings = op.padding(); |
| tensorflow::Padding padding; |
| auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding); |
| if (!padding_is_valid.ok()) { |
| return emitOptionalError(op.getLoc(), "Invalid padding format provided"); |
| } |
| |
| // Verifies that, |
| // * Ranks of operands and result are valid |
| // * Length of explicit_paddings attribute is valid and has non negative |
| // elements |
| // * strides and dilations attributes have positive elements |
| if (!IsOfRankOrUnranked(op.input(), num_dims) || |
| !IsOfRankOrUnranked(op.filter(), num_dims)) |
| return emitOptionalError(op.getLoc(), "requires operands to be ", num_dims, |
| "D tensor"); |
| |
| if (padding == tensorflow::Padding::EXPLICIT) { |
| ArrayRef<Attribute> explicit_padding; |
| ArrayAttr explicit_pad = |
| op->getAttr("explicit_paddings") |
| .template dyn_cast_or_null<::mlir::ArrayAttr>(); |
| if (!explicit_pad) { |
| explicit_pad = ::mlir::Builder(op->getContext()).getI64ArrayAttr({}); |
| } |
| explicit_padding = explicit_pad.getValue(); |
| |
| if (explicit_padding.empty()) { |
| return emitOptionalError(op.getLoc(), |
| "requires attribute 'explicit_paddings' with " |
| "'EXPLICIT' padding mode"); |
| } |
| if (explicit_padding.size() != num_dims * 2) { |
| return emitOptionalError( |
| op.getLoc(), "requires explicit_paddings attribute length to be ", |
| num_dims * 2); |
| } |
| auto is_negative = [](Attribute val) { |
| return val.cast<IntegerAttr>().getValue().getSExtValue() < 0; |
| }; |
| if (llvm::any_of(explicit_padding, is_negative)) |
| return emitOptionalError(op.getLoc(), |
| "requires non negative explicit paddings"); |
| } |
| |
| ArrayRef<Attribute> strides = op.strides().getValue(); |
| ArrayRef<Attribute> dilations = op.dilations().getValue(); |
| if (failed( |
| VerifyConvOpAttributes(num_dims, strides, dilations, op.getLoc()))) { |
| return failure(); |
| } |
| |
| int64_t input_channels = -1; |
| if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) { |
| absl::string_view data_format(op.data_format().data(), |
| op.data_format().size()); |
| tensorflow::TensorFormat format; |
| auto is_valid = FormatFromString(data_format, &format); |
| DCHECK(is_valid) << data_format; |
| int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format); |
| input_channels = ty.getDimSize(idx); |
| } |
| |
| int64_t filter_channels = -1; |
| if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) { |
| int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( |
| num_dims, tensorflow::FORMAT_HWIO); |
| filter_channels = ty.getDimSize(idx); |
| } |
| |
| if (input_channels != -1 && filter_channels != -1 && |
| input_channels % filter_channels != 0) |
| return op.emitOpError() |
| << "requires the number of input channels to be divisible by the " |
| "number of filter input channels; found " |
| << input_channels << " and " << filter_channels << ", respectively"; |
| |
| return success(); |
| } |
| |
| LogicalResult Conv2DOp::verify() { return Verify(*this); } |
| LogicalResult Conv3DOp::verify() { return Verify(*this); } |
| |
| LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { |
| auto perm = GetDataFormatPermutation(this->data_format(), data_format); |
| if (perm.empty()) return failure(); |
| |
| // Update data_format attribute and result types. |
| if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); |
| |
| // Update convolution attributes. |
| (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); |
| (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); |
| (*this)->setAttr("explicit_paddings", |
| ShuffleArrayAttr(explicit_paddings(), perm, 2)); |
| |
| return success(); |
| } |
| |
| // Verifies the inferred return type of the given operation. |
| template <typename OpT, |
| typename std::enable_if<llvm::is_one_of< |
| OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type * = nullptr> |
| static LogicalResult inferConvReturnTypeComponents( |
| llvm::Optional<mlir::Location> location, OpT op, |
| ArrayRef<Attribute> explicit_padding, |
| llvm::SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| const int64_t num_spatial_dims = std::is_same<OpT, Conv2DOpAdaptor>() ? 2 : 3; |
| const int64_t num_dims = 2 + num_spatial_dims; |
| const Value input = op.input(); |
| const Value filter = op.filter(); |
| const TensorType input_ty = input.getType().template cast<TensorType>(); |
| const TensorType filter_ty = filter.getType().template cast<TensorType>(); |
| |
| ArrayRef<Attribute> strides = op.strides().getValue(); |
| StringRef data_format = op.data_format(); |
| ArrayRef<Attribute> dilations = op.dilations().getValue(); |
| |
| tensorflow::TensorFormat format; |
| auto data_format_is_valid = FormatFromString(data_format.str(), &format); |
| assert(data_format_is_valid); |
| (void)data_format_is_valid; |
| |
| tensorflow::Padding padding; |
| const StringRef paddings = op.padding(); |
| auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding); |
| assert(padding_is_valid.ok()); |
| (void)padding_is_valid; |
| |
| auto get_int = [](Attribute attr) { |
| return attr.template cast<IntegerAttr>().getInt(); |
| }; |
| |
| // Output always have `num_dims` rank. All dimensions are initialized to |
| // dynamic size and can be partially inferred. |
| SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize); |
| // Output batch and channel dimension can be obtained using utilities from |
| // tensorflow/core/util/tensor_format.h. |
| if (input_ty.hasRank()) { |
| return_shape[GetTensorBatchDimIndex(num_dims, format)] = |
| input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format)); |
| } |
| if (filter_ty.hasRank()) { |
| return_shape[GetTensorFeatureDimIndex(num_dims, format)] = |
| filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex( |
| num_dims, tensorflow::FORMAT_HWIO)); |
| } |
| // Spatial dimensions can be inferred only when both input and filter are |
| // ranked because we need to get their spatial dimensions. |
| if (input_ty.hasRank() && filter_ty.hasRank()) { |
| // Checks the size of each of the output spatial dimensions. |
| for (auto i : llvm::seq<int>(0, num_spatial_dims)) { |
| const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i); |
| int64_t stride = get_int(strides[dim]); |
| int64_t expected_output_size; |
| int64_t pad_low; |
| int64_t pad_high; |
| // Retrieve padding, if defined explicitly. |
| if (padding == tensorflow::Padding::EXPLICIT) { |
| pad_low = get_int(explicit_padding[2 * dim]); |
| pad_high = get_int(explicit_padding[2 * dim + 1]); |
| } |
| // Skip if input or filter size is dynamic. |
| if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue; |
| // Calculate the expected_output_size. |
| tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( |
| input_ty.getDimSize(dim), filter_ty.getDimSize(i), |
| get_int(dilations[dim]), stride, padding, &expected_output_size, |
| &pad_low, &pad_high); |
| // Return failure if expected_output_size could not be calculated. |
| if (!status.ok()) return failure(); |
| return_shape[dim] = expected_output_size; |
| } |
| } |
| |
| inferredReturnShapes.emplace_back(return_shape, input_ty.getElementType()); |
| return success(); |
| } |
| |
| LogicalResult Conv2DOp::inferReturnTypeComponents( |
| MLIRContext *context, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| Conv2DOpAdaptor op(operands.getValues(), attributes); |
| ArrayRef<Attribute> explicit_padding; |
| ArrayAttr explicit_pad = |
| attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>(); |
| if (!explicit_pad) { |
| explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({}); |
| } |
| explicit_padding = explicit_pad.getValue(); |
| |
| return inferConvReturnTypeComponents(location, op, explicit_padding, |
| inferredReturnShapes); |
| } |
| |
| StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) |
| return data_format(); |
| |
| // Input must be a tensor. |
| auto input_ty = input().getType().dyn_cast<TensorType>(); |
| if (!input_ty) return data_format(); |
| |
| // For f16 data type on devices with Tensor Cores support NHWC data format |
| // is up to ~2x faster. |
| const bool is_f16 = input_ty.getElementType().isF16(); |
| if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; |
| |
| // For f32/f16 data type decision depends on the filter size in spatial |
| // dimensions, for other data types we keep current data format. |
| if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) |
| return data_format(); |
| |
| // Keep current data format if filter rank is unknown or not equal to 4. |
| auto filter_ty = filter().getType().dyn_cast<RankedTensorType>(); |
| if (!filter_ty || filter_ty.getRank() != 4) return data_format(); |
| |
| const int64_t d0 = filter_ty.getDimSize(0); |
| const int64_t d1 = filter_ty.getDimSize(1); |
| |
| auto all_ones = [](ArrayAttr arr) -> bool { |
| return llvm::all_of(arr, [](Attribute attr) -> bool { |
| return attr.cast<IntegerAttr>().getInt() == 1; |
| }); |
| }; |
| |
| // Convolutions with 1x1 filter and with strides and dilations all ones, can |
| // be computed as a GEMM in NHWC data format, and can be up to ~2x times |
| // faster than convolution in NCHW. |
| const bool one_by_one = d0 == 1 && d1 == 1; |
| const bool trivial_strides = all_ones(strides()); |
| const bool trivial_dilations = all_ones(dilations()); |
| |
| // TODO(ezhulenev): This might lead to excessive transposes in the final IR, |
| // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. |
| // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all |
| // users can efficiently use NHWC data format? |
| if (one_by_one && trivial_strides && trivial_dilations) { |
| return "NHWC"; |
| } |
| |
| // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because |
| // it's the fastest option on NVIDIA GPUs with cuDNN library support. |
| return "NCHW"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conv2dBackpropFilterOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { |
| StringRef src_data_format = this->data_format(); |
| |
| auto perm = GetDataFormatPermutation(src_data_format, data_format); |
| if (perm.empty()) return failure(); |
| |
| // Update data_format attribute and result types. |
| if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); |
| |
| // Update convolution attributes. |
| (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); |
| (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); |
| (*this)->setAttr("explicit_paddings", |
| ShuffleArrayAttr(explicit_paddings(), perm, 2)); |
| |
| // Permute filter sizes operand. |
| OpBuilder builder(getOperation()); |
| auto filter_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>( |
| getLoc(), filter_sizes(), StringAttr::get(getContext(), src_data_format), |
| StringAttr::get(getContext(), data_format)); |
| setOperand(1, filter_sizes_permuted); |
| |
| return success(); |
| } |
| |
| StringRef Conv2DBackpropFilterOp::GetOptimalLayout( |
| const RuntimeDevices &devices) { |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) |
| return data_format(); |
| |
| // Input must be a tensor. |
| auto input_ty = input().getType().dyn_cast<TensorType>(); |
| if (!input_ty) return data_format(); |
| |
| // For f16 data type on devices with Tensor Cores support NHWC data format |
| // is up to ~2x faster. |
| const bool is_f16 = input_ty.getElementType().isF16(); |
| if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; |
| |
| // Otherwise always use "NCHW". |
| return "NCHW"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conv2DBackpropInputOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult Conv2DBackpropInputOp::verify() { |
| Conv2DBackpropInputOp op = *this; |
| int num_spatial_dims = 2; |
| int num_dims = 2 + num_spatial_dims; |
| |
| if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) || |
| !IsOfRankOrUnranked(op.filter(), num_dims)) |
| return op.emitOpError() |
| << "requires operands to be " << num_dims << "D tensor"; |
| if (!IsOfRankOrUnranked(op.getResult(), num_dims)) |
| return op.emitOpError() |
| << "requires result to be " << num_dims << "D tensor"; |
| |
| llvm::Optional<mlir::Location> location = op.getLoc(); |
| ArrayRef<Attribute> strides = op.strides().getValue(); |
| ArrayRef<Attribute> dilations = op.dilations().getValue(); |
| LogicalResult verify_result = |
| VerifyConvOpAttributes(num_dims, strides, dilations, location); |
| if (failed(verify_result)) { |
| return verify_result; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { |
| StringRef src_data_format = this->data_format(); |
| |
| auto perm = GetDataFormatPermutation(src_data_format, data_format); |
| if (perm.empty()) return failure(); |
| |
| // Update data_format attribute and result types. |
| if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); |
| |
| // Update convolution attributes. |
| (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); |
| (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); |
| (*this)->setAttr("explicit_paddings", |
| ShuffleArrayAttr(explicit_paddings(), perm, 2)); |
| |
| // Permute input sizes operand. |
| OpBuilder builder(getOperation()); |
| auto input_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>( |
| getLoc(), input_sizes(), StringAttr::get(getContext(), src_data_format), |
| StringAttr::get(getContext(), data_format)); |
| setOperand(0, input_sizes_permuted); |
| |
| return success(); |
| } |
| |
| StringRef Conv2DBackpropInputOp::GetOptimalLayout( |
| const RuntimeDevices &devices) { |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) |
| return data_format(); |
| |
| // Filter must be a tensor. |
| auto filter_ty = filter().getType().dyn_cast<TensorType>(); |
| if (!filter_ty) return data_format(); |
| |
| // For f16 data type on devices with Tensor Cores support NHWC data format |
| // is up to ~2x faster. |
| const bool is_f16 = filter_ty.getElementType().isF16(); |
| if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; |
| |
| // Otherwise always use "NCHW". |
| return "NCHW"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conv3DOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult Conv3DOp::inferReturnTypeComponents( |
| MLIRContext *context, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| Conv3DOpAdaptor op(operands.getValues(), attributes); |
| ArrayRef<Attribute> explicit_padding; |
| ArrayAttr explicit_pad = |
| attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>(); |
| if (!explicit_pad) { |
| explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({}); |
| } |
| explicit_padding = explicit_pad.getValue(); |
| |
| return inferConvReturnTypeComponents(location, op, explicit_padding, |
| inferredReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DataFormatVecPermuteOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DataFormatVecPermuteOp::verify() { |
| DataFormatVecPermuteOp op = *this; |
| auto input_ty = op.x().getType().dyn_cast<RankedTensorType>(); |
| if (!input_ty) return success(); |
| |
| int rank = input_ty.getRank(); |
| if (rank != 1 && rank != 2) |
| return op.emitOpError("requires input of rank 1 or 2"); |
| |
| if (rank == 1) { |
| int64_t dim0 = input_ty.getDimSize(0); |
| if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) |
| return op.emitOpError("requires 1D input of size 4 or size 2"); |
| } |
| |
| if (rank == 2) { |
| int64_t dim0 = input_ty.getDimSize(0); |
| if (dim0 != ShapedType::kDynamicSize && dim0 != 4) |
| return op.emitOpError( |
| "requires first dimensions of 2D input to be of size 4"); |
| |
| int64_t dim1 = input_ty.getDimSize(1); |
| if (dim1 != ShapedType::kDynamicSize && dim1 != 2) |
| return op.emitOpError( |
| "requires second dimensions of 2D input to be of size 2"); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DivNoNanOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Canonicalization template for tf.DivNoNan and tf.MulNoNan: |
| /// If the op is tf.DivNoNan and the divisor is a constant tensor (with all the |
| /// elements of any allowed type: float or complex), rewrite the op to the |
| /// divisor if all the elements of the divisor are zero and to tf.Div if all the |
| /// elements of the divisor are non-zero. |
| |
| /// Similarly, if the op is tf.MulNoNan and the multiplier is a constant tensor |
| /// (with all the elements of any allowed type: float or complex), rewrite the |
| /// op to the multiplier if all the elements of the multiplier are zero and to |
| /// tf.Mul if all the elements of the multiplier are non-zero. |
| |
| /// Replace the given op with an op of type `RetT`. Upon calling |
| /// DivNoNanOrMulNoNanConstantY for canonicalizing tf.DivNoNan, tf.DivOp is |
| /// passed as the second argument and for canonicalizing tf.MulNoNan, tf.MulOp |
| /// is passed as the second argument. |
| template <typename OpT, typename RetT> |
| class DivNoNanOrMulNoNanConstantY : public OpRewritePattern<OpT> { |
| using OpRewritePattern<OpT>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(OpT op, |
| PatternRewriter &rewriter) const override { |
| static_assert( |
| llvm::is_one_of<OpT, DivNoNanOp, MulNoNanOp>::value, |
| "only canonicalization of tf.DivNoNan and tf.MulNoNan is supported"); |
| |
| // Returns true iff `val` (a complex constant with float real and imaginary |
| // parts) is zero. |
| auto complexIsZero = [](const std::complex<APFloat> val) { |
| // Note that when `val` is of complex type, it is zero iff both |
| // its real and imaginary parts are zero. |
| if (val.real().isZero() && val.imag().isZero()) |
| return true; |
| else |
| return false; |
| }; |
| |
| // Returns true iff `attr` has both zero and non-zero elements |
| // (float/complex type) in `attr`. |
| auto hasBothZeroAndNonzeroElements = |
| [&complexIsZero](ElementsAttr attr, bool hasComplexElements) { |
| bool foundZero = false, foundNonzero = false; |
| if (!hasComplexElements) { |
| for (const auto val : attr.getValues<APFloat>()) { |
| if (val.isZero()) |
| foundZero = true; |
| else |
| foundNonzero = true; |
| if (foundZero && foundNonzero) return true; |
| } |
| } else { |
| for (const auto val : attr.getValues<std::complex<APFloat>>()) { |
| if (complexIsZero(val)) |
| foundZero = true; |
| else |
| foundNonzero = true; |
| if (foundZero && foundNonzero) return true; |
| } |
| } |
| return false; |
| }; |
| |
| // Note that `y` is the divisor if the op is tf.DivNoNan and it is the |
| // multiplier if the op is tf.MulNoNan. |
| Value y = op.y(); |
| // The below if condition is true iff `y.getDefiningOp()` is of the type |
| // TF::ConstOp, i.e., if `y` is defined by an op and it is the tf.Const op. |
| // In that case, `yDefOp` stores this tf.Const op. |
| // Note that if `y` is a block argument, `y.getDefiningOp()` will return |
| // null, which will get propogated by dyn_cast_or_null to `yDefOp`. |
| // Further, if `y` is defined by an op other than tf.Const, |
| // `y.getDefiningOp()` will not return null but dyn_cast_or_null will. |
| if (auto yDefOp = dyn_cast_or_null<TF::ConstOp>(y.getDefiningOp())) { |
| Type typeOfElementsInY = getElementTypeOrSelf(y.getType()); |
| ElementsAttr attr = yDefOp.value(); |
| bool yHasComplexElements = typeOfElementsInY.isa<ComplexType>(); |
| |
| // If `y` is a splat constant, then the op will definitely get replaced. |
| // We check for a splat constant first, in order to optimize the |
| // performance of this canonicalization because this check will be O(1). |
| if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { |
| bool splatAttrIsZero = false; |
| if (!yHasComplexElements) { |
| if (splatAttr.getSplatValue<APFloat>().isZero()) |
| splatAttrIsZero = true; |
| } else { |
| if (complexIsZero(splatAttr.getSplatValue<std::complex<APFloat>>())) |
| splatAttrIsZero = true; |
| } |
| if (splatAttrIsZero) { |
| // When `y` is a zero splat constant (i.e., all the elements in `y` |
| // are zero, replace the op (tf.divNoNan or tf.MulNoNan) with `y`. |
| rewriter.replaceOp(op, y); |
| } else { |
| // When `y` is a non-zero splat constant, replace tf.DivNoNan with |
| // tf.Div and tf.MulNoNan with tf.Mul. |
| rewriter.replaceOpWithNewOp<RetT>(op, op->getResult(0).getType(), |
| op->getOperand(0), |
| op->getOperand(1)); |
| } |
| return success(); |
| } |
| |
| // If `y` has both zero and non-zero elements, do nothing. |
| if (hasBothZeroAndNonzeroElements(attr, yHasComplexElements)) { |
| return failure(); |
| } else { |
| // When all the elements in `y` are non-splat and non-zero, replace |
| // tf.DivNoNan with tf.Div and tf.MulNoNan with tf.Mul. |
| rewriter.replaceOpWithNewOp<RetT>(op, op->getResult(0).getType(), |
| op->getOperand(0), op->getOperand(1)); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| } // namespace |
| |
| void DivNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<DivNoNanOrMulNoNanConstantY<TF::DivNoNanOp, TF::DivOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DivOp |
| //===----------------------------------------------------------------------===// |
| |
| void DivOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<DivWithSqrtDivisor>(context); |
| } |
| |
| OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { |
| return IdentityArithmeticOpFolder<DivOp>(*this, operands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DynamicStitchOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DynamicStitchOp::verify() { |
| DynamicStitchOp op = *this; |
| if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); |
| |
| if (RankedTensorType out_ty = op.getType().dyn_cast<RankedTensorType>()) { |
| if (out_ty.getRank() == 0) { |
| return op.emitOpError("requires non scalar output"); |
| } |
| } |
| |
| llvm::SmallDenseSet<int64_t, 8> index_values; |
| bool all_indices_const = true; |
| int32_t max_index = -1; |
| llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape; |
| for (auto it : llvm::zip(op.indices(), op.data())) { |
| Value index = std::get<0>(it); |
| |
| DenseIntElementsAttr index_attr; |
| if (matchPattern(index, m_Constant(&index_attr))) { |
| for (int32_t index : index_attr.getValues<int32_t>()) { |
| if (index < 0) |
| return op.emitOpError() |
| << "requires non-negative index values; found " << index; |
| max_index = std::max(index, max_index); |
| index_values.insert(index); |
| } |
| } else { |
| all_indices_const = false; |
| } |
| |
| Value data = std::get<1>(it); |
| RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>(); |
| RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>(); |
| if (!index_ty || !data_ty) continue; |
| |
| int64_t index_rank = index_ty.getRank(); |
| ArrayRef<int64_t> data_shape = data_ty.getShape(); |
| ArrayRef<int64_t> index_shape = index_ty.getShape(); |
| if (failed(mlir::verifyCompatibleShape(index_shape, |
| data_shape.take_front(index_rank)))) |
| return op.emitOpError() << "requires shape of data with type " << data_ty |
| << " to have prefix matching with shape of the " |
| "corresponding index type " |
| << index_ty; |
| |
| ArrayRef<int64_t> item_shape = data_shape.drop_front(index_rank); |
| if (!inferred_item_shape) { |
| inferred_item_shape = llvm::to_vector<4>(item_shape); |
| continue; |
| } |
| |
| if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) |
| return op.emitOpError() << "has inconsistent shaped data and index " |
| "pairs; inferred item shapes [" |
| << llvm::makeArrayRef(*inferred_item_shape) |
| << "] and [" << item_shape << "] don't match"; |
| for (int i = 0, e = item_shape.size(); i < e; ++i) { |
| int64_t &inferred_dim = (*inferred_item_shape)[i]; |
| int64_t dim = item_shape[i]; |
| if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; |
| } |
| } |
| |
| // If all indices are constants, then verify that they cover all indices in |
| // the range [0, max_index] and the output type is legal. |
| if (all_indices_const) { |
| for (int32_t i = 0; i <= max_index; i++) { |
| if (!index_values.count(i)) |
| return op.emitOpError() << "missing index " << i; |
| } |
| |
| if (inferred_item_shape) { |
| SmallVector<int64_t, 4> expected_shape; |
| expected_shape.push_back(max_index + 1); |
| expected_shape.append(inferred_item_shape->begin(), |
| inferred_item_shape->end()); |
| |
| auto out_ty = op.getType().cast<TensorType>(); |
| auto expected_out_ty = |
| RankedTensorType::get(expected_shape, out_ty.getElementType()); |
| |
| if (!AreCastCompatible({out_ty, expected_out_ty})) { |
| return op.emitOpError() << "has invalid output type; should be " |
| "compatible with inferred type " |
| << expected_out_ty; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EinsumOp |
| //===----------------------------------------------------------------------===// |
| |
| // Verifies that, |
| // * Arity of the op is at most two. |
| // |
| // TODO(hinsu): Verify einsum equation attribute. |
| LogicalResult EinsumOp::verify() { |
| EinsumOp op = *this; |
| if (op.N() > 2) { |
| return op.emitOpError("supports at most two operands"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EmptyOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult EmptyOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 1 && "empty op has one operand"); |
| |
| Attribute attr = operands.front(); |
| if (!attr) return {}; |
| |
| auto int_attr = attr.cast<DenseIntElementsAttr>(); |
| SmallVector<int64_t, 6> out_shape; |
| for (const auto val : int_attr.getValues<int32_t>()) { |
| out_shape.push_back(val); |
| } |
| |
| auto type = getResult().getType().cast<ShapedType>(); |
| auto etype = type.getElementType(); |
| |
| // We can not fold if the result is not static. |
| if (!type.hasStaticShape()) return {}; |
| |
| if (auto float_type = etype.dyn_cast<FloatType>()) { |
| auto out_type = RankedTensorType::get(out_shape, float_type); |
| return DenseElementsAttr::get(out_type, |
| {APFloat(float_type.getFloatSemantics())}); |
| } |
| |
| if (auto int_type = etype.dyn_cast<IntegerType>()) { |
| auto out_type = RankedTensorType::get(out_shape, etype); |
| APInt val(int_type.getWidth(), 0, int_type.getSignedness()); |
| return DenseElementsAttr::get(out_type, val); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EmptyTensorListOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult EmptyTensorListOp::verify() { |
| EmptyTensorListOp op = *this; |
| // This is required to populate derived attributes during export in a |
| // meaningful way. Else during export to GraphDef element_type() query |
| // will result in out of bounds access/assert. |
| if (handle_dtype().getSubtypes().size() != 1) { |
| return emitOpError( |
| "must have exactly one subtype in the result variant type"); |
| } |
| |
| if (!IsOfRankOrUnranked(op.element_shape(), 0) && |
| !IsOfRankOrUnranked(op.element_shape(), 1)) { |
| return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); |
| } |
| |
| if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { |
| return op.emitOpError("requires max_num_elements operand to be 0D tensor"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EnqueueTPUEmbedding ops |
| //===----------------------------------------------------------------------===// |
| |
| // For EnqueueTPUEmbedding ops the device ordinal corresponds to the resource |
| // instance. |
| |
| std::string |
| EnqueueTPUEmbeddingArbitraryTensorBatchOp::GetResourceInstanceStr() { |
| return std::to_string(device_ordinal()); |
| } |
| |
| std::string EnqueueTPUEmbeddingBatchOp::GetResourceInstanceStr() { |
| return std::to_string(device_ordinal()); |
| } |
| |
| std::string EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceStr() { |
| return std::to_string(device_ordinal()); |
| } |
| |
| std::string EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceStr() { |
| return std::to_string(device_ordinal()); |
| } |
| |
| std::string EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceStr() { |
| return std::to_string(device_ordinal()); |
| } |
| |
| std::string EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() { |
| return std::to_string(device_ordinal()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EnsureShapeOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef<mlir::Attribute>) { |
| ShapedType type = input().getType().dyn_cast<ShapedType>(); |
| if (!type || !type.hasRank()) return {}; |
| // If shape attribute equals input operand's type's shape, fold it to input. |
| llvm::Optional<llvm::ArrayRef<int64_t>> shape_constraint = shape(); |
| if (type.getShape() == shape_constraint) return input(); |
| |
| // If input operand's type's shape always satisfies the shape attribute, fold |
| // it to input. |
| if (shape_constraint.has_value() && |
| shape_constraint->size() == type.getShape().size()) { |
| for (int i = 0; i < shape_constraint->size(); ++i) { |
| if (!ShapedType::isDynamic(shape_constraint.getValue()[i]) && |
| type.getDimSize(i) != shape_constraint.getValue()[i]) { |
| return {}; |
| } |
| } |
| return input(); |
| } |
| // Else retain to enable failing dynamically. |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EqualOp/NotEqualOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult EqualOp::verify() { |
| EqualOp op = *this; |
| // If we allow inputs to have incompatible type, then nothing to do. |
| if (!op.incompatible_shape_error()) return success(); |
| |
| // Otherwise, check inputs are broadcastable. |
| return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( |
| op.getOperation()); |
| } |
| |
| void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, |
| Value y, BoolAttr incompatible_shape_error) { |
| auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, |
| incompatible_shape_error); |
| return build(builder, result, result_type, x, y, incompatible_shape_error); |
| } |
| |
| namespace { |
| |
| // Flips the incompatible_shape_error attribute to true if the shapes are known |
| // to be compatible. |
| template <typename Ty> |
| static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter &rewriter) { |
| if (op.incompatible_shape_error()) { |
| return rewriter.notifyMatchFailure(op, "the attribute is already true"); |
| } |
| |
| // incompatible_shape_error=false implies that the op will either return a |
| // valid result or a scalar boolean indicating the error. For unranked outputs |
| // we don't know which one it is. TF shape inference turns unranked outputs |
| // into ranked ones if it can statically evaluate the broadcast, see the shape |
| // function of tf.Equal. |
| auto ty = op.getType().template dyn_cast<RankedTensorType>(); |
| if (!ty) { |
| return rewriter.notifyMatchFailure(op, "requires a ranked output shape"); |
| } |
| |
| // Unless this is a scalar compare, a scalar output indicates that this will |
| // always fail. |
| auto x_ty = op.x().getType().template dyn_cast<RankedTensorType>(); |
| auto y_ty = op.y().getType().template dyn_cast<RankedTensorType>(); |
| if (ty.getRank() == 0 && |
| (!x_ty || x_ty.getRank() != 0 || !y_ty || y_ty.getRank() != 0)) { |
| return rewriter.notifyMatchFailure(op, "output rank must match input rank"); |
| } |
| |
| // Shapes are known to be compatible. |
| rewriter.template replaceOpWithNewOp<Ty>(op, op.x(), op.y(), |
| rewriter.getBoolAttr(true)); |
| return success(); |
| } |
| } // namespace |
| |
| void EqualOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add(flipComatibleShapeError<EqualOp>); |
| } |
| |
| void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add(flipComatibleShapeError<NotEqualOp>); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ExpandDimsOp |
| //===----------------------------------------------------------------------===// |
| |
| Type InferExpandDimsOpType(Value input, Value dim) { |
| Type element_ty = input.getType().cast<TensorType>().getElementType(); |
| auto unranked_ty = UnrankedTensorType::get(element_ty); |
| |
| auto input_ty = input.getType().dyn_cast<RankedTensorType>(); |
| if (!input_ty) return unranked_ty; |
| |
| DenseIntElementsAttr dim_attr; |
| if (!matchPattern(dim, m_Constant(&dim_attr)) || |
| dim_attr.getNumElements() != 1) |
| return unranked_ty; |
| int64_t dim_val = (*dim_attr.begin()).getSExtValue(); |
| int64_t input_rank = input_ty.getRank(); |
| |
| if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty; |
| if (dim_val < 0) dim_val += input_rank + 1; |
| |
| SmallVector<int64_t, 4> shape = llvm::to_vector<4>(input_ty.getShape()); |
| shape.insert(shape.begin() + dim_val, 1); |
| return RankedTensorType::get(shape, element_ty); |
| } |
| |
| void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, |
| Value input, Value dim) { |
| return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Expm1Op |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult Expm1Op::inferReturnTypeComponents( |
| MLIRContext *context, Optional<Location> location, ValueShapeRange operands, |
| DictionaryAttr attributes, RegionRange regions, |
| SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { |
| ShapeAdaptor adaptor = operands.getShape(0); |
| ShapedTypeComponents component(adaptor.getElementType()); |
| if (adaptor.hasRank()) adaptor.getDims(component); |
| inferredReturnShapes.push_back(component); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FakeQuantWithMinMaxArgsOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult FakeQuantWithMinMaxArgsOp::verify() { |
| FakeQuantWithMinMaxArgsOp op = *this; |
| // TODO(fengliuai): moving the following to an utility method. |
| const llvm::fltSemantics &semantics = op.min().getSemantics(); |
| float rmin, rmax; |
| if (&semantics == &APFloat::IEEEsingle()) { |
| rmin = op.min().convertToFloat(); |
| rmax = op.max().convertToFloat(); |
| } else { |
| rmin = op.min().convertToDouble(); |
| rmax = op.max().convertToDouble(); |
| } |
| // Range boundaries must be valid. |
| if (rmin >= rmax) { |
| return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + |
| "," + Twine(std::to_string(rmax)) + "]"); |
| } |
| int64_t num_bits = op.num_bits(); |
| if (num_bits < 2 || num_bits > 16) { |
| return op.emitOpError( |
| "requires num_bits to be between 2 and 16, inclusive"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FakeQuantWithMinMaxVarsOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult FakeQuantWithMinMaxVarsOp::verify() { |
| FakeQuantWithMinMaxVarsOp op = *this; |
| auto min = GetRankedTensorTypeForOperand(op.min()); |
| if (min && !IsOfRankedFloatTensorType(min, 0)) |
| return op.emitOpError("requires min to be a 0d float tensor"); |
| |
| auto max = GetRankedTensorTypeForOperand(op.max()); |
| if (max && !IsOfRankedFloatTensorType(max, 0)) |
| return op.emitOpError("requires max to be a 0d float tensor"); |
| |
| int64_t num_bits = op.num_bits(); |
| if (num_bits < 2 || num_bits > 16) { |
| return op.emitOpError( |
| "requires num_bits to be between 2 and 16, inclusive"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FakeQuantWithMinMaxVarsPerChannelOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() { |
| FakeQuantWithMinMaxVarsPerChannelOp op = *this; |
| auto min = GetRankedTensorTypeForOperand(op.min()); |
| if (min && !IsOfRankedFloatTensorType(min, 1)) |
| return op.emitOpError("requires min to be a 1d float tensor"); |
| |
| auto max = GetRankedTensorTypeForOperand(op.max()); |
| if (max && !IsOfRankedFloatTensorType(max, 1)) |
| return op.emitOpError("requires max to be a 1d float tensor"); |
| |
| Value inputs = op.inputs(); |
| if (!HasRankAtLeast(inputs, 1)) |
| return op.emitError("requires inputs to be at least 1d float tensor"); |
| |
| int64_t num_bits = op.num_bits(); |
| if (num_bits < 2 || num_bits > 16) { |
| return op.emitOpError( |
| "requires num_bits to be between 2 and 16, inclusive"); |
| } |
| |
| auto inputs_type = inputs.getType().dyn_cast<RankedTensorType>(); |
| if (!inputs_type) return success(); |
| int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); |
| if ((min && min.getDimSize(0) != depth) || |
| (max && max.getDimSize(0) != depth)) { |
| return op.emitOpError( |
| "requires min and max to have same size as last dimension of inputs"); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FillOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult FillOp::verify() { |
| FillOp op = *this; |
| if (!IsOfRankOrUnranked(op.dims(), 1)) |
| return op.emitOpError() << "requires dims to be a 1D tensor"; |
| if (!IsOfRankOrUnranked(op.value(), 0)) |
| return op.emitOpError() << "requires value to be a scalar"; |
| |
| return success(); |
| } |
| |
| static ShapedType InferFillOpType(Value dims, Value value) { |
| Type etype = value.getType().cast<ShapedType>().getElementType(); |
| |
| DenseIntElementsAttr dims_attr; |
| if (matchPattern(dims, m_Constant(&dims_attr))) { |
| llvm::SmallVector<int64_t, 4> shape; |
| shape.reserve(dims_attr.getNumElements()); |
| for (const APInt dim : dims_attr.getValues<APInt>()) { |
| shape.push_back(dim.getSExtValue()); |
| } |
| return RankedTensorType::get(shape, etype); |
| } |
| |
| if (auto shape_op = dims.getDefiningOp<ShapeOp>()) { |
| if (auto t = shape_op.input().getType().dyn_cast<ShapedType>()) { |
| return t; |
| } |
| } |
| |
| return UnrankedTensorType::get(etype); |
| } |
| |
| void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, |
| Value value) { |
| FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); |
| } |
| |
| OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 2 && "fill op has two operand"); |
| |
| auto type = getType().cast<ShapedType>(); |
| // DenseElementsAttr that is used in this folder only supports int and float |
| // types. |
| // TODO(hinsu): Handle complex types once there is a attribute kind for |
| // complex. |
| if (!type.getElementType().isIntOrFloat()) return {}; |
| |
| auto value = operands[1].dyn_cast_or_null<ElementsAttr>(); |
| if (!value) return {}; |
| |
| if (type.hasStaticShape()) |
| return DenseElementsAttr::get(type, value.getValues<Attribute>()[0]); |
| |
| auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); |
| if (!dims) return {}; |
| |
| llvm::SmallVector<int64_t, 4> shape; |
| shape.reserve(dims.getNumElements()); |
| for (const APInt dim : dims.getValues<APInt>()) { |
| shape.push_back(dim.getSExtValue()); |
| } |
| type = RankedTensorType::get(shape, type.getElementType()); |
| |
| return DenseElementsAttr::get(type, value.getValues<Attribute>()[0]); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FusedBatchNormGradOp |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(b/150954845): Add benchmarks to verify that layout preference didn't |
| // change in the latest GPU generations. |
| |
| LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { |
| return ::mlir::TF::UpdateDataFormat(data_format, this); |
| } |
| |
| StringRef FusedBatchNormGradV3Op::GetOptimalLayout( |
| const RuntimeDevices &devices) { |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) |
| return data_format(); |
| |
| // For f16 data type on devices with Tensor Cores support NHWC data format |
| // is up to ~2x faster. |
| auto x_ty = x().getType().cast<TensorType>(); |
| const bool is_f16 = x_ty.getElementType().isF16(); |
| if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; |
| |
| // For all other data types prefer NCHW. |
| return "NCHW"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FusedBatchNormOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult FusedBatchNormOp::verify() { |
| FusedBatchNormOp op = *this; |
| auto x = GetRankedTensorTypeForOperand(op.x()); |
| if (x && !IsOfRankedFloatTensorType(x, 4)) |
| return op.emitOpError("requires x to be a 4D float tensor"); |
| |
| auto scale = GetRankedTensorTypeForOperand(op.scale()); |
| if (scale && !IsOfRankedFloatTensorType(scale, 1)) |
| return op.emitOpError("requires scale to be a 1D float tensor"); |
| |
| auto offset = GetRankedTensorTypeForOperand(op.offset()); |
| if (offset && !IsOfRankedFloatTensorType(offset, 1)) |
| return op.emitOpError("requires offset to be a 1D float tensor"); |
| |
| auto mean = GetRankedTensorTypeForOperand(op.mean()); |
| if (mean && !IsOfRankedFloatTensorType(mean, 1)) |
| return op.emitOpError("requires mean to be a 1D float tensor"); |
| |
| auto variance = GetRankedTensorTypeForOperand(op.variance()); |
| if (variance && !IsOfRankedFloatTensorType(variance, 1)) |
| return op.emitOpError("requires variance to be a 1D float tensor"); |
| |
| // TODO(antiagainst): check attributes |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FusedBatchNormV2Op / FusedBatchNormV3Op |
| //===----------------------------------------------------------------------===// |
| |
| template <class Op> |
| static LogicalResult InferenceFoldOperandsPermutation( |
| ArrayRef<int64_t> permutation, Op *op) { |
| // FusedBatchNorm in training mode is a layout sentitive operation, and should |
| // have already assigned an optimal data format. |
| if (op->is_training()) return failure(); |
| return ::mlir::TF::FoldOperandsPermutation(permutation, op); |
| } |
| |
| template <class Op> |
| static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) { |
| // In inference mode FusedBatchNorm is not sensitive to data layout. |
| if (!op->is_training()) return op->data_format(); |
| |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation())) |
| return op->data_format(); |
| |
| // For f16 data type on devices with Tensor Cores support NHWC data format |
| // is up to ~2x faster. |
| auto x_ty = op->x().getType().template cast<TensorType>(); |
| const bool is_f16 = x_ty.getElementType().isF16(); |
| if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; |
| |
| // For all other data types prefer NCHW. |
| return "NCHW"; |
| } |
| |
| LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation( |
| ArrayRef<int64_t> permutation) { |
| return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); |
| } |
| |
| LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) { |
| return ::mlir::TF::UpdateDataFormat(data_format, this); |
| } |
| |
| StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) { |
| return ::mlir::TF::GetOptimalLayout(devices, this); |
| } |
| |
| LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( |
| ArrayRef<int64_t> permutation) { |
| return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); |
| } |
| |
| LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { |
| return ::mlir::TF::UpdateDataFormat(data_format, this); |
| } |
| |
| StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { |
| return ::mlir::TF::GetOptimalLayout(devices, this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GatherV2Op |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult GatherV2Op::verify() { |
| GatherV2Op op = *this; |
| int64_t batch_dims = op.batch_dims(); |
| if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) { |
| int64_t rank = ty.getRank(); |
| if (batch_dims > rank || batch_dims < -rank) |
| return op.emitOpError() |
| << "batch_dims (" << batch_dims << ") must be in range [" << -rank |
| << ", " << rank + 1 << ")"; |
| if (batch_dims < 0) batch_dims += rank; |
| } |
| |
| if (!HasRankAtMost(op.axis(), 1)) |
| return op.emitOpError("requires axis to have rank at most 1"); |
| |
| DenseIntElementsAttr axis_attr; |
| if (matchPattern(op.axis(), m_Constant(&axis_attr))) { |
| int64_t axis = (*axis_attr.begin()).getSExtValue(); |
| if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) { |
| int64_t rank = ty.getRank(); |
| if (axis >= rank || axis < -rank) |
| return op.emitOpError() << "axis (" << axis << ") must be in range [" |
| << -rank << ", " << rank << ")"; |
| if (axis < 0) axis += rank; |
| } |
| |
| if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) { |
| return op.emitOpError() << "requires axis (" << axis |
| << ") to be greater than or equal to batch_dims (" |
| << batch_dims << ")"; |
| } |
| } |
| return success(); |
| } |
| |
| void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<GatherToV2>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // IfOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) { |
| auto branch_name = [](unsigned index) -> std::string { |
| return index == 0 ? "'then_branch'" : "'else_branch'"; |
| }; |
| return VerifyCaseOrIfOpBranchFunctions( |
| symbol_table, *this, {then_branchAttr(), else_branchAttr()}, branch_name); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // IfOp canonicalization. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> { |
| public: |
| explicit FoldConstantIfOp(MLIRContext *context) |
| : OpRewritePattern<TF::IfOp>(context) {} |
| LogicalResult matchAndRewrite(TF::IfOp op, |
| PatternRewriter &rewriter) const override; |
| |
| private: |
| template <typename T> |
| struct CallOpType { |
| using CallOp = T; |
| }; |
| }; |
| |
| LogicalResult FoldConstantIfOp::matchAndRewrite( |
| TF::IfOp op, PatternRewriter &rewriter) const { |
| // Extract the constant cond value. |
| DenseIntElementsAttr cond_attr; |
| if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); |
| |
| // Cond value must be a scalar. |
| if (cond_attr.getNumElements() != 1) return failure(); |
| |
| // Select a branch function. |
| bool cond = cond_attr.getSplatValue<BoolAttr>().getValue(); |
| FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr(); |
| |
| // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp. |
| auto rewrite = [&](auto op_type) { |
| auto empty = rewriter.getStringAttr(""); |
| ReplaceTfOpWithNewOp<typename decltype(op_type)::CallOp>( |
| rewriter, op, op.getResultTypes(), op.input(), func, |
| /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); |
| }; |
| |
| if (op.is_stateless()) |
| rewrite(CallOpType<PartitionedCallOp>{}); |
| else |
| rewrite(CallOpType<StatefulPartitionedCallOp>{}); |
| |
| return success(); |
| } |
| } // anonymous namespace |
| |
| void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<FoldConstantIfOp, DropAttributes<IfOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // IfRegionOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult IfRegionOp::verifyRegions() { |
| IfRegionOp op = *this; |
| TypeRange then_types = |
| op.then_branch().front().getTerminator()->getOperandTypes(); |
| TypeRange else_types = |
| op.else_branch().front().getTerminator()->getOperandTypes(); |
| |
| TypeRangeWithDesc results{op.getResultTypes(), "result"}; |
| TypeRangeWithDesc then_results{then_types, "then result"}; |
| TypeRangeWithDesc else_results{else_types, "else result"}; |
| |
| if (failed(VerifyTypeRangesAreCompatible(op, then_results, results))) |
| return failure(); |
| if (failed(VerifyTypeRangesAreCompatible(op, else_results, results))) |
| return failure(); |
| return success(); |
| } |
| |
| namespace { |
| class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> { |
| public: |
| explicit FoldConstantIfRegionOp(MLIRContext *context) |
| : OpRewritePattern<TF::IfRegionOp>(context) {} |
| LogicalResult matchAndRewrite(TF::IfRegionOp op, |
| PatternRewriter &rewriter) const override; |
| }; |
| |
| LogicalResult FoldConstantIfRegionOp::matchAndRewrite( |
| TF::IfRegionOp op, PatternRewriter &rewriter) const { |
| // Extract the constant cond value. |
| DenseIntElementsAttr cond_attr; |
| if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); |
| |
| // IfRegion condition should always be a scalar. Select the region to fold to. |
| bool cond = cond_attr.getSplatValue<BoolAttr>().getValue(); |
| Region ®ion = cond ? op.then_branch() : op.else_branch(); |
| |
| // If the IfRegion is stateless but the region being inlined itself is not |
| // stateless, then inlining the region could cause a loss of information. |
| // However, its probably better to fold the IfRegion instead of having the |
| // dead branch stay. |
| |
| // Inline the region in place of the IfRegion op, and forward the yield |
| // inputs to the IfRegion op results. This is possible only if the yield |
| // types match the result types. |
| auto yield = cast<YieldOp>(region.front().getTerminator()); |
| auto updated_results = llvm::to_vector<4>(yield.getOperands()); |
| |
| // If the yield types do not match the IfRegion result types, add appropriate |
| // casts. |
| rewriter.setInsertionPoint(yield); |
| for (auto it : llvm::zip(op.getResultTypes(), updated_results)) { |
| auto &updated_result = std::get<1>(it); |
| Type result_type = std::get<0>(it); |
| if (result_type != updated_result.getType()) { |
| updated_result = |
| rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result, |
| /*Truncate=*/rewriter.getBoolAttr(false)); |
| } |
| } |
| // Inline the region into the block containing the IfRegion. |
| rewriter.mergeBlockBefore(®ion.front(), op); |
| rewriter.eraseOp(yield); |
| rewriter.replaceOp(op, updated_results); |
| return success(); |
| } |
| } // anonymous namespace |
| |
| void IfRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<FoldConstantIfRegionOp, |
| CaseOrIfRegionEliminatePassThrough<TF::IfRegionOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InvertPermutationOp |
| //===----------------------------------------------------------------------===// |
| |
| // Verifies that the input is 1D. |
| LogicalResult InvertPermutationOp::verify() { |
| InvertPermutationOp op = *this; |
| auto x_type = op.x().getType().cast<TensorType>(); |
| if (!x_type.hasRank()) return success(); |
| if (x_type.getShape().size() != 1) |
| return op.emitOpError() << "requires input x to be 1-dimensional"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LeakyReluOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) { |
| assert(operands.size() == 1 && "leaky relu has one operand"); |
| |
| // leaky_relu(x, alpha: 1) -> x |
| if (alpha().convertToFloat() == 1.0f) return getOperand(); |
| |
| auto calculate = [&](FloatAttr arg) { |
| APFloat val = arg.getValue(); |
| if (val.isNegative()) val = alpha() * val; |
| return FloatAttr::get(arg.getType(), val); |
| }; |
| |
| if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) { |
| return calculate(arg); |
| } else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) { |
| if (auto elementAttr = arg.getSplatValue<Attribute>().dyn_cast<FloatAttr>()) |
| return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); |
| } |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LogOp |
| //===----------------------------------------------------------------------===// |
| |
| void LogOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<LogOfSoftmax, LogToLog1p>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LogicalNotOp |
| //===----------------------------------------------------------------------===// |
| |
| void LogicalNotOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results |
| .add<LogicalNotOfEqual, LogicalNotOfNotEqual, LogicalNotOfGreater, |
| LogicalNotOfGreaterEqual, LogicalNotOfLess, LogicalNotOfLessEqual>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MatrixBandPartOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult MatrixBandPartOp::verify() { |
| MatrixBandPartOp op = *this; |
| if (!HasRankAtLeast(op.input(), 2)) { |
| return op.emitOpError() |
| << "requires `input` to have rank of at least 2, but found " |
| << op.input().getType(); |
| } |
| if (!IsOfRankOrUnranked(op.num_lower(), 0)) { |
| return op.emitOpError() |
| << "requires `num_lower` to have 0 dimensions, but found " |
| << op.num_lower().getType(); |
| } |
| if (!IsOfRankOrUnranked(op.num_upper(), 0)) { |
| return op.emitOpError() |
| << "requires `num_upper` to have 0 dimensions, but found " |
| << op.num_upper().getType(); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MatrixDiag Ops |
| //===----------------------------------------------------------------------===// |
| |
| void MatrixDiagOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<MatrixDiagToV3>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MatrixSetDiagOp |
| //===----------------------------------------------------------------------===// |
| |
| void MatrixSetDiagOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<MatrixSetDiagToV3>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MatrixSetDiagV2Op |
| //===----------------------------------------------------------------------===// |
| |
| void MatrixSetDiagV2Op::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<MatrixSetDiagV2ToV3>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MaxOp |
| //===----------------------------------------------------------------------===// |
| |
| void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, |
| Value reduction_indices, BoolAttr keep_dims) { |
| Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims); |
| build(builder, result, out_ty, input, reduction_indices, keep_dims); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MaximumOp |
| //===----------------------------------------------------------------------===// |
| |
| void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<MaximumOfZeroToRelu>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MaxPoolOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult MaxPoolOp::FoldOperandsPermutation( |
| ArrayRef<int64_t> permutation) { |
| return ::mlir::TF::FoldOperandsPermutation( |
| permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); |
| } |
| |
| LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) { |
| StringRef src_data_format = data_format(); |
| |
| auto perm = GetDataFormatPermutation(src_data_format, new_data_format); |
| if (perm.empty()) return failure(); |
| |
| // Update data_format attribute and result types. |
| if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this))) |
| return failure(); |
| |
| stridesAttr(ShuffleArrayAttr(strides(), perm)); |
| explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2)); |
| ksizeAttr(ShuffleArrayAttr(ksize(), perm)); |
| |
| return success(); |
| } |
| |
| StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) { |
| // Keep current data format if no GPUs are available or if explicit placement |
| // does not allow to use GPU for this operation. |
| if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) |
| return data_format(); |
| |
| // Defaults to NCHW. |
| return "NCHW"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MaxPoolGradOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult MaxPoolGradOp::verify() { |
| MaxPoolGradOp op = *this; |
| if (!IsOfRankOrUnranked(op.orig_input(), 4)) { |
| return op.emitOpError() << "requires orig_input to be rank 4"; |
| } |
| if (!IsOfRankOrUnranked(op.orig_output(), 4)) { |
| return op.emitOpError() << "requires orig_output to be rank 4"; |
| } |
| if (!IsOfRankOrUnranked(op.grad(), 4)) { |
| return op.emitOpError() << "requires grad to be rank 4"; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MeanOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) { |
| // Reduction indices must be defined by a constant operation. |
| auto reduction_op = |
| dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp()); |
| if (!reduction_op) return failure(); |
| |
| auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>(); |
| if (!reductions_value) return failure(); |
| |
| // Prepare new reduction indices according to operand permutation. |
| SmallVector<int32_t, 4> shuffled_reduction; |
| llvm::transform(reductions_value.getValues<APInt>(), |
| std::back_inserter(shuffled_reduction), |
| [&](APInt idx) { return permutation[idx.getSExtValue()]; }); |
| |
| // Add constant operation with a new reduction indices. |
| OpBuilder builder(getOperation()); |
| auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), |
| builder.getIntegerType(32)); |
| auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); |
| auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values); |
| |
| // Use new reduction indices. |
| setOperand(1, shuffled_reduction_op); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MulNoNanOp |
| //===----------------------------------------------------------------------===// |
| |
| void MulNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<DivNoNanOrMulNoNanConstantY<TF::MulNoNanOp, TF::MulOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MulOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { |
| return IdentityArithmeticOpFolder<MulOp>(*this, operands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // HashTableOp |
| //===----------------------------------------------------------------------===// |
| void HashTableOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<HashTableAndInitializeTableToV2>(context); |
| results.add<HashTableAndLookupTableSizeToV2>(context); |
| results.add<HashTableAndLookupTableFindToV2>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BitcastOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult BitcastOp::verify() { |
| BitcastOp op = *this; |
| auto input_type = op.input().getType().cast<ShapedType>(); |
| auto output_type = op.output().getType().cast<ShapedType>(); |
| auto input_element_type = input_type.getElementType(); |
| auto output_element_type = output_type.getElementType(); |
| |
| // We only handle float and int element type in the verifier currently |
| // TODO(hanxiongwang): we can plan to handle more element type checks besides |
| // int and float in the verifier |
| if (input_type.hasStaticShape() && output_type.hasStaticShape() && |
| input_element_type.isIntOrFloat() && output_element_type.isIntOrFloat()) { |
| const auto input_element_type_bitwidth = |
| input_element_type.getIntOrFloatBitWidth(); |
| const auto output_element_type_bitwidth = |
| output_element_type.getIntOrFloatBitWidth(); |
| |
| auto is_output_shape_valid_with_small_input_element_type_bitwidth = [&]() { |
| if (output_element_type_bitwidth % input_element_type_bitwidth != 0) { |
| op.emitOpError() << "output element bitwidth is not multiple " |
| << "of input element bitwidth"; |
| return failure(); |
| } |
| if (input_type.getShape().size() != output_type.getShape().size() + 1) { |
| op.emitOpError() << "rank of input tensor is " |
| << input_type.getShape().size() |
| << ". rank of output tensor is expected to be " |
| << input_type.getShape().size() - 1 << ", instead of " |
| << output_type.getShape().size() << "."; |
| return failure(); |
| } |
| const auto rightmost_dim_size_divisor = |
| output_element_type_bitwidth / input_element_type_bitwidth; |
| if (input_type.getShape().empty() || |
| input_type.getShape().back() != rightmost_dim_size_divisor) { |
| op.emitOpError() |
| << "input rightmost dimension size is not equal to the divisor. " |
| << "the last dimension of input is expected to be " |
| << rightmost_dim_size_divisor; |
| return failure(); |
| } |
| for (auto idx = 0; idx < output_type.getShape().size(); idx++) { |
| if (input_type.getShape()[idx] != output_type.getShape()[idx]) { |
| op.emitOpError() |
| << "the " << idx << "th dim of output tensor is " |
| << output_type.getShape()[idx] |
| << ". It is not equal to the one in input tensor, which is " |
| << input_type.getShape()[idx]; |
| return failure(); |
| } |
| } |
| return success(); |
| }; |
| |
| auto is_output_shape_valid_with_small_output_element_type_bitwidth = [&]() { |
| if (input_element_type_bitwidth % output_element_type_bitwidth != 0) { |
| op.emitOpError() << "input element bitwidth is not multiple " |
| << "of output element bitwidth"; |
| return failure(); |
| } |
| if (input_type.getShape().size() + 1 != output_type.getShape().size()) { |
| op.emitOpError() << "rank of input tensor is " |
| << input_type.getShape().size() |
| << ". rank of output tensor is expected to be " |
| << input_type.getShape().size() + 1 << ", instead of " |
| << output_type.getShape().size() << "."; |
| return failure(); |
| } |
| const auto rightmost_dim_size_divisor = |
| input_element_type_bitwidth / output_element_type_bitwidth; |
| if (output_type.getShape().back() != rightmost_dim_size_divisor) { |
| op.emitOpError() |
| << "output rightmost dimension size is not equal to the divisor. " |
| << "the last dimension of output is expected to be " |
| << rightmost_dim_size_divisor; |
| return failure(); |
| } |
| for (auto idx = 0; idx < input_type.getShape().size(); idx++) { |
| if (input_type.getShape()[idx] != output_type.getShape()[idx]) { |
| op.emitOpError() |
| << "the " << idx << "th dim of output tensor is " |
| << output_type.getShape()[idx] |
| << ". It is not equal to the one in input tensor, which is " |
| << input_type.getShape()[idx]; |
| return failure(); |
| } |
| } |
| return success(); |
| }; |
| |
| auto is_output_shape_valid_with_equal_bitwidth = [&]() { |
| if (input_type.getShape().equals(output_type.getShape())) { |
| return success(); |
| } |
| op.emitOpError() |
| << "output tensor shape shall be equal to input tensor shape"; |
| return failure(); |
| }; |
| |
| if (input_element_type_bitwidth < output_element_type_bitwidth) { |
| return is_output_shape_valid_with_small_input_element_type_bitwidth(); |
| } else if (input_element_type_bitwidth > output_element_type_bitwidth) { |
| return is_output_shape_valid_with_small_output_element_type_bitwidth(); |
| } else { |
| return is_output_shape_valid_with_equal_bitwidth(); |
| } |
| } |
| return success(); |
| } |
| |
| } // namespace TF |
| } // namespace mlir |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen'd op method definitions |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc" |