blob: 35799b5d129a5bec542b58a61f827f5a8d3c0ea8 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
#include <numeric>
#include <string>
#include <tuple>
#include <type_traits>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
namespace TF {
namespace {
// Returns the equivalent Value skipping through identity nodes.
Value LookThroughIdentity(Value result) {
while (isa_and_nonnull<IdentityOp, IdentityNOp>(result.getDefiningOp())) {
auto op_result = result.cast<OpResult>();
result = op_result.getOwner()->getOperand(op_result.getResultNumber());
}
return result;
}
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// NotEqualOp
//===----------------------------------------------------------------------===//
LogicalResult NotEqualOp::verify() {
NotEqualOp op = *this;
// If we allow inputs to have incompatible type, then nothing to do.
if (!op.incompatible_shape_error()) return success();
// Otherwise, check inputs are broadcastable.
return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
op.getOperation());
}
void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
Value y, BoolAttr incompatible_shape_error) {
auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
incompatible_shape_error);
return build(builder, result, result_type, x, y, incompatible_shape_error);
}
//===----------------------------------------------------------------------===//
// OneHotOp
//===----------------------------------------------------------------------===//
LogicalResult OneHotOp::verify() {
OneHotOp op = *this;
int64_t axis = op.axis();
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (indices_ty &&
!(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
return op.emitOpError()
<< "expected axis (" << axis << ") to be -1 or between [0, "
<< indices_ty.getShape().size() << "]";
}
if (axis < -1) {
return op.emitOpError() << "expected axis (" << axis
<< ") to be -1 or between [0, rank(indices()))";
}
if (!IsOfRankOrUnranked(op.depth(), 0)) {
return op.emitOpError() << "requires depth to be a scalar";
}
if (!IsOfRankOrUnranked(op.on_value(), 0)) {
return op.emitOpError() << "requires on_value to be a scalar";
}
if (!IsOfRankOrUnranked(op.off_value(), 0)) {
return op.emitOpError() << "requires off_value to be a scalar";
}
DenseIntElementsAttr depth_attr;
if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
if (depth_attr.getType().getRank() != 0)
return op.emitOpError() << "requires depth to be a scalar";
int64_t depth = depth_attr.getValues<APInt>()[0].getSExtValue();
if (depth < 0) {
return op.emitOpError() << "depth must be non-negative, got: " << depth;
}
}
return success();
}
static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
Value off_value, IntegerAttr axis) {
int64_t axis_val = axis.getInt();
Type element_ty = on_value.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
if (axis_val < -1) return unranked_ty;
auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
if (!indices_ty) return unranked_ty;
auto shape = llvm::to_vector<2>(indices_ty.getShape());
if (axis_val == -1) axis_val = shape.size();
int64_t depth_val = ShapedType::kDynamicSize;
DenseIntElementsAttr depth_attr;
if (matchPattern(depth, m_Constant(&depth_attr)) &&
depth_attr.getNumElements() == 1)
depth_val = (*depth_attr.begin()).getSExtValue();
shape.insert(shape.begin() + axis_val, depth_val);
return RankedTensorType::get(shape, element_ty);
}
void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
Value depth, Value on_value, Value off_value,
IntegerAttr axis) {
build(builder, result,
InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
depth, on_value, off_value, axis);
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
LogicalResult PackOp::verify() {
PackOp op = *this;
// TODO(hinsu): Convert variadic length attributes to derived attributes.
Operation::operand_range values = op.values();
if (failed(VerifyTypesCompatibility(values,
/*mask_one_dim=*/false,
op.getOperation()))) {
return failure();
}
int64_t inputs_rank = -1;
for (Value value : values) {
if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
// Exit early as input types are verified to be compatible so all ranked
// tensors have the same rank.
inputs_rank = ty.getRank();
break;
}
}
if (inputs_rank == -1) return success();
// The values can be packed along any of the dimensions between 0 and
// inputs rank, inclusive. Also, as the negative axis values wrap around so
// the axis value range is [-(R+1), R+1).
int64_t range_begin = -inputs_rank - 1; // Inclusive
int64_t range_end = inputs_rank + 1; // Exclusive
int64_t axis = op.axis();
if (axis < range_begin || axis >= range_end) {
return op.emitError() << "attribute 'axis' should be within range ["
<< range_begin << ", " << range_end
<< "); actual value: " << axis;
}
return success();
}
OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
// Fold pack operation if it computes the input tensor shape:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
//
// Where `...` are some statically known dimensions. In this case %pack can be
// replaced with a %shape. This is a common pattern in models with a dynamic
// batch size.
// Pack operation should pack at least two values.
if (values().size() < 2) return {};
// Dimensions packed along axis = 0 (pack scalars into vector).
if (axis() != 0) return {};
// First packed value is defined by a strided slice operation.
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
if (!slice_op) return {};
// Input to the slice op is defined by shape operation.
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
if (!shape_op) return {};
// Input tensor, which shape is reconstructed by the pack operation.
Value tensor = shape_op.input();
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
// scalar value from input vector).
if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
slice_op.shrink_axis_mask() != 1)
return {};
// Returns a value if the `value` is defined by a ConstOp with a single
// integer element in it and has an expected rank.
auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
if (!const_op) return None;
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
if (!value_attr || value_attr.getNumElements() != 1) return None;
auto value_ty = value_attr.getType();
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
auto splat = value_attr.getSplatValue<IntegerAttr>();
return splat.getValue().getSExtValue();
};
// All other packed values are scalar constants.
SmallVector<int64_t, 4> packed_dims;
packed_dims.reserve(values().size() - 1);
for (Value operand : llvm::drop_begin(values(), 1)) {
if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
packed_dims.push_back(*dim);
} else {
return {};
}
}
// Slice exactly the first shape dimension:
// begin = [0] end = [1], strides = [1]
auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
*begin != 0 || *end != 1 || *strides != 1)
return {};
// First tensor dimension is dynamic.
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
!arg_ty.isDynamicDim(0))
return {};
// Argument tensor rank is equal to the number of packed dimensions.
if (arg_ty.getRank() != values().size()) return {};
// All other dimensions are statically known and equal to packed dims.
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
return {};
// Replace %pack with %shape.
return slice_op.input();
}
// Convert Pack to Reshape when there is only one operand to be packed.
// For example,
//
// %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
//
// can be canonicalized to
//
// %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
// %0 = tf.Reshape(%input, %shape)
struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp pack_op,
PatternRewriter &rewriter) const override {
// Check if there is only one operand to be packed.
if (pack_op.N() != 1) {
return failure();
}
// Check if input and output are static.
auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
auto output_ty = pack_op.output().getType().cast<ShapedType>();
if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
return failure();
}
// Create constant shape for reshape.
auto type =
RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
// TODO(b/173622615): Remove after fixed.
ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
pack_op.getOperand(0), shape);
return success();
}
};
void PackOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertPackToReshape>(context);
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
// Paddings must be defined by a constant operation.
auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
if (!paddings_op) return failure();
auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
if (!paddings_value ||
paddings_value.getNumElements() != permutation.size() * 2)
return failure();
SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
for (auto index_pair : llvm::enumerate(paddings_value.getValues<APInt>())) {
size_t outer_idx = index_pair.index() / 2;
size_t inner_idx = index_pair.index() % 2;
shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
index_pair.value().getSExtValue();
}
// Add constant operation with a new paddings.
OpBuilder builder(getOperation());
auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
builder.getIntegerType(32));
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
// Use new paddings.
setOperand(1, shuffled_paddings_op);
// Change the result type.
getResult().setType(ShuffleRankedTensorType(getResult().getType(),
ReversePermutation(permutation)));
return success();
}
//===----------------------------------------------------------------------===//
// ParseExampleV2Op
//===----------------------------------------------------------------------===//
LogicalResult ParseExampleV2Op::verify() {
ParseExampleV2Op op = *this;
// NOTE(mrry): This validates properties of an op that would previously be
// validated by the TensorFlow OpDef type checker. In addition to these
// checks, the shape inference function for ParseExampleV2 validates the
// consistency of the argument and result types.
// Validate dense variadic input and output lengths.
// NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
// do not need to validate dense_defaults.
auto dense_types_count =
std::distance(op.Tdense().begin(), op.Tdense().end());
auto dense_values_count =
std::distance(op.dense_values().begin(), op.dense_values().end());
if (dense_values_count != dense_types_count) {
return op.emitError() << "output 'dense_values' should have same length "
<< "as attribute 'Tdense'";
}
// Validate sparse variadic output lengths.
// NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
// do not need to validate sparse_values.
auto sparse_types_count =
std::distance(op.sparse_types().begin(), op.sparse_types().end());
if (op.num_sparse() != sparse_types_count) {
return op.emitError() << "attribute 'num_sparse' should be the same as "
<< "the length of attribute 'sparse_types'";
}
if (op.sparse_indices().size() != sparse_types_count) {
return op.emitError() << "output 'sparse_indices' should have same length "
<< "as attribute 'sparse_types'";
}
if (op.sparse_shapes().size() != sparse_types_count) {
return op.emitError() << "output 'sparse_shapes' should have same length "
<< "as attribute 'sparse_types'";
}
// Validate ragged variadic output lengths.
auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
op.ragged_value_types().end());
auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
op.ragged_split_types().end());
if (ragged_value_types_count != ragged_split_types_count) {
return op.emitError() << "attribute 'ragged_value_types' should have same "
<< "length as attribute 'ragged_split_types'";
}
return success();
}
//===----------------------------------------------------------------------===//
// PartitionedCallOp
//===----------------------------------------------------------------------===//
template <class OpClass>
static LogicalResult VerifyPartitionedCall(OpClass op) {
auto module = op->template getParentOfType<ModuleOp>();
SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
auto function =
dyn_cast_or_null<func::FuncOp>(SymbolTable::lookupSymbolIn(module, func));
if (!function) {
return op.emitError("'f' attribute refers to an undefined function: ")
<< func;
}
FunctionType function_ty = function.getFunctionType();
int func_arg_count = function_ty.getNumInputs();
int arg_count = op.args().size();
if (arg_count != func_arg_count) {
return op.emitError() << "argument count mismatch: 'args' has " << arg_count
<< " arguments, but '" << func << "' expects "
<< func_arg_count;
}
return success();
}
LogicalResult PartitionedCallOp::verify() {
return VerifyPartitionedCall(*this);
}
LogicalResult StatefulPartitionedCallOp::verify() {
return VerifyPartitionedCall(*this);
}
LogicalResult TPUPartitionedCallOp::verify() {
return VerifyPartitionedCall(*this);
}
//===----------------------------------------------------------------------===//
// PowOp
//===----------------------------------------------------------------------===//
OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
if (constant_y && constant_y.isSplat()) {
APFloat y_value = constant_y.getSplatValue<APFloat>();
auto output_type = getType().cast<ShapedType>();
if (y_value.isZero() && output_type.hasStaticShape()) {
return DenseElementsAttr::get(
output_type,
FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
}
if (y_value.isExactlyValue(1.0)) {
return x();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// QuantizeAndDequantizeV2Op
//===----------------------------------------------------------------------===//
void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
}
//===----------------------------------------------------------------------===//
// QrOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// * Input type, if ranked, must have at least 2 dimensions and at most
// INT32_MAX dimensions.
//
LogicalResult QrOp::verify() {
QrOp op = *this;
auto ttype = op.input().getType().cast<TensorType>();
if (!ttype.hasRank()) return success();
if (!HasRankAtLeast(op.input(), 2))
return op.emitOpError(
"requires ranked input tensor to be of rank 2 or more");
if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
return op.emitOpError(
"requires ranked input tensor to be of rank INT32_MAX or less");
return success();
}
//===----------------------------------------------------------------------===//
// ReadVariableOp
//===----------------------------------------------------------------------===//
void ReadVariableOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ReadVariableOfCast>(context);
}
//===----------------------------------------------------------------------===//
// RandomUniformOp
//===----------------------------------------------------------------------===//
LogicalResult RandomUniformOp::verify() {
RandomUniformOp op = *this;
if (!IsOfRankOrUnranked(op.shape(), 1))
return op.emitOpError("shape must be 1D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// RangeOp
//===----------------------------------------------------------------------===//
namespace {
// Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
// Template parameter `FloatOrInt` must be standard C integer or floating-point
// types.
template <typename FloatOrInt>
int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
// Refer to the implementation in
// tensorflow/lite/kernels/range.cc.
FloatOrInt diff = limit - start;
if (std::is_integral<FloatOrInt>::value) {
return ((std::abs(diff) + std::abs(delta) - 1) / std::abs(delta));
}
return std::ceil(std::abs(diff / delta));
}
// Builds a constant range tensor of `result_elem_type` elements.
// Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
// mlir::FloatAttr.
template <typename FloatOrIntAtrr>
DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
FloatOrIntAtrr start_attr,
FloatOrIntAtrr delta_attr) {
using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
ValueType start = start_attr.getValue();
ValueType delta = delta_attr.getValue();
SmallVector<ValueType, 16> new_values;
new_values.reserve(num_elements);
ValueType new_value = start;
for (int i = 0; i < num_elements; ++i) {
new_values.push_back(new_value);
new_value = new_value + delta;
}
// Result is always a 1-D tensor.
auto new_result_type =
RankedTensorType::get({num_elements}, result_elem_type);
return DenseElementsAttr::get(new_result_type, new_values);
}
} // namespace
void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
Value limit, Value delta) {
assert(start.getType() == limit.getType());
assert(start.getType() == delta.getType());
DenseIntElementsAttr start_val;
DenseIntElementsAttr limit_val;
DenseIntElementsAttr delta_val;
if (matchPattern(start, m_Constant(&start_val)) &&
matchPattern(limit, m_Constant(&limit_val)) &&
matchPattern(delta, m_Constant(&delta_val))) {
auto size = llvm::APIntOps::RoundingSDiv(
*limit_val.begin() - *start_val.begin(), *delta_val.begin(),
llvm::APInt::Rounding::DOWN);
return RangeOp::build(
builder, result,
RankedTensorType::get(
size.getSExtValue(),
start.getType().cast<TensorType>().getElementType()),
start, limit, delta);
}
return RangeOp::build(
builder, result,
RankedTensorType::get(
{-1}, start.getType().cast<TensorType>().getElementType()),
start, limit, delta);
}
OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 3);
auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr;
// Operands should all be scalars
assert(start_tensor.getType().getRank() == 0 &&
limit_tensor.getType().getRank() == 0 &&
delta_tensor.getType().getRank() == 0);
Type elem_type = getType().cast<ShapedType>().getElementType();
if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) {
auto start_attr = start_tensor.getValues<IntegerAttr>()[0];
auto limit_attr = limit_tensor.getValues<IntegerAttr>()[0];
auto delta_attr = delta_tensor.getValues<IntegerAttr>()[0];
int num_elements;
if (elem_type.isUnsignedInteger()) {
uint64_t start = start_attr.getUInt();
uint64_t limit = limit_attr.getUInt();
uint64_t delta = delta_attr.getUInt();
assert(start <= (uint64_t)INT_MAX);
assert(limit <= (uint64_t)INT_MAX);
assert(delta <= (uint64_t)INT_MAX);
num_elements =
GetLengthOfRange(static_cast<int>(start), static_cast<int>(limit),
static_cast<int>(delta));
} else {
num_elements = GetLengthOfRange(start_attr.getInt(), limit_attr.getInt(),
delta_attr.getInt());
}
return BuildConstRangeTensor(elem_type, num_elements, start_attr,
delta_attr);
} else if (elem_type.isa<FloatType>()) {
auto start_attr = start_tensor.getValues<FloatAttr>()[0];
auto limit_attr = limit_tensor.getValues<FloatAttr>()[0];
auto delta_attr = delta_tensor.getValues<FloatAttr>()[0];
const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
limit_attr.getValueAsDouble(),
delta_attr.getValueAsDouble());
return BuildConstRangeTensor(elem_type, num_elements, start_attr,
delta_attr);
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
return RankOp::build(builder, result,
RankedTensorType::get({}, builder.getIntegerType(32)),
input);
}
// This will create a constant value for RankOp of a ranked tensor.
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
auto type = input().getType();
auto ranked_type = type.dyn_cast<RankedTensorType>();
if (!ranked_type) return {};
// DenseIntElementsAttr::get requires the output type be ranked with static
// shape.
auto output_type = getType().dyn_cast<RankedTensorType>();
if (!output_type || !output_type.hasStaticShape()) return {};
int32_t rank = ranked_type.getRank();
return DenseIntElementsAttr::get(output_type, rank);
}
//===----------------------------------------------------------------------===//
// RealDivOp
//===----------------------------------------------------------------------===//
void RealDivOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
}
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
namespace {
using ReshapeErrorHandler =
llvm::function_ref<LogicalResult(const llvm::Twine &)>;
LogicalResult GetReshapeOutputType(Value tensor, Value shape,
ReshapeErrorHandler error_handler,
TensorType &output_ty) {
auto tensor_ty = tensor.getType().cast<TensorType>();
auto element_ty = tensor_ty.getElementType();
output_ty = UnrankedTensorType::get(element_ty);
auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
if (!shape_ty) return success();
if (shape_ty.getRank() != 1)
return error_handler(llvm::formatv(
"requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
DenseIntElementsAttr shape_attr;
if (!matchPattern(shape, m_Constant(&shape_attr))) {
// If only shape of `shape` is known, return ranked but dynamic output
// shape.
if (shape_ty.hasStaticShape()) {
llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
ShapedType::kDynamicSize);
output_ty = RankedTensorType::get(dynamic_shape, element_ty);
}
return success();
}
// Detect if reshape output shape is folded.
bool shape_ty_zero_dim = false;
int unknown_index = -1;
// The product of constant shape argument excluding unknown dimension.
int64_t shape_ty_size = 1;
llvm::SmallVector<int64_t, 8> output_ty_shape;
output_ty_shape.reserve(shape_attr.getNumElements());
for (const auto &dim : llvm::enumerate(shape_attr.getValues<APInt>())) {
const int64_t size = dim.value().getSExtValue();
if (ShapedType::isDynamic(size)) {
if (unknown_index != -1)
return error_handler(llvm::formatv(
"requires 'shape' to have at most one dynamic dimension, but got "
"multiple dynamic dimensions at indices {0} and {1}",
unknown_index, dim.index()));
unknown_index = dim.index();
} else if (size == 0) {
shape_ty_zero_dim = true;
} else if (size > 0) {
shape_ty_size *= size;
} else {
return error_handler(
llvm::formatv("requires 'shape' to have dimensions greater than -1, "
"but got {0} at index {1}",
size, dim.index()));
}
output_ty_shape.push_back(size);
}
if (!tensor_ty.hasStaticShape()) {
output_ty = RankedTensorType::get(output_ty_shape, element_ty);
return success();
}
// Compute the value of the unknown dimension.
if (unknown_index != -1) {
// Compute number of elements in tensor shape.
int64_t tensor_ty_size = 1;
bool tensor_ty_zero_dim = false;
for (const auto &dim : tensor_ty.getShape()) {
if (dim > 0 || !shape_ty_zero_dim) {
tensor_ty_size *= dim;
} else {
tensor_ty_zero_dim = true;
}
}
const int64_t missing_dim = tensor_ty_size / shape_ty_size;
if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
return error_handler(
llvm::formatv("requires 'tensor' number of elements be a multiple of "
"{0}, but got {1}",
shape_ty_size, tensor_ty_size));
// Set the unknown dimension such that total number of elements remain
// constant.
output_ty_shape[unknown_index] = missing_dim;
}
output_ty = RankedTensorType::get(output_ty_shape, element_ty);
return success();
}
} // namespace
LogicalResult ReshapeOp::verify() {
ReshapeOp op = *this;
auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
return op.emitOpError() << message;
};
TensorType expected_ty;
if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
expected_ty)))
return failure();
auto output_ty = op.getType().dyn_cast<RankedTensorType>();
if (!output_ty) return success();
auto tensor_ty = op.tensor().getType().cast<TensorType>();
if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
const int64_t output_ty_size = output_ty.getNumElements();
const int64_t tensor_ty_size = tensor_ty.getNumElements();
if (tensor_ty_size != output_ty_size)
return op.emitOpError() << "requires 'output' number of elements to "
"match 'tensor' number of elements, but got "
<< output_ty_size << " and " << tensor_ty_size;
}
if (!AreCastCompatible({output_ty, expected_ty}))
return op.emitOpError()
<< "requires 'output' type " << output_ty
<< " to be cast compatible with expected type " << expected_ty;
return success();
}
// Currently there are use cases that rely on partial evaluation of the `shape`
// operand, so InferTypeOpInterface is not used (along with generated builder of
// the same signature).
void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
Value shape) {
auto error_handler = [&result](const llvm::Twine &message) {
return mlir::emitError(result.location) << message;
};
TensorType output_ty;
if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
return;
return ReshapeOp::build(builder, result, output_ty, tensor, shape);
}
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RedundantReshape, ReshapeToSelfShape>(context);
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
Value tensor = this->tensor();
// Fold reshape if operand and result types are the same and all dimensions
// are statically known (no-op reshape).
auto result_ty = getType().dyn_cast<ShapedType>();
if (result_ty && result_ty.hasStaticShape() &&
result_ty == tensor.getType()) {
return tensor;
}
return {};
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
// Verifies a few extra requirements on SelectOp:
// (1) `then` and `else` must have same shape
// (2) At least one of the following must be true:
// (a) `cond` has the same rank as `then` and `else`
// (b) `cond` is a scalar
// (c) `cond` is a vector AND `then` and `else` are non-scalar with their
// first dimension equal to `cond`.
LogicalResult SelectOp::verify() {
SelectOp op = *this;
auto then_tensor = op.t().getType().cast<TensorType>();
auto else_tensor = op.e().getType().cast<TensorType>();
// Check (1).
if (!AreCastCompatible({then_tensor, else_tensor}))
return op.emitOpError() << "requires t and e have compatible shapes";
// Get data rank (if exists).
int data_rank;
// If data is unranked or data_rank is 0, this will remain -2. Otherwise
// refers to first dimension of then and/or else.
int data_first_dim = -2;
bool then_has_rank = then_tensor.hasRank();
bool else_has_rank = else_tensor.hasRank();
if (then_has_rank && else_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
if (else_tensor.getRank() > 0)
data_first_dim = std::max(
static_cast<int>(else_tensor.getShape().front()), data_first_dim);
} else if (then_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
} else if (else_has_rank) {
data_rank = else_tensor.getRank();
if (else_tensor.getRank() > 0)
data_first_dim = else_tensor.getShape().front();
} else {
// Neither has a rank.
return success();
}
auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
if (!cond_tensor) return success();
auto cond_rank = cond_tensor.getRank();
// Check (2a) and (2b).
if (cond_rank == 0 || cond_rank == data_rank) return success();
// Check (2c).
if (cond_rank == 1) {
auto cond_shape = cond_tensor.getShape().front();
if (data_rank == 0) {
return op.emitOpError()
<< "requires that t and e are nonscalar when pred is a vector";
}
// We know `data` tensor has a rank of at least 1.
if (data_first_dim != -1 && cond_shape != -1 &&
data_first_dim != cond_shape) {
return op.emitOpError() << "requires that, when pred is a vector, the "
"shape matches the first dimension of t and e";
}
return success();
}
// None of (2a,b,c) were true; fail.
return op.emitOpError() << "requires that pred is a scalar OR has the same "
"rank as t and e OR is a vector";
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//
static Type InferSelectV2OpType(Value condition, Value e, Value t) {
Type element_ty = e.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
Type broadcasted_ty =
OpTrait::util::getBroadcastedType(e.getType(), t.getType());
if (!broadcasted_ty) return unranked_ty;
auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
// Explicitly get broadcasted output type as element types of condition may
// not be same as the broadcated type's element type.
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
broadcasted_ranked_ty.getShape(),
result_shape))
return unranked_ty;
return RankedTensorType::get(result_shape, element_ty);
}
void SelectV2Op::build(OpBuilder &builder, OperationState &result,
Value condition, Value e, Value t) {
build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
}
//===----------------------------------------------------------------------===//
// ShapeOp
//===----------------------------------------------------------------------===//
namespace {
// Validates Shape/ShapeN/VariableShape operand and associated result types.
LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
Type result_type,
int variadic_idx = -1) {
std::string variadic_idx_str =
variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
if (!result_ranked_type) return success();
if (result_ranked_type.getShape().size() != 1)
return op->emitOpError("requires 1D type for result") << variadic_idx_str;
auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
if (operand_ranked_type) {
// The operand is a ranked tensor.
if (result_ranked_type.hasStaticShape() &&
!operand_ranked_type.getShape().empty() &&
result_ranked_type.getDimSize(0) !=
operand_ranked_type.getShape().size())
return op->emitOpError("requires dimension size of result")
<< variadic_idx_str << " to match rank of operand"
<< variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, print a warning if the result
// is static.
// Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
mlir::InFlightDiagnostic diag =
mlir::emitWarning(op->getLoc(), "has static shape result");
if (op->getContext()->shouldPrintOpOnDiagnostic()) {
diag.attachNote(op->getLoc())
.append("see current operation: ")
.appendOp(*op, OpPrintingFlags().printGenericOpForm());
}
diag << variadic_idx_str << " for unranked operand" << variadic_idx_str;
}
Type element_type = result_ranked_type.getElementType();
if (!element_type.isSignlessInteger(32) &&
!element_type.isSignlessInteger(64))
return op->emitOpError("requires int32 or int64 return type for result")
<< variadic_idx_str;
return success();
}
} // anonymous namespace
LogicalResult ShapeOp::verify() {
ShapeOp op = *this;
return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
}
// Converts shape of the given type to attribute if it is of ranked tensor type.
// Returned attribute has integer elements of the given width.
static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
auto shape = ranked_ty.getShape();
int rank = shape.size();
SmallVector<APInt, 4> dimensions;
dimensions.reserve(rank);
for (int i = 0; i < rank; ++i)
dimensions.push_back(APInt(out_width, shape[i]));
auto result_type = RankedTensorType::get(
{rank}, IntegerType::get(input_ty.getContext(), out_width));
return DenseElementsAttr::get(result_type, dimensions);
}
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
return ConvertShapeToAttr(getOperand().getType(), width);
}
void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
BoolAttr use32Bit) {
auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
: builder.getIntegerType(64);
return ShapeOp::build(builder, result,
RankedTensorType::get({rank}, out_type), input);
}
//===----------------------------------------------------------------------===//
// ShapeNOp
//===----------------------------------------------------------------------===//
LogicalResult ShapeNOp::verify() {
ShapeNOp op = *this;
const size_t num_tensors = op.N();
if (op.getNumOperands() != num_tensors)
return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
<< op.getNumOperands() << " operand(s)";
if (op.getNumResults() != num_tensors)
return op.emitOpError() << "requires " << num_tensors << " result(s), got "
<< op.getNumResults() << " result(s)";
for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
auto verification = VerifyShapeOperandAndResult(
op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
if (failed(verification)) return verification;
}
return success();
}
namespace {
// Canonicalization pattern for ShapeNOp that don't have all
// static input shapes. Replacing output values corresponding to static input
// types may enable optimizations in users of the values.
class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
using OpRewritePattern<ShapeNOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeNOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() == 0) {
rewriter.eraseOp(op);
return success();
}
int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
SmallVector<Value, 4> results(op.getNumOperands());
SmallVector<int64_t, 4> dynamic_indices;
SmallVector<Value, 4> dynamic_inputs;
SmallVector<Type, 4> result_types;
for (auto e : llvm::enumerate(op.getOperands())) {
if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
} else {
dynamic_indices.push_back(e.index());
dynamic_inputs.push_back(e.value());
result_types.push_back(op.getType(e.index()));
}
}
if (dynamic_inputs.size() == op.getNumOperands()) {
// Cannot canonicalize ShapeN if all inputs are dynamic.
return failure();
}
// Create a ShapeNOp for all dynamic inputs.
if (!dynamic_inputs.empty()) {
auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
op.getLoc(), result_types, dynamic_inputs);
for (auto index_result :
llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
results[std::get<0>(index_result)] = std::get<1>(index_result);
}
}
rewriter.replaceOp(op, results);
return success();
}
};
// Canonicalize ShapeNOp to ShapeOp if there is only one operand.
class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
using OpRewritePattern<ShapeNOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeNOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 1) {
return failure();
}
auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
op.getOperand(0));
rewriter.replaceOp(op, {shape});
return success();
}
};
} // namespace
void ShapeNOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
}
//===----------------------------------------------------------------------===//
// SizeOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
//
LogicalResult SizeOp::verify() {
SizeOp op = *this;
if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
return op.emitOpError(
"requires ranked input tensor to be of rank INT32_MAX or less");
// Output type needs to be scalar.
if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
return op.emitOpError("requires scalar output");
return success();
}
OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
ShapedType output_type = getType().cast<ShapedType>();
if (!output_type.hasRank()) return {};
ShapedType input_type = getOperand().getType().cast<ShapedType>();
if (!input_type.hasStaticShape()) return {};
int size = input_type.getNumElements();
return DenseElementsAttr::get(
output_type,
IntegerAttr::get(output_type.getElementType(), /*value=*/size));
}
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
// Verifies that:
//
// - operands begin and size are 1D with the same number of elements.
// - if the input is a ranked tensor, the rank of the input equals the number
// of elements in operands begin and size.
// - if begin are constants, that
// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
// and
// size[i] == output_ty.getShape()[i]
// - if begins aren't constant but the input is a ranked tensor, that
// size[i] <= input_ty.getShape()[i]
// - output rank is the same as input rank
//
LogicalResult SliceOp::verify() {
SliceOp op = *this;
RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
if (begin_ty && begin_ty.getRank() != 1) {
return op.emitOpError() << "requires begin operand to be 1D tensor";
}
RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
if (size_ty && size_ty.getRank() != 1) {
return op.emitOpError() << "requires size operand to be 1D tensor";
}
if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
!size_ty.hasStaticShape())
return success();
if (begin_ty.getNumElements() != size_ty.getNumElements()) {
return op.emitOpError() << "requires begin and size operands to have the"
" same number of elements";
}
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
return op.emitOpError() << "requires number of elements in begin and size "
"are equal to input rank";
}
auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
return op.emitOpError()
<< "requires output to have the same rank as input, but got input "
"rank "
<< input_ty.getRank() << " and output rank " << output_ty.getRank();
}
DenseIntElementsAttr begin_indices;
if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
DenseIntElementsAttr slice_sizes;
bool constant_slice_sizes =
matchPattern(op.size(), m_Constant(&slice_sizes));
int dim = 0;
// TODO(jpienaar): Reformulate the shape verification below to not use magic
// constants.
for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
int64_t begin_index = raw_begin_index.getSExtValue();
int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
int64_t slice_size =
constant_slice_sizes
? slice_sizes.getValues<APInt>()[dim].getSExtValue()
: 0;
int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
if (slice_size == -1 && input_size != -1) {
slice_size = input_size - begin_index;
}
if (output_size != -1 && constant_slice_sizes &&
output_size != slice_size) {
return op.emitOpError()
<< "requires output size to have the same size of slice, got "
"slice size "
<< slice_size << " and output size " << output_size;
}
if (begin_index < 0 ||
(input_size != -1 && begin_index + slice_size > input_size)) {
return op.emitOpError()
<< "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
}
++dim;
}
} else if (input_ty) {
// If the inputs are ranked, we can do a few more sanity checks.
DenseIntElementsAttr slice_sizes;
if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
auto input_shape = input_ty.getShape();
for (int64_t i = 0; i < input_ty.getRank(); ++i) {
int64_t slice_size = slice_sizes.getValues<APInt>()[i].getSExtValue();
int64_t input_size = input_shape[i];
if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
"is unknown at compile time";
}
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// SoftmaxOp
//===----------------------------------------------------------------------===//
LogicalResult SoftmaxOp::verify() {
SoftmaxOp op = *this;
if (!HasRankAtLeast(op.logits(), 1)) {
return op.emitOpError("requires operand to have rank at least 1");
}
return success();
}
//===----------------------------------------------------------------------===//
// SoftmaxCrossEntropyWithLogitsOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// * Input types are broadcast compatible and the broadcasted type has rank two.
//
LogicalResult SoftmaxCrossEntropyWithLogitsOp::verify() {
SoftmaxCrossEntropyWithLogitsOp op = *this;
auto broadcasted_ty = OpTrait::util::getBroadcastedType(
op.features().getType(), op.labels().getType())
.dyn_cast_or_null<ShapedType>();
if (!broadcasted_ty ||
(broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
return op.emitOpError(
"requires features and labels to be broadcast compatible to rank two");
return success();
}
//===----------------------------------------------------------------------===//
// SpaceToBatchNDOp
//===----------------------------------------------------------------------===//
int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
const TensorType paddings_type) {
if (block_shape_type.hasStaticShape()) {
return block_shape_type.getShape()[0];
} else if (paddings_type.hasStaticShape()) {
return paddings_type.getShape()[0];
} else {
return -1;
}
}
LogicalResult SpaceToBatchNDOp::verify() {
SpaceToBatchNDOp op = *this;
const auto input_type = op.input().getType().cast<TensorType>();
const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
const auto paddings_type = op.paddings().getType().cast<TensorType>();
// Check that block_shape has rank 1.
if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
return op.emitOpError() << "requires rank of block_shape = 1; got "
<< block_shape_type.getRank();
}
// Check that paddings has rank 2.
if (!IsOfRankOrUnranked(op.paddings(), 2)) {
return op.emitOpError()
<< "requires rank of paddings = 2; got " << paddings_type.getRank();
}
// Check that paddings.shape[1]=2.
if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
<< paddings_type.getShape()[1];
}
// Check that block_shape and paddings have consistent ranks.
if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
return op.emitOpError()
<< "requires block_shape.shape[0] must equal paddings.shape[0]";
}
const int64_t block_rank =
SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
// Further checks require block_rank to be known.
if (block_rank == -1) {
return success();
}
// check that rank of input_type >= block_rank + 1
if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
return op.emitOpError() << "requires rank of input >= 1 + rank of block";
}
ElementsAttr block_shape_attr = nullptr;
ElementsAttr paddings_attr = nullptr;
// Check that block_shape[*] >= 1.
if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
uint64_t i = 0;
for (auto block_len : block_shape_attr.getValues<APInt>()) {
if (block_len.getSExtValue() < 1) {
return op.emitOpError()
<< "requires all values of block_shape to be >= 1; "
"failed for dimension "
<< i;
}
++i;
}
}
// Check that paddings[*] >= 0.
if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
for (uint64_t i = 0; i < block_rank; ++i) {
const int64_t pad_start =
paddings_attr.getValues<APInt>()[{i, 0}].getSExtValue();
const int64_t pad_end =
paddings_attr.getValues<APInt>()[{i, 1}].getSExtValue();
if (pad_start < 0 || pad_end < 0) {
return op.emitOpError()
<< "requires all values of paddings to be >= 0; "
"failed for dimension "
<< i;
}
}
}
// Check that block_shape divides the padded input.
if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
for (uint64_t i = 0; i < block_rank; ++i) {
const int64_t input_len = input_type.getShape()[1 + i];
const int64_t pad_start =
paddings_attr.getValues<APInt>()[{i, 0}].getSExtValue();
const int64_t pad_end =
paddings_attr.getValues<APInt>()[{i, 1}].getSExtValue();
const int64_t block_len =
block_shape_attr.getValues<APInt>()[i].getSExtValue();
if ((input_len + pad_start + pad_end) % block_len != 0) {
return op.emitOpError()
<< "requires block_shape[i] divides "
"input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
"failed for i="
<< i;
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// SparseSoftmaxCrossEntropyWithLogitsOp
//===----------------------------------------------------------------------===//
LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() {
SparseSoftmaxCrossEntropyWithLogitsOp op = *this;
if (!IsOfRankOrUnranked(op.features(), 2)) {
return op.emitOpError("requires features operand of rank two");
}
if (!IsOfRankOrUnranked(op.labels(), 1)) {
return op.emitOpError("requires labels operand of rank one");
}
auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
if (features_ty && labels_ty) {
int64_t features_batches = features_ty.getDimSize(0);
int64_t labels_batches = labels_ty.getDimSize(0);
if (!ShapedType::isDynamic(features_batches) &&
!ShapedType::isDynamic(labels_batches) &&
features_batches != labels_batches)
return op.emitOpError(
"requires features and labels with matching first dimension");
}
return success();
}
//===----------------------------------------------------------------------===//
// SplitOp
//===----------------------------------------------------------------------===//
// Verifies the input and split dimension operands for tf.Split/tf.SplitV.
// Writes the split dimension's index (adjusted with input rank) via `dim_index`
// if it's a constant.
template <class Op>
LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
*dim_index = llvm::None;
Value split_dim = op.split_dim();
if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
if (split_dim_type.getRank() != 0)
return op.emitOpError(
"split dimension should be an integer scalar tensor");
// We can perform further verification if the input tensor to be split has
// known rank and the split dimension tensor is a constant.
auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t input_rank = input_type.getRank();
if (input_rank == 0)
return op.emitOpError("cannot split scalar input tensor");
DenseIntElementsAttr split_dim_attr;
if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
int64_t index = (*split_dim_attr.begin()).getSExtValue();
if (index + input_rank < 0 || index >= input_rank) {
return op.emitOpError("split dimension must be in range [-")
<< input_rank << ", " << input_rank << ")";
}
if (index < 0) index += input_rank;
*dim_index = index;
return success();
}
LogicalResult SplitOp::verify() {
SplitOp op = *this;
Optional<int64_t> dim_index;
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
if (!dim_index) return success();
int64_t input_dim_size =
op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (ShapedType::isDynamic(input_dim_size)) return success();
if (op.getNumResults() == 0) return failure();
if (input_dim_size % op.getNumResults() != 0)
return op.emitOpError("dimension #")
<< *dim_index << " not divisible by the number of result tensors";
return success();
}
//===----------------------------------------------------------------------===//
// SplitVOp
//===----------------------------------------------------------------------===//
LogicalResult SplitVOp::verify() {
SplitVOp op = *this;
auto split_sizes_type =
op.size_splits().getType().dyn_cast<RankedTensorType>();
if (!split_sizes_type) return success();
if (split_sizes_type.getRank() != 1 ||
(!ShapedType::isDynamic(split_sizes_type.getDimSize(0)) &&
split_sizes_type.getDimSize(0) != op.getNumResults()))
return op.emitOpError("split sizes should be a 1D tensor of ")
<< op.getNumResults() << " elements";
Optional<int64_t> dim_index = 0;
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
if (!dim_index) return success();
int64_t input_dim_size =
op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (ShapedType::isDynamic(input_dim_size)) return success();
// If split sizes come from a constant, they must sum to the dimension size
// along split_dim, and we can have no more than one dynamic dimension.
DenseIntElementsAttr split_sizes_attr;
if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
return success();
int64_t total_dim_size = 0; // Total dimension size assigned to splits
llvm::Optional<int> dynamic_dim_index;
SmallVector<int64_t, 4> split_sizes;
split_sizes.reserve(
split_sizes_attr.getType().cast<ShapedType>().getNumElements());
for (auto dim : llvm::enumerate(split_sizes_attr)) {
int64_t dim_val = dim.value().getSExtValue();
split_sizes.push_back(dim_val);
if (ShapedType::isDynamic(dim_val)) {
// We cannot have more than one dynamic dimension.
if (dynamic_dim_index)
return op.emitOpError(
"cannot have more than one dynamic dimension in split sizes");
dynamic_dim_index = dim.index();
} else {
total_dim_size += dim_val;
}
}
if (!dynamic_dim_index && total_dim_size != input_dim_size)
return op.emitOpError(
"split sizes must sum up to the dimension size along split "
"dimension, found ")
<< total_dim_size << " vs " << input_dim_size;
if (dynamic_dim_index && total_dim_size > input_dim_size)
return op.emitOpError(
"split sizes must sum up to be less than or equal to the "
"dimension size along split dimension, found ")
<< total_dim_size << " vs " << input_dim_size;
return success();
}
//===----------------------------------------------------------------------===//
// SquareOp
//===----------------------------------------------------------------------===//
void SquareOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SquareOfSub>(context);
}
//===----------------------------------------------------------------------===//
// SqueezeOp
//===----------------------------------------------------------------------===//
LogicalResult SqueezeOp::verify() {
SqueezeOp op = *this;
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success(); // Can't verify squeeze dims.
int64_t input_rank = input_type.getRank();
for (const auto &squeeze_dim_apint :
op.squeeze_dims().getAsValueRange<IntegerAttr>()) {
int64_t squeeze_dim = squeeze_dim_apint.getSExtValue();
if (squeeze_dim < -input_rank || squeeze_dim >= input_rank) {
return op.emitOpError()
<< "squeeze dimension " << squeeze_dim << " not in ["
<< -input_rank << ", " << input_rank << ")";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SubOfNeg>(context);
}
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<SubOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//
void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
Value reduction_indices, BoolAttr keep_dims) {
Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims);
build(builder, result, out_ty, input, reduction_indices, keep_dims);
}
// TODO: Templatize this fold for all reduction ops.
OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return {};
auto result_ty = getType().template dyn_cast<RankedTensorType>();
if (!result_ty) return {};
// Bypass this op if the result has the same shape and type. This can happen
// if the input tensor has size 0 or size 1.
if (!keep_dims() && input_ty == result_ty) {
return input();
}
return {};
}
//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//
// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
// tf.SliceOp if both of the following are true:
// - All strides have a known value equal to 1
// - No masks are set (or masks can be applied by transforming the inputs to
// Slice)
// Verifies that,
//
// - begin, end and strides operands are 1D and they have the same number of
// elements. Here, the number of elements should be less than 32 to support
// 32-bit mask attributes.
// - None of the strides values are zero.
// - Ellipsis mask can have at most one bit set.
template <class OpTy>
static LogicalResult VerifyStridedSliceBase(OpTy op) {
// Expected size for operands begin, end and strides vector operands.
int64_t expected_size = -1;
for (Value val : {op.begin(), op.end(), op.strides()}) {
auto operand_ty = val.getType().dyn_cast<ShapedType>();
if (!operand_ty || !operand_ty.hasStaticShape()) {
// TensorFlow constant ops may have non-static shape because the shape is
// not propagated during constant folding. If the defining op for this
// operand is a constant op, use the constant op's attribute to get the
// actual shape.
DenseIntElementsAttr attr;
if (!matchPattern(val, m_Constant(&attr))) continue;
operand_ty = attr.getType();
}
if (operand_ty.getRank() != 1)
return op.emitOpError()
<< "requires begin, end and strides to be 1D tensors";
int64_t length = operand_ty.getDimSize(0);
if (length == -1) continue;
if (expected_size == -1) {
// This op uses 32-bit masks.
if (length >= 32)
return op.emitOpError(
"requires begin, end and strides operands with less than 32 "
"elements");
expected_size = length;
} else if (length != expected_size) {
return op.emitOpError() << "requires begin, end and strides to have the "
"same number of elements";
}
}
// If strides are constants, verify that none of the element is zero.
DenseIntElementsAttr strides;
if (matchPattern(op.strides(), m_Constant(&strides))) {
if (llvm::is_contained(strides.getValues<APInt>(), 0))
return op.emitOpError("requires non-zero strides");
}
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
// exists only no more than one ellipsis.
uint32_t ellipsis_mask = op.ellipsis_mask();
if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
return op.emitOpError("cannot have multiple ellipses");
return success();
}
LogicalResult StridedSliceOp::verify() { return VerifyStridedSliceBase(*this); }
// Clamps the given `val`: returns `low` if `val` is less than `low`; returns
// `high` if `high` is less than `val`; otherwise returns `val`.
template <class T>
constexpr const T &Clamp(const T &val, const T &low, const T &high) {
assert(!(high < low));
return (val < low) ? low : (high < val) ? high : val;
}
// Checks if the `index` bit of `val` is set.
template <class T>
constexpr bool IsSet(const T &val, unsigned index) {
return (val & (1 << index)) != 0;
}
// Sets the `index` bit of `val`.
template <class T>
constexpr void Set(T &val, unsigned index) {
val |= (1 << index);
}
// Unset the `index` bit of `val`.
template <class T>
constexpr void Unset(T &val, unsigned index) {
val &= ~(1 << index);
}
// Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
template <class T>
constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
unsigned dst_index) {
if (IsSet(src, src_index))
Set(dst, dst_index);
else
Unset(dst, dst_index);
}
// The sparse spec of strided slice does not correspond to the number of
// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
// 4, 8) would have dims = 2.
struct SparseSliceSpec {
int64_t dims;
int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
const ArrayRef<int64_t> &begin;
const ArrayRef<int64_t> &end;
const ArrayRef<int64_t> &strides;
};
// The dense spec of strided slice is the canonicalized version of sparse spec.
// The number of dimensions of dense spec correspond to the number of dimensions
// in operand tensor.
struct DenseSliceSpec {
int64_t dims;
int32_t begin_mask, end_mask, shrink_axis_mask;
SmallVectorImpl<int64_t> &begin;
SmallVectorImpl<int64_t> &end;
SmallVectorImpl<int64_t> &strides;
};
// Make a sparse spec into a dense index spec.
// The sparse spec does not correspond to the number of dimensions
// Make a dense spec that corresponds to the number of dimensions
//
// For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
// we need to produce the missing begin_mask, end_mask for the first two
// dimensions i.e. foo[:, :, 3:, 2].
static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
DenseSliceSpec *dense) {
// Build expanded dense begin, end, strides, begin_mask, end_mask, and
// shrink_axis_mask.
dense->begin.resize(dense->dims);
dense->end.resize(dense->dims);
dense->strides.resize(dense->dims);
dense->begin_mask = 0;
dense->end_mask = 0;
dense->shrink_axis_mask = 0;
// Count number of new_axis after ellipsis. This helps in calculating the
// number of dimensions ellipsis represents in the sparse spec.
bool ellipsis_seen = false;
int num_new_axis_after_ellipsis = 0;
for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
num_new_axis_after_ellipsis++;
if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
}
int dense_index = 0;
for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
if (IsSet(sparse.ellipsis_mask, sparse_index)) {
auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1 + num_new_axis_after_ellipsis,
dense->dims);
// Expand ellipsis into the appropriate dense indices. From current index
// until next_index, all dimensions would have begin and end masks set and
// stride 1, i.e., get all elements in those dimensions.
for (; dense_index < next_index; ++dense_index) {
dense->begin[dense_index] = dense->end[dense_index] = 0;
dense->strides[dense_index] = 1;
Set(dense->begin_mask, dense_index);
Set(dense->end_mask, dense_index);
}
continue;
}
assert(dense_index < dense->dims);
// Copy over the sparse indices to dense indices if ellipsis_mask and
// new_axis_mask are not set.
dense->begin[dense_index] = sparse.begin[sparse_index];
dense->end[dense_index] = sparse.end[sparse_index];
dense->strides[dense_index] = sparse.strides[sparse_index];
CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
dense_index);
dense_index++;
}
}
// For the given `input_shape`, calculates the sliced shape using the given
// `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
// `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
// `shrink_axis_mask` is not zero, this function will not drop the corresponding
// dimensions in `input_shape`; it will turn them into 1s. At the same time,
// canonicalizes `begin`, `end`, and `strides. The calculation follows
// tf.StridedSlice op semantics.
static void CalculateSlicedShapeFromDenseIndices(
MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
assert(input_shape.size() <= 32); // Only 32-bit masks are supported.
// Make sure ranges' ranks are consistent with the input.
assert(input_shape.size() == begin.size());
assert(input_shape.size() == end.size());
assert(input_shape.size() == stride.size());
for (int i = 0, e = input_shape.size(); i < e; ++i) {
if (ShapedType::isDynamic(input_shape[i])) continue;
int64_t dim_i = input_shape[i];
int64_t begin_i = begin[i];
int64_t end_i = end[i];
int64_t stride_i = stride[i];
// [0]: mask for begin, [1]: mask for end
int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
// [0]: bound for begin, [1]: bound for end
int64_t bounds[] = {stride_i > 0 ? 0 : -1,
stride_i > 0 ? dim_i : dim_i - 1};
// Canonicalizes the given range `point` (begin/end) according to the
// current dimension. `c` means case: 0 for begin, 1 for end.
auto canonicalize = [&](int64_t point, int c) {
if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
// Add dim as offset to negative range point.
point = point < 0 ? dim_i + point : point;
return Clamp(point, bounds[0], bounds[1]);
};
begin_i = canonicalize(begin_i, 0);
end_i = canonicalize(end_i, 1);
int64_t interval_len = end_i - begin_i;
int64_t size_i = 0;
// If internal length is zero or has different sign from stride, it's a
// degenerated case: we are slicing nothing. Otherwise, calculate the sliced
// size.
if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
begin[i] = begin_i;
if (IsSet(shrink_axis_mask, i)) {
// Shrink this dimension. It means we only take the element at begin_i.
input_shape[i] = 1;
end[i] = begin_i + 1;
stride[i] = 1;
} else {
input_shape[i] = size_i;
end[i] = end_i;
stride[i] = stride_i;
}
}
}
// For the given `input_shape`, calculates the sliced shape using the given
// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
// Updates the result back to `input_shape`.
static void CalculateSlicedShapeFromSparseIndices(
MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
int32_t new_axis_mask, int32_t shrink_axis_mask,
SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
SmallVectorImpl<int64_t> *stride) {
int64_t num_sparse_indices = sparse_begin.size();
SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask,
ellipsis_mask, new_axis_mask, shrink_axis_mask,
sparse_begin, sparse_end, sparse_strides};
// If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
// inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
// a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
if (sparse.ellipsis_mask == 0) {
Set(sparse.ellipsis_mask, sparse.dims);
sparse.dims++;
}
int64_t dims = input_shape.size();
DenseSliceSpec dense = {dims,
/*begin_mask = */ 0,
/*end_mask = */ 0,
/*shrink_axis_mask = */ 0,
*begin,
*end,
*stride};
BuildDenseSliceSpec(sparse, &dense);
CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
dense.end_mask, dense.shrink_axis_mask,
*begin, *end, *stride);
}
bool StridedSliceOp::GetSlicedBoundRanges(
SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
SmallVectorImpl<int64_t> *slice_stride) {
// TODO(hinsu): Support lowering for ops with dynamic begin and end values
// when it is possible to derive indices based on mask attributes.
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
!matchPattern(end(), m_Constant(&sparse_end_attr)) ||
!matchPattern(strides(), m_Constant(&sparse_strides_attr)))
return false;
auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return false;
auto input_shape = llvm::to_vector<4>(input_ty.getShape());
SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
for (const APInt &index : sparse_begin_attr)
sparse_begin.push_back(index.getSExtValue());
for (const APInt &index : sparse_end_attr)
sparse_end.push_back(index.getSExtValue());
for (const APInt &stride : sparse_strides_attr)
sparse_strides.push_back(stride.getSExtValue());
CalculateSlicedShapeFromSparseIndices(
input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
slice_begin, slice_end, slice_stride);
return true;
}
OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
// Fold StridedSlice operation if it extracts statically known dimensions.
//
// For example,
//
// %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
// %height = tf.StridedSlice(%shape, 1, 2, 1)
//
// In this case %height can be replaced with a constant 2.
//
// Or,
//
// %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
// %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
//
// In this case %spatial_shape can be replaced with a constant [2, 3].
// Input to strided slice op is defined by shape operation.
auto shape_op = input().getDefiningOp<ShapeOp>();
if (!shape_op) {
return {};
}
// `begin`, `end` and `strides` should be constant in order to infer static
// dimension.
DenseIntElementsAttr begin_attr, end_attr, strides_attr;
if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
!matchPattern(end(), m_Constant(&end_attr)) ||
!matchPattern(strides(), m_Constant(&strides_attr)) ||
begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
strides_attr.getNumElements() != 1) {
return {};
}
// Do not fold when `new_axis_mask` is set. It's likely to break the shape
// of output. Typically, `new_axis_mask` is not set in this canonicalization
// pattern.
if (new_axis_mask() != 0) return {};
auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
// Only ranked tensor can be folded.
if (!tensor_ty) return {};
int64_t rank = tensor_ty.getRank();
int64_t begin_int = begin_attr.getValues<APInt>()[0].getSExtValue();
int64_t end_int = end_attr.getValues<APInt>()[0].getSExtValue();
int64_t strides_int = strides_attr.getValues<APInt>()[0].getSExtValue();
// Canonicalize `begin` and `end` in case of negative index.
if (begin_int < 0) begin_int += rank;
if (end_int < 0) end_int += rank;
// Create `begin` and `end` from `*_mask`. Note that we don't care about
// `new_axis_mask` as it can be inferred from `output_ty`.
if (shrink_axis_mask() == 1) {
// When `shrink_axis_mask` is set, output is always a scalar so only
// one element is sliced.
end_int = begin_int + 1;
}
if (begin_mask() == 1) {
begin_int = (strides_int > 0) ? 0 : rank - 1;
}
if (end_mask() == 1) {
end_int = (strides_int > 0) ? rank : -1;
}
if (ellipsis_mask() == 1) {
begin_int = 0;
end_int = rank;
}
// It's possible that `begin` and `end` are out of bound. See
// https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
if (strides_int > 0) {
begin_int = std::min(begin_int, rank);
end_int = std::min(end_int, rank);
} else {
begin_int = std::min(begin_int, rank - 1);
end_int = std::min(end_int, rank - 1);
}
SmallVector<int64_t, 2> sub_shape;
// Only handle cases that have something to slice to avoid infinite for-loop.
if ((end_int > begin_int && strides_int > 0) ||
(end_int < begin_int && strides_int < 0)) {
// Extract sub-shape only if all of those dimensions are static.
for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
i += strides_int) {
if (tensor_ty.isDynamicDim(i)) {
return {};
}
sub_shape.push_back(tensor_ty.getDimSize(i));
}
}
// For unranked or dynamic output, we infer the output type to either a
// scalar or a vector based on `shrink_axis_mask` because we have rejected
// the case of `new_axis_mask` != 0.
auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
auto output_ty = output().getType().dyn_cast<RankedTensorType>();
if (!output_ty || !output_ty.hasStaticShape()) {
if (shrink_axis_mask() == 1) {
output_ty = RankedTensorType::get({}, output_elt_ty);
} else {
output_ty = RankedTensorType::get(
{static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
}
}
// Down-cast to 32 bit int if needed.
if (output_elt_ty.isInteger(32)) {
SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
[](int64_t d) { return static_cast<int32_t>(d); });
return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
}
return DenseIntElementsAttr::get(output_ty, sub_shape);
}
//===----------------------------------------------------------------------===//
// StridedSliceGradOp
//===----------------------------------------------------------------------===//
LogicalResult StridedSliceGradOp::verify() {
StridedSliceGradOp op = *this;
auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
if (shape_type && shape_type.getRank() != 1)
return op.emitOpError("'shape' operand must be 1D tensor, but got ")
<< shape_type.getRank() << "D tensor";
if (failed(VerifyStridedSliceBase(op))) return failure();
// TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
// the sliced type from StridedSlice.
return success();
}
bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
SmallVectorImpl<int64_t> *input_shape,
SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
SmallVectorImpl<int64_t> *slice_stride) {
DenseIntElementsAttr shape_attr;
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
!matchPattern(end(), m_Constant(&sparse_end_attr)) ||
!matchPattern(strides(), m_Constant(&sparse_strides_attr)))
return false;
int rank = std::distance(shape_attr.begin(), shape_attr.end());
input_shape->clear();
input_shape->reserve(rank);
for (const APInt &dim : shape_attr)
input_shape->push_back(dim.getSExtValue());
SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
for (const APInt &index : sparse_begin_attr)
sparse_begin.push_back(index.getSExtValue());
for (const APInt &index : sparse_end_attr)
sparse_end.push_back(index.getSExtValue());
for (const APInt &stride : sparse_strides_attr)
sparse_strides.push_back(stride.getSExtValue());
CalculateSlicedShapeFromSparseIndices(
*input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
slice_begin, slice_end, slice_stride);
return true;
}
//===----------------------------------------------------------------------===//
// SummaryWriterOp
//===----------------------------------------------------------------------===//
llvm::SmallVector<ResourceHandleValueAndId, 4>
SummaryWriterOp::GetResourceHandleValueAndIdList(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id) {
llvm::StringRef device = GetDeviceOrEmpty(getOperation());
return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
writer(), resource_handle_id_map,
next_id)};
}
//===----------------------------------------------------------------------===//
// TPUExecuteOp
//===----------------------------------------------------------------------===//
void TPUExecuteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.reserve(args().size() + 1);
effects.emplace_back(MemoryEffects::Write::get(),
ResourceEffects::TPUExecute::get());
for (Value value : args()) {
if (value.getType()
.cast<TensorType>()
.getElementType()
.isa<ResourceType>()) {
// Conservatively mark resource handles as read and write, as without
// analyzing TPUCompile, there is not sufficient information to determine
// effects on resources. For the MLIR bridge, this op will never be
// populated with resource handles and tf.TPUExecuteAndUpdateVariables is
// used instead.
effects.emplace_back(MemoryEffects::Read::get(), value,
ResourceEffects::Variable::get());
effects.emplace_back(MemoryEffects::Write::get(), value,
ResourceEffects::Variable::get());
}
}
}
//===----------------------------------------------------------------------===//
// TPUExecuteAndUpdateVariablesOp
//===----------------------------------------------------------------------===//
LogicalResult TPUExecuteAndUpdateVariablesOp::verify() {
TPUExecuteAndUpdateVariablesOp op = *this;
int num_resource_args = 0;
for (Type arg_type : op.args().getTypes())
if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
++num_resource_args;
auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
int min) -> LogicalResult {
if (indices.size() != num_resource_args)
return op.emitOpError()
<< "requires '" << name
<< "' to be the same size as number of resource handles in 'args' "
"("
<< num_resource_args << "), but got " << indices.size();
for (auto entry : llvm::enumerate(indices.getValue())) {
auto int_attr = entry.value().cast<IntegerAttr>();
if (int_attr.getInt() < min)
return op.emitOpError()
<< "requires '" << name << "' to contain values of at least "
<< min << ", but got " << int_attr.getInt() << " at index "
<< entry.index();
}
return success();
};
return failure(
failed(check_attr(op.device_var_reads_indices(),
/*name=*/"device_var_reads_indices", /*min=*/0)) ||
failed(check_attr(op.device_var_updates_indices(),
/*name=*/"device_var_updates_indices", /*min=*/-1)));
}
void TPUExecuteAndUpdateVariablesOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.reserve(device_var_reads_indices().size() + 1);
effects.emplace_back(MemoryEffects::Write::get(),
ResourceEffects::TPUExecute::get());
auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
return value.getType()
.cast<TensorType>()
.getElementType()
.isa<ResourceType>();
});
for (auto &entry : llvm::enumerate(resource_handles)) {
Value value = entry.value();
effects.emplace_back(MemoryEffects::Read::get(), value,
ResourceEffects::Variable::get());
if (device_var_updates_indices()
.getValue()[entry.index()]
.cast<IntegerAttr>()
.getInt() >= 0)
effects.emplace_back(MemoryEffects::Write::get(), value,
ResourceEffects::Variable::get());
}
}
//===----------------------------------------------------------------------===//
// TensorListReserveOp
//===----------------------------------------------------------------------===//
LogicalResult TensorListReserveOp::verify() {
TensorListReserveOp op = *this;
// This is required to populate derived attributes during export in a
// meaningful way. Else during export to GraphDef element_type() query
// will result in out of bounds access/assert.
if (handle_dtype().getSubtypes().size() != 1) {
return emitOpError(
"must have exactly one subtype in the result variant type");
}
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
!IsOfRankOrUnranked(op.element_shape(), 1)) {
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
}
if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
return op.emitOpError("requires num_elements operand to be 0D tensor");
}
return success();
}
//===----------------------------------------------------------------------===//
// TensorListElementShapeOp
//===----------------------------------------------------------------------===//
OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
auto variant_type =
getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
if (variant_type.getSubtypes().empty()) return {};
return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
}
//===----------------------------------------------------------------------===//
// TensorListStackOp
//===----------------------------------------------------------------------===//
LogicalResult TensorListStackOp::verify() {
TensorListStackOp op = *this;
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
!IsOfRankOrUnranked(op.element_shape(), 1)) {
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
}
return success();
}
//===----------------------------------------------------------------------===//
// TensorScatterUpdateOp
//===----------------------------------------------------------------------===//
LogicalResult TensorScatterUpdateOp::verify() {
TensorScatterUpdateOp op = *this;
if (!HasRankAtLeast(op.tensor(), 1))
return op.emitOpError(
"requires tensor operand to have at least 1 dimension");
if (!HasRankAtLeast(op.indices(), 1))
return op.emitOpError(
"requires indices operand to have at least 1 dimension");
if (!HasRankAtLeast(op.updates(), 1))
return op.emitOpError(
"requires updates operand to have at least 1 dimension");
auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (!tensor_ty || !indices_ty) return success();
int64_t num_index_dims = indices_ty.getShape().back();
if (ShapedType::isDynamic(num_index_dims)) return success();
if (num_index_dims > tensor_ty.getRank())
return op.emitOpError(
"requires tensor operand with rank greater than or equal to the "
"indices operand's last dimensions");
return success();
}
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// - input has at least rank 1
// - multiples is rank 1
// - multiples.size() == input.rank()
// - input.rank() == output.rank()
// - Elements in multiples are non-negative
// - input.shape[i] * multiples[i] == output.shape[i]
// for i in [0, input.rank() - 1]
LogicalResult TileOp::verify() {
TileOp op = *this;
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
if (multiples_type && multiples_type.getRank() != 1) {
return op.emitOpError() << "expected multiples to be rank 1, got rank = "
<< multiples_type.getRank();
}
if (input_type && multiples_type && multiples_type.hasStaticShape() &&
(input_type.getRank() != multiples_type.getNumElements() ||
(input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
return op.emitOpError()
<< "expected size of multiples equal to rank of input"
<< ", got multiples of size " << multiples_type.getNumElements()
<< ", and input of rank " << input_type.getRank();
}
if (input_type && output_type) {
if (input_type.getRank() != output_type.getRank()) {
return op.emitOpError()
<< "expected rank of input to equal to rank of output"
<< ", got input of rank " << input_type.getRank()
<< ", and output of rank " << output_type.getRank();
}
DenseIntElementsAttr multiples_attr;
if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
const int64_t input_dim = input_type.getDimSize(i);
const int64_t output_dim = output_type.getDimSize(i);
const int64_t m = multiples_attr.getValues<APInt>()[i].getSExtValue();
if (m < 0) {
return op.emitOpError()
<< "expected multiples to be non-negative, got "
<< "multiples[" << i << "] = " << m;
}
if (!ShapedType::isDynamic(input_dim) &&
!ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
return op.emitOpError()
<< "requires input.shape[" << i << "] (" << input_dim << ")"
<< " * " << m << " to be equal to "
<< "output.shape[" << i << "] (" << output_dim << ")";
}
}
}
}
return success();
}
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
DenseIntElementsAttr multiples_attr;
if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
// Return input directly when multiples are all ones,
// regardless what input is.
if (multiples_attr.isSplat() &&
multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
return input();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//
LogicalResult TopKV2Op::verify() {
TopKV2Op op = *this;
if (!HasRankAtLeast(op.input(), 1))
return op.emitOpError(
"requires input operand to have at least 1 dimension");
if (!IsOfRankOrUnranked(op.k(), 0))
return op.emitOpError("requires k operand to be 0D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// ToBoolOp
//===----------------------------------------------------------------------===//
namespace {
// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
// into an identity or an equality comparison.
class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
using OpRewritePattern<ToBoolOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToBoolOp op,
PatternRewriter &rewriter) const override {
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
// If the input is an unranked tensor, cannpt rewrite.
if (!type) return failure();
// Expected return type of the ToBool operation. The return type of ToBool
// operation is always 0D tensor of bool type.
auto result_type = op.getResult().getType().cast<RankedTensorType>();
// If input is already a tensor<i1>, it can be folded into an identity.
if (type == result_type) {
rewriter.replaceOp(op, op.getOperand());
return success();
}
if (type.getRank() == 0) {
// If the input is a scalar tensor, the ToBool can be expanded to
// element != 0 (for numerical values) or element == empty (for string).
Type element_type = type.getElementType();
Attribute zero_attr;
if (element_type.isIntOrFloat())
zero_attr = rewriter.getZeroAttr(type);
else if (element_type.isa<TF::StringType>())
zero_attr = DenseStringElementsAttr::get(type, {""});
if (!zero_attr) return failure();
auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
op, result_type, op.getOperand(), zero_const, false);
} else {
// If the input is a non-scalar ranked tensor, ToBool can be expanded
// to numElements != 0. numElements will be 0 iff one of the dimensions is
// zero.
bool any_zero =
llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
rewriter.replaceOpWithNewOp<TF::ConstOp>(
op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
}
return success();
}
};
} // namespace
void ToBoolOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ToBoolOfRankedTensor>(context);
}
LogicalResult ToBoolOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(
RankedTensorType::get({}, IntegerType::get(context, 1)));
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
LogicalResult TransposeOp::verify() {
TransposeOp op = *this;
auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
if (perm_type && perm_type.getRank() != 1) {
return op.emitOpError()
<< "expected perm to be a 1-D Tensor, got perm of rank "
<< perm_type.getRank();
}
if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
return op.emitOpError() << "x should be of the same rank with y, got "
<< "x of rank " << x_type.getRank()
<< ", and y of rank " << y_type.getRank();
}
if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
return success();
}
if (x_type.getRank() != perm_type.getNumElements()) {
return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
<< "equal to the rank of x, got perm of size "
<< perm_type.getNumElements() << ", and x of rank "
<< x_type.getRank();
}
DenseIntElementsAttr attr_perm;
if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
// y.shape[i] should be equal to x.shape[perm[i]]
// for i = [0, 1, ..., rank(x) - 1]
for (auto e : llvm::enumerate(attr_perm)) {
const int64_t y_idx = e.index();
const int64_t y_dim = y_type.getDimSize(y_idx);
const int64_t x_idx = e.value().getSExtValue();
const int64_t x_dim = x_type.getDimSize(x_idx);
if (!ShapedType::isDynamic(y_dim) && !ShapedType::isDynamic(x_dim) &&
y_dim != x_dim) {
return op.emitOpError()
<< "requires y.shape[" << y_idx << "] (" << y_dim << ") "
<< "to be equal to x.shape[perm[" << x_idx << "]] "
<< "(" << x_dim << ")";
}
}
}
return success();
}
// TODO(jpienaar): perm could be optional too.
void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
Value perm) {
auto x_type = x.getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!x_type.hasRank())
return TransposeOp::build(builder, result,
UnrankedTensorType::get(x_type.getElementType()),
x, perm);
// TODO(jpienaar): Handle unknown perm case.
// TODO(jpienaar): Extract utility function.
auto etype = x_type.cast<ShapedType>().getElementType();
DenseIntElementsAttr attr_shape;
if (matchPattern(perm, m_Constant(&attr_shape))) {
llvm::SmallVector<int64_t, 4> const_shape;
if (attr_shape.isSplat()) {
const_shape.assign(
attr_shape.getNumElements(),
x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
} else {
const_shape.reserve(attr_shape.getNumElements());
for (const auto &dim : attr_shape)
const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
}
return TransposeOp::build(
builder, result, RankedTensorType::get(const_shape, etype), x, perm);
}
return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
perm);
}
namespace {
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
DenseIntElementsAttr perm;
if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
const auto elements = perm.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) return {};
}
// TODO(jpienaar): Remove if/when we handle this more generally.
if (op.getType() != op.x().getType()) {
// If the types don't match then only fold if all the operands are in the TF
// dialect.
for (auto user : op.getOperation()->getUsers())
if (user->getDialect() != op->getDialect()) return {};
}
return op.x();
}
OpFoldResult FoldCancellableTranspose(TransposeOp op) {
// Operand is a TransposeOp.
auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
if (!transpose) return {};
// Permutations defined by constant operations.
DenseIntElementsAttr perm0;
DenseIntElementsAttr perm1;
if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
!matchPattern(transpose.perm(), m_Constant(&perm1)))
return {};
// With permutation indices that cancel each other
if (!AreCancellablePermutations(perm0, perm1)) return {};
return transpose.x();
}
} // namespace
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (auto folded = FoldIdentityTranspose(*this)) return folded;
if (auto folded = FoldCancellableTranspose(*this)) return folded;
return {};
}
//===----------------------------------------------------------------------===//
// TruncateDivOp
//===----------------------------------------------------------------------===//
void TruncateDivOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<TruncateDivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// NonMaxSuppressionV3Op
//===----------------------------------------------------------------------===//
namespace {
// Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
PatternRewriter &rewriter) const override {
if (nms_op.getNumOperands() != 5) {
return failure();
}
SmallVector<Type, 2> new_result_types;
new_result_types.push_back(nms_op.getType());
auto input_ty = nms_op.getType().template cast<ShapedType>();
// corresponds to the second result type of nmsv4
RankedTensorType valid_output_type =
RankedTensorType::get({}, input_ty.getElementType());
new_result_types.push_back(valid_output_type);
auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
nms_op.max_output_size(), nms_op.iou_threshold(),
nms_op.score_threshold());
// Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
// two are different (v4 expects two output values vs v3 requires only one.
nms_op.replaceAllUsesWith(nmsv4.getResult(0));
return success();
}
};
} // namespace.
void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<NMSV3ToNMSV4Op>(context);
}
//===----------------------------------------------------------------------===//
// FusedBatchNormOp
//===----------------------------------------------------------------------===//
namespace {
class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
PatternRewriter &rewriter) const override {
auto new_result_types =
llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
// reserve_space_3
new_result_types.push_back(
UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
OperationState new_state(tf_fused_batch_norm_op.getLoc(),
TF::FusedBatchNormV3Op::getOperationName(),
tf_fused_batch_norm_op.getOperands(),
new_result_types,
tf_fused_batch_norm_op->getAttrs());
Operation *tf_fused_batch_norm_op_v3 = rewriter.create(new_state);
rewriter.replaceOp(tf_fused_batch_norm_op,
tf_fused_batch_norm_op_v3->getResults().drop_back());
return success();
}
};
} // namespace.
void FusedBatchNormOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertFusedBatchNorm>(context);
}
//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
LogicalResult UnpackOp::verify() {
UnpackOp op = *this;
auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!value_type) return success();
int64_t value_rank = value_type.getRank();
int64_t axis = op.axis();
if (axis < -value_rank || axis >= value_rank)
return op.emitOpError("axis attribute must be in the range of [-")
<< value_rank << ", " << value_rank << ')';
axis = GetDimForAxis(axis, value_rank);
int64_t dim_size = value_type.getDimSize(axis);
if (ShapedType::isDynamic(dim_size)) return success();
if (dim_size != op.getNumResults())
return op.emitOpError("result count must be equal to ") << dim_size;
return success();
}
namespace {
// Hoist coefficient-wise unary operation out of the Unpack op:
//
// %unpacked:N = "tf.Unpack"(%0)
// %neg0 = "tf.Neg"(%unpacked#0)
// %neg1 = "tf.Neg"(%unpacked#1)
// ...
// %negN-1 = "tf.Neg"(%unpacked:N-1)
//
// Rewrite it to:
//
// %neg = "tf.Neg"(%0)
// %unpacked:N = "tf.Unpack"(%neg)
class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
public:
explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
: OpRewritePattern<UnpackOp>(context) {}
LogicalResult matchAndRewrite(UnpackOp op,
PatternRewriter &rewriter) const override;
};
LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
UnpackOp op, PatternRewriter &rewriter) const {
auto loc = op.getLoc();
// First unpack user must be coeff-wise unary operation.
Operation *first_user = *op->getUsers().begin();
if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
// All unpack users must be defined by the op of same kind.
bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
return user->getName() == first_user->getName();
});
if (!users_same_op) return failure();
// Pass unpack operand to unary operation.
OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
op.getOperand(), op.getOperand().getType(),
ArrayRef<NamedAttribute>());
Operation *new_unary_op = rewriter.create(new_unary_op_state);
// Unpack results after applying unary operation.
auto unpack_unary_op = rewriter.create<UnpackOp>(
loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
// Bypass all users of the original unpack operation and use `unpack_unary_op`
// results instead.
for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
OpResult old_result = std::get<0>(pair); // result of original Unpack
OpResult new_result = std::get<1>(pair); // result of transformed Unpack
for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
rewriter.replaceOp(user, ValueRange(new_result));
}
// Erase original unpack operation.
rewriter.eraseOp(op.getOperation());
return success();
}
} // namespace
void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<HoistCwiseUnaryOutOfUnpack>(context);
}
//===----------------------------------------------------------------------===//
// Unsorted segment reduction ops
//===----------------------------------------------------------------------===//
template <class Op>
static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
if (!HasRankAtMost(op.num_segments(), 0))
return op.emitOpError("number of segments should be a 0-D tensor");
auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
auto segment_ids_type =
op.segment_ids().getType().template dyn_cast<RankedTensorType>();
if (data_type && segment_ids_type) {
if (data_type.getRank() < segment_ids_type.getRank())
return op.emitOpError(
"requires segment ids rank to be less than or equal to data's rank");
int index = 0;
for (auto shape_pair :
llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
int64_t segment_id_dim = std::get<0>(shape_pair);
int64_t data_dim = std::get<1>(shape_pair);
if (!ShapedType::isDynamic(segment_id_dim) &&
!ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
return op.emitOpError(
"requires segment ids shape to be a prefix of data shape, "
"but dimension #")
<< index << " differs: " << segment_id_dim << " vs. "
<< data_dim;
++index;
}
}
DenseIntElementsAttr num_segments_attr;
if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
if (num_segments < 0)
return op.emitOpError("num of segments cannot be negative");
}
return success();
}
LogicalResult UnsortedSegmentMaxOp::verify() {
return VerifyUnsortedSegmentReduction(*this);
}
LogicalResult UnsortedSegmentMinOp::verify() {
return VerifyUnsortedSegmentReduction(*this);
}
LogicalResult UnsortedSegmentProdOp::verify() {
return VerifyUnsortedSegmentReduction(*this);
}
LogicalResult UnsortedSegmentSumOp::verify() {
return VerifyUnsortedSegmentReduction(*this);
}
//===----------------------------------------------------------------------===//
// VarHandleOp
//===----------------------------------------------------------------------===//
LogicalResult VarHandleOp::verify() {
// VarHandleOp requires the resource handle supply a single subtype from
// which to derive the dtype and shape attributes.
if (resource_type().getSubtypes().size() != 1) {
return emitOpError(
"must have exactly one subtype in the result resource type");
}
return success();
}
llvm::SmallVector<ResourceHandleValueAndId, 4>
VarHandleOp::GetResourceHandleValueAndIdList(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id) {
llvm::StringRef device = GetDeviceOrEmpty(getOperation());
return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
resource(), resource_handle_id_map,
next_id)};
}
//===----------------------------------------------------------------------===//
// VarIsInitializedOp
//===----------------------------------------------------------------------===//
namespace {
/// Erase VarIsInitializedOp operations with no uses. This op has side effect on
/// resources (read-only), but can still be deleted if it has zero uses.
struct EraseDeadVarIsInitializedOp
: public OpRewritePattern<VarIsInitializedOp> {
using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(VarIsInitializedOp op,
PatternRewriter &rewriter) const override {
if (!op.use_empty()) return failure();
rewriter.eraseOp(op);
return success();
}
};
} // end anonymous namespace.
void VarIsInitializedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<EraseDeadVarIsInitializedOp>(context);
}
//===----------------------------------------------------------------------===//
// VariableOp
//===----------------------------------------------------------------------===//
void VariableOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<VariableToVariableV2>(context);
}
//===----------------------------------------------------------------------===//
// VariableShapeOp
//===----------------------------------------------------------------------===//
LogicalResult VariableShapeOp::verify() {
VariableShapeOp op = *this;
auto input_type = op.input().getType().cast<TensorType>();
if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
return op.emitOpError("requires input to have one resource");
auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
auto subtypes = resource_type.getSubtypes();
switch (subtypes.size()) {
case 1:
return VerifyShapeOperandAndResult(
op, resource_type.getSubtypes().front(), op.getType());
case 0:
return VerifyShapeOperandAndResult(op, Type(), op.getType());
default:
return op.emitOpError(
"requires resource input type to have at most 1 subtype");
}
}
OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
auto resource_type =
getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
if (resource_type.getSubtypes().empty()) return {};
return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
TypeRange body_input,
TypeRange body_result,
bool shape_invariant) {
const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
constexpr int kNumRegionTypeLists = 3;
const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
{body_result, "body result"},
{cond_input, "condition input"},
{body_input, "body input"},
}};
// A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a
// common source of inputs for both. Therefore, the While op requires the
// following pairs of type lists to be cast compatible for the tensor_cast
// operation:
//
// * Operands and cond inputs to call the cond function before the
// first iteration.
// * Operands and body inputs to call the body function for the first
// iteration if the cond functions returns True or equivalent result.
// * Operands and results to assign cond function arguments to op results if
// the cond function returns False or equivalent result. If the op is shape
// invariant, this does not hold as shapes can differ.
// * All three pairs using cond inputs, body inputs and results as operand is
// a common source for all three.
// * Body result and cond inputs to call the cond function for the subsequent
// iterations. Similarly, Body result should be compatible with body inputs
// and op results.
//
// Note that the operands and body results need not be compatible as they are
// never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive.
if (!shape_invariant &&
failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
return failure();
// Skip the first pair as the While op operands and body function results does
// not need to be compatible with each other.
for (int i = 1; i < kNumRegionTypeLists; ++i)
if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
return failure();
for (int i = 0; i < kNumRegionTypeLists; ++i)
if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
return failure();
for (int i = 0; i < kNumRegionTypeLists; ++i)
for (int j = i + 1; j < kNumRegionTypeLists; ++j)
if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
region_types[j])))
return failure();
return success();
}
LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
// TODO(jpienaar): Remove.
if (failed(WhileOpAdaptor(*this).verify(getLoc()))) return failure();
auto cond_fn =
symbol_table.lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
auto body_fn =
symbol_table.lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
if (!cond_fn) {
return emitOpError("cond refers to an undefined function : ") << cond();
}
if (!body_fn) {
return emitOpError("body refers to an undefined function : ") << body();
}
auto cond_fn_type = cond_fn.getFunctionType();
auto body_fn_type = body_fn.getFunctionType();
// Verify that the cond function has exactly one result.
if (cond_fn_type.getNumResults() != 1)
return emitOpError("requires cond function to have exactly one result");
return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
/*body_input=*/body_fn_type.getInputs(),
/*body_result=*/body_fn_type.getResults(),
shape_invariant());
}
//===----------------------------------------------------------------------===//
// WhileRegionOp
//===----------------------------------------------------------------------===//
LogicalResult WhileRegionOp::verify() {
WhileRegionOp op = *this;
// Verify that the condition generates a single tensor<i1> result.
Operation *cond_yield = op.cond().front().getTerminator();
if (cond_yield->getNumOperands() != 1)
return op.emitOpError()
<< "condition should have a single tensor<i1> result";
auto cond_type =
cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!cond_type || !cond_type.getShape().equals({}) ||
!cond_type.getElementType().isInteger(/*width=*/1))
return op.emitOpError()
<< "condition should have a single tensor<i1> result";
Operation *body_yield = op.body().front().getTerminator();
if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
/*body_input=*/op.body().getArgumentTypes(),
/*body_result=*/body_yield->getOperandTypes(),
op.shape_invariant())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// WhileRegionOp LoopLikeOpInterface
//===----------------------------------------------------------------------===//
Region &WhileRegionOp::getLoopBody() { return body(); }
//===----------------------------------------------------------------------===//
// WhileRegionOp canonicalization
//===----------------------------------------------------------------------===//
namespace {
// Eliminate values that pass through the WhileRegionOp body.
struct WhileRegionEliminatePassThrough
: public OpRewritePattern<WhileRegionOp> {
using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileRegionOp while_op,
PatternRewriter &rewriter) const override {
// Remove any extern values that are explicitly captured and returned. Also
// replace values that simply passthrough the body with extern values. The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
int old_num_operands = while_op.getNumOperands();
int new_num_operands = old_num_operands;
auto &body_block = while_op.body().front();
auto &cond_block = while_op.cond().front();
auto &yield = *body_block.getTerminator();
// Bit mask indicating which operands will be removed.
llvm::BitVector removed_operand(old_num_operands);
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
auto body_arg = body_block.getArgument(op_idx);
auto yield_operand = LookThroughIdentity(yield.getOperand(op_idx));
auto while_operand = while_op.getOperand(op_idx);
if (body_arg == yield_operand || while_operand == yield_operand) {
// Replace the use of the passthrough value with the while operand
// in the body and condition regions, as well as the while output (if
// type match)
// TODO(jurahul): Use PatternRewriter API for IR modification.
if (body_arg.getType() == while_operand.getType())
body_arg.replaceAllUsesWith(while_operand);
auto cond_arg = cond_block.getArgument(op_idx);
if (cond_arg.getType() == while_operand.getType())
cond_arg.replaceAllUsesWith(while_operand);
auto result = while_op.getResult(op_idx);
if (result.getType() == while_operand.getType())
result.replaceAllUsesWith(while_operand);
}
// Now check if the operand is unused in both regions as well as the
// result. If so, mark it for removal.
if (body_block.getArgument(op_idx).use_empty() &&
cond_block.getArgument(op_idx).use_empty() &&
while_op.getResult(op_idx).use_empty()) {
removed_operand.set(op_idx);
new_num_operands--;
}
}
if (new_num_operands == old_num_operands) return failure();
// Compress the operands, region arguments, and outputs.
SmallVector<Value, 4> new_while_operands;
SmallVector<Type, 4> new_result_types;
new_while_operands.reserve(new_num_operands);
new_result_types.reserve(new_num_operands);
// Build new operands and result type.
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
if (removed_operand.test(op_idx)) continue;
new_while_operands.push_back(while_op.getOperand(op_idx));
new_result_types.push_back(while_op.getResult(op_idx).getType());
}
// Create the new while operation.
auto new_while_op = rewriter.create<WhileRegionOp>(
while_op.getLoc(), new_result_types, new_while_operands,
while_op->getAttrs());
// Move region bodies to the new while.
rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
new_while_op.cond().end());
rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
new_while_op.body().end());
auto &new_cond_block = new_while_op.cond().front();
auto &new_body_block = new_while_op.body().front();
auto &new_yield = *new_body_block.getTerminator();
// Patch up the region bodies and yield.
new_cond_block.eraseArguments(removed_operand);
new_body_block.eraseArguments(removed_operand);
new_yield.eraseOperands(removed_operand);
// Build a vector of new results. Also patch up the region bodies and
// yield.
SmallVector<Value, 4> new_results(old_num_operands);
int next_idx = 0;
for (int op_idx : llvm::seq<int>(0, old_num_operands))
if (!removed_operand.test(op_idx))
new_results[op_idx] = new_while_op.getResult(next_idx++);
rewriter.replaceOp(while_op, new_results);
return success();
}
};
} // anonymous namespace
void WhileRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<WhileRegionEliminatePassThrough>(context);
}
//===----------------------------------------------------------------------===//
// XdivyOp
//===----------------------------------------------------------------------===//
void XdivyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<XdivyWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// XlaBroadcastHelperOp
//===----------------------------------------------------------------------===//
LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
auto loc = location ? *location : mlir::UnknownLoc::get(context);
XlaBroadcastHelperOpAdaptor op(operands.getValues(), attributes);
if (failed(op.verify(loc))) {
return failure();
}
Value lhs = op.lhs();
Value rhs = op.rhs();
auto set_unranked_results = [&]() {
inferredReturnShapes.emplace_back(getElementTypeOrSelf(lhs));
inferredReturnShapes.emplace_back(getElementTypeOrSelf(rhs));
return success();
};
RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_ty || !rhs_ty) return set_unranked_results();
int64_t lhs_rank = lhs_ty.getRank();
int64_t rhs_rank = rhs_ty.getRank();
DenseIntElementsAttr dims;
if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
return set_unranked_results();
}
if (dims.size() == 0) {
if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
return emitOptionalError(
location,
"if broadcast_dims is empty, both arguments must have equal rank or "
"at least one argument must be a scalar");
}
inferredReturnShapes.emplace_back(lhs_ty.cast<ShapedType>());
inferredReturnShapes.emplace_back(rhs_ty.cast<ShapedType>());
return success();
}
const bool broadcast_lhs = lhs_rank < rhs_rank;
RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
if (dims.size() != min_rank_ty.getRank()) {
return emitOptionalError(
location,
"broadcast_dims must have size equal to the smaller argument rank");
}
int64_t output_rank = max_rank_ty.getRank();
llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
for (auto item : llvm::enumerate(dims)) {
int64_t index = item.index();
int64_t dim = item.value().getSExtValue();
if (dim < 0 || dim > output_rank) {
return emitOptionalError(location, "out of range broadcast dim");
}
if (is_broadcasted[dim]) {
return emitOptionalError(location, "broadcast_dims has duplicates");
}
broadcast_shape[dim] = min_rank_ty.getDimSize(index);
is_broadcasted[dim] = true;
}
if (broadcast_lhs) {
inferredReturnShapes.emplace_back(broadcast_shape, lhs_ty.getElementType());
inferredReturnShapes.emplace_back(rhs_ty.cast<ShapedType>());
} else {
inferredReturnShapes.emplace_back(lhs_ty.cast<ShapedType>());
inferredReturnShapes.emplace_back(broadcast_shape, rhs_ty.getElementType());
}
return success();
}
//===----------------------------------------------------------------------===//
// XlaSetDynamicDimensionSizeOp
//===----------------------------------------------------------------------===//
LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
auto loc = location ? *location : mlir::UnknownLoc::get(context);
XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes);
if (failed(op.verify(loc))) return failure();
TensorType operand_ty = op.input().getType().cast<TensorType>();
Type element_ty = operand_ty.getElementType();
TensorType result_ty;
if (operand_ty.hasRank()) {
auto shape = llvm::to_vector<4>(operand_ty.getShape());
DenseIntElementsAttr dim_index_attr;
if (matchPattern(op.dim_index(), m_Constant(&dim_index_attr))) {
int64_t dim_index = dim_index_attr.getValues<APInt>()[0].getSExtValue();
int64_t rank = operand_ty.getRank();
if (dim_index < 0 || dim_index >= rank) {
return emitOptionalError(location, "dim_index (", dim_index,
") is out of range [0, ", rank, ")");
}
shape[dim_index] = ShapedType::kDynamicSize;
} else {
shape.assign(shape.size(), ShapedType::kDynamicSize);
}
result_ty = RankedTensorType::get(shape, element_ty);
} else {
result_ty = UnrankedTensorType::get(element_ty);
}
inferredReturnShapes.emplace_back(result_ty.cast<ShapedType>());
return success();
}
//===----------------------------------------------------------------------===//
// XlaReduceOp
//===----------------------------------------------------------------------===//
class XlaReduceToXlaVariadicReduceV2
: public OpRewritePattern<TF::XlaReduceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::XlaReduceOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> inputs{op.input()};
SmallVector<Value> init_values{op.init_value()};
SmallVector<Type> result_types{op.getResult().getType()};
rewriter.replaceOpWithNewOp<TF::XlaVariadicReduceV2Op>(
op, result_types, inputs, init_values, op.dimensions_to_reduce(),
op.reducer());
return ::mlir::success();
};
};
void XlaReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<XlaReduceToXlaVariadicReduceV2>(context);
}
//===----------------------------------------------------------------------===//
// XlaReduceWindowOp
//===----------------------------------------------------------------------===//
LogicalResult XlaReduceWindowOp::verify() {
XlaReduceWindowOp op = *this;
const auto &input_ty = op.input().getType().cast<ShapedType>();
auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult {
ElementsAttr attr;
if (matchPattern(val, m_Constant(&attr))) {
if (attr.getType().getRank() != 1) {
return op.emitOpError() << "expects the rank of " << attr_name
<< "to be 1, got " << attr.getType().getRank();
}
if (input_ty.hasRank()) {
int64_t input_rank = input_ty.getRank();
int64_t size = attr.size();
if (input_rank != size) {
return op.emitOpError() << "expects the size of " << attr_name
<< " to be equal to the input "
"rank ("
<< size << " vs. " << input_rank << ")";
}
}
}
return success();
};
if (check(op.window_dimensions(), "window_dimensions").failed())
return failure();
if (check(op.window_strides(), "window_strides").failed()) return failure();
if (check(op.base_dilations(), "base_dilations").failed()) return failure();
if (check(op.window_dilations(), "window_dilations").failed())
return failure();
ElementsAttr padding;
if (matchPattern(op.padding(), m_Constant(&padding))) {
const ShapedType &padding_ty = padding.getType();
if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) {
return op.emitOpError()
<< "expects padding to be a matrix with minor dimension 2, got "
<< padding.getType().getShape();
}
}
auto module = op->getParentOfType<mlir::ModuleOp>();
auto func = dyn_cast_or_null<mlir::func::FuncOp>(
SymbolTable::lookupSymbolIn(module, op.computation()));
if (!func) {
return op.emitOpError() << "has no reduction function specified";
}
auto func_type = func.getFunctionType();
if (func_type.getNumInputs() != 2) {
return op.emitOpError()
<< "expects reduction function to take 2 parameters, but "
"has "
<< func_type.getNumInputs() << " parameter(s)";
}
return success();
}
//===----------------------------------------------------------------------===//
// XlaSelectAndScatterOp
//===----------------------------------------------------------------------===//
LogicalResult XlaSelectAndScatterOp::verify() {
XlaSelectAndScatterOp op = *this;
auto input_ty = op.operand().getType().cast<ShapedType>();
auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult {
ElementsAttr attr;
if (input_ty.hasRank() && matchPattern(val, m_Constant(&attr))) {
int64_t input_rank = input_ty.getRank();
int64_t size = attr.size();
if (input_rank != size) {
return op.emitOpError() << "expects the size of " << attr_name
<< "to be equal to the input "
"rank ("
<< size << " vs. " << input_rank << ")";
}
}
return success();
};
if (check(op.window_dimensions(), "window_dimensions").failed())
return failure();
if (check(op.window_strides(), "window_strides").failed()) return failure();
ElementsAttr padding;
if (matchPattern(op.padding(), m_Constant(&padding))) {
const ShapedType &padding_ty = padding.getType();
if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) {
return op.emitOpError()
<< "expects padding to be a matrix with minor dimension 2, got "
<< padding.getType().getShape();
}
}
auto module = op->getParentOfType<mlir::ModuleOp>();
auto select_func = dyn_cast_or_null<mlir::func::FuncOp>(
SymbolTable::lookupSymbolIn(module, op.select()));
if (!select_func) {
return op.emitOpError() << "has no select function specified";
}
auto select_func_type = select_func.getFunctionType();
if (select_func_type.getNumInputs() != 2) {
return op.emitOpError()
<< "expects select function to take 2 parameters, but has "
<< select_func_type.getNumInputs() << " parameter(s)";
}
if (select_func_type.getNumResults() != 1 ||
!getElementTypeOrSelf(select_func_type.getResult(0)).isInteger(1)) {
return op.emitOpError() << "expects select function to return a single "
"boolean result but got "
<< select_func_type.getResult(0);
}
auto scatter_func = dyn_cast_or_null<mlir::func::FuncOp>(
SymbolTable::lookupSymbolIn(module, op.scatter()));
if (!scatter_func) {
return op.emitOpError() << "has no scatter function specified";
}
auto scatter_func_type = scatter_func.getFunctionType();
if (scatter_func_type.getNumInputs() != 2) {
return op.emitOpError()
<< "expects scatter function to take 2 parameters, but has "
<< scatter_func_type.getNumInputs() << " parameter(s)";
}
return success();
}
//===----------------------------------------------------------------------===//
// XlaVariadicReduceOp
//===----------------------------------------------------------------------===//
LogicalResult XlaVariadicReduceOp::verify() {
XlaVariadicReduceOp op = *this;
// We rely on V2 for the majority of the checks.
const auto &input_ty = op.input().getType();
if (input_ty.empty()) return op.emitOpError() << "No input";
const auto &dtype = input_ty[0].cast<TensorType>().getElementType();
for (const auto &ty : input_ty) {
if (ty.cast<TensorType>().getElementType() != dtype)
return op.emitOpError()
<< "This version is limited to operands of the same dtype";
}
return success();
}
class XlaVariadicReduceToV2 : public OpRewritePattern<TF::XlaVariadicReduceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::XlaVariadicReduceOp op,
PatternRewriter &rewriter) const override {
mlir::TF::XlaVariadicReduceV2Op xla_variadic_reduce_v2_op =
rewriter.create<::mlir::TF::XlaVariadicReduceV2Op>(
op.getLoc(), op.getResults().getTypes(), op.input(),
op.init_value(), op.dimensions_to_reduce(), op.reducer());
rewriter.replaceOp(op, xla_variadic_reduce_v2_op.getResults());
return ::mlir::success();
};
};
void XlaVariadicReduceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<XlaVariadicReduceToV2>(context);
}
//===----------------------------------------------------------------------===//
// XlaVariadicReduceV2Op
//===----------------------------------------------------------------------===//
LogicalResult XlaVariadicReduceV2Op::verify() {
XlaVariadicReduceV2Op op = *this;
const auto &inputs_ty = op.inputs().getType();
int n_inputs = inputs_ty.size();
if (n_inputs < 1) return op.emitOpError() << "No inputs";
const auto &init_values_ty = op.init_values().getType();
int n_init_values = init_values_ty.size();
if (n_init_values != n_inputs) {
return op.emitOpError() << "Number of inputs (" << n_inputs
<< ") is different than number of init_values ("
<< n_init_values << ")";
}
auto input_ty_0 = inputs_ty[0].cast<ShapedType>();
if (input_ty_0.hasStaticShape()) {
for (int i = 1; i < n_inputs; ++i) {
auto input_ty_i = inputs_ty[i].cast<ShapedType>();
if (input_ty_i.hasStaticShape() &&
input_ty_i.getShape() != input_ty_0.getShape()) {
return op.emitOpError()
<< "inputs[" << i << "] has shape [" << input_ty_i.getShape()
<< "] different than the shape of inputs[0]: "
<< input_ty_0.getShape();
}
}
if (op.dimensions_to_reduce().size() > input_ty_0.getRank()) {
return op.emitOpError()
<< "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2";
}
}
for (int i = 0; i < n_inputs; ++i) {
auto init_value_ty_i = init_values_ty[i].cast<ShapedType>();
if (init_value_ty_i.hasRank() && init_value_ty_i.getRank() != 0) {
return op.emitOpError()
<< "init_values[" << i << "] must be a scalar but got ["
<< init_value_ty_i.getShape() << "]";
}
}
auto module = op->getParentOfType<mlir::ModuleOp>();
auto function = dyn_cast_or_null<mlir::func::FuncOp>(
SymbolTable::lookupSymbolIn(module, op.reducer()));
if (!function) return op.emitOpError() << "No reducer";
if (!function.getBody().hasOneBlock())
return op.emitOpError() << "reducer has more than one block";
return success();
}
//===----------------------------------------------------------------------===//
// XlaVariadicSortOp
//===----------------------------------------------------------------------===//
LogicalResult XlaVariadicSortOp::verify() {
XlaVariadicSortOp op = *this;
const auto &inputs_ty = op.inputs().getType();
int n_inputs = inputs_ty.size();
auto input_ty_0 = inputs_ty[0].cast<ShapedType>();
if (input_ty_0.hasStaticShape()) {
for (int i = 1; i < n_inputs; ++i) {
auto input_ty_i = inputs_ty[i].cast<ShapedType>();
if (input_ty_i.hasStaticShape() &&
input_ty_i.getShape() != input_ty_0.getShape()) {
return op.emitOpError()
<< "input[" << i << "] has shape [" << input_ty_i.getShape()
<< "] different than the shape of input[0]: "
<< input_ty_0.getShape();
}
}
}
ElementsAttr dimension;
if (matchPattern(op.dimension(), m_Constant(&dimension))) {
if (dimension.getType().getRank() != 0 ||
dimension.getType().getNumElements() != 1)
return op.emitOpError() << "dimension must be a scalar";
}
auto module = op->getParentOfType<mlir::ModuleOp>();
auto function = dyn_cast_or_null<mlir::func::FuncOp>(
SymbolTable::lookupSymbolIn(module, op.comparator()));
if (!function) return op.emitOpError() << "No comparator";
if (!function.getBody().hasOneBlock())
return op.emitOpError() << "comparator has more than one block";
return success();
}
//===----------------------------------------------------------------------===//
// SetStaticDimensionBoundsOp
//===----------------------------------------------------------------------===//
//
LogicalResult SetStaticDimensionBoundsOp::verify() {
SetStaticDimensionBoundsOp op = *this;
mlir::ShapedType input_type = op.input().getType().cast<mlir::ShapedType>();
mlir::ShapedType static_shape_type =
op.static_shape().getType().cast<mlir::ShapedType>();
int input_type_rank = input_type.hasRank() ? input_type.getRank() : -1;
if (input_type_rank > 2) {
return op.emitOpError() << "was used with an input tensor with rank > 2, "
"only tensors of rank 1,2 are supported";
}
if (static_shape_type.hasRank() && static_shape_type.getRank() != 1) {
return op.emitOpError("static shape must be of rank 1 (vector)");
}
if (input_type_rank != -1 && static_shape_type.hasStaticShape()) {
if (static_shape_type.getShape()[0] != input_type_rank) {
return op.emitOpError(
"static shape must have num_elements == rank of input "
"tensor");
}
}
return success();
}
} // namespace TF
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"