blob: 737665d51dc0180d8aec95dc6c3f1fd8569fa2a6 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
#include <numeric>
#include <string>
#include <tuple>
#include <type_traits>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
namespace TF {
namespace {
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<NegNested>(context);
}
//===----------------------------------------------------------------------===//
// NotEqualOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(NotEqualOp op) {
// If we allow inputs to have incompatible type, then nothing to do.
if (!op.incompatible_shape_error()) return success();
// Otherwise, check inputs are broadcastable.
return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
op.getOperation());
}
void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
Value y, BoolAttr incompatible_shape_error) {
auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
incompatible_shape_error);
return build(builder, result, result_type, x, y, incompatible_shape_error);
}
//===----------------------------------------------------------------------===//
// OneHotOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(OneHotOp op) {
int64_t axis = op.axis().getSExtValue();
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (indices_ty &&
!(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
return op.emitOpError()
<< "expected axis (" << axis << ") to be -1 or between [0, "
<< indices_ty.getShape().size() << "]";
}
if (axis < -1) {
return op.emitOpError() << "expected axis (" << axis
<< ") to be -1 or between [0, rank(indices()))";
}
if (!IsOfRankOrUnranked(op.depth(), 0)) {
return op.emitOpError() << "requires depth to be a scalar";
}
if (!IsOfRankOrUnranked(op.on_value(), 0)) {
return op.emitOpError() << "requires on_value to be a scalar";
}
if (!IsOfRankOrUnranked(op.off_value(), 0)) {
return op.emitOpError() << "requires off_value to be a scalar";
}
DenseIntElementsAttr depth_attr;
if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
if (depth_attr.getType().getRank() != 0)
return op.emitOpError() << "requires depth to be a scalar";
int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
if (depth < 0) {
return op.emitOpError() << "depth must be non-negative, got: " << depth;
}
}
return success();
}
static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
Value off_value, IntegerAttr axis) {
int64_t axis_val = axis.getInt();
Type element_ty = on_value.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
if (axis_val < -1) return unranked_ty;
auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
if (!indices_ty) return unranked_ty;
auto shape = llvm::to_vector<2>(indices_ty.getShape());
if (axis_val == -1) axis_val = shape.size();
int64_t depth_val = ShapedType::kDynamicSize;
DenseIntElementsAttr depth_attr;
if (matchPattern(depth, m_Constant(&depth_attr)) &&
depth_attr.getNumElements() == 1)
depth_val = (*depth_attr.begin()).getSExtValue();
shape.insert(shape.begin() + axis_val, depth_val);
return RankedTensorType::get(shape, element_ty);
}
void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
Value depth, Value on_value, Value off_value,
IntegerAttr axis) {
build(builder, result,
InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
depth, on_value, off_value, axis);
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PackOp op) {
// TODO(hinsu): Convert variadic length attributes to derived attributes.
Operation::operand_range values = op.values();
if (failed(VerifyTypesCompatibility(values,
/*mask_one_dim=*/false,
op.getOperation()))) {
return failure();
}
int64_t inputs_rank = -1;
for (Value value : values) {
if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
// Exit early as input types are verified to be compatible so all ranked
// tensors have the same rank.
inputs_rank = ty.getRank();
break;
}
}
if (inputs_rank == -1) return success();
// The values can be packed along any of the dimensions between 0 and
// inputs rank, inclusive. Also, as the negative axis values wrap around so
// the axis value range is [-(R+1), R+1).
int64_t range_begin = -inputs_rank - 1; // Inclusive
int64_t range_end = inputs_rank + 1; // Exclusive
int64_t axis = op.axis().getSExtValue();
if (axis < range_begin || axis >= range_end) {
return op.emitError() << "attribute 'axis' should be within range ["
<< range_begin << ", " << range_end
<< "); actual value: " << axis;
}
return success();
}
OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
// Fold pack operation if it computes the input tensor shape:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
//
// Where `...` are some statically known dimensions. In this case %pack can be
// replaced with a %shape. This is a common pattern in models with a dynamic
// batch size.
// Pack operation should pack at least two values.
if (values().size() < 2) return {};
// Dimensions packed along axis = 0 (pack scalars into vector).
if (axis().getSExtValue() != 0) return {};
// First packed value is defined by a strided slice operation.
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
if (!slice_op) return {};
// Input to the slice op is defined by shape operation.
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
if (!shape_op) return {};
// Input tensor, which shape is reconstructed by the pack operation.
Value tensor = shape_op.input();
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
// scalar value from input vector).
if (slice_op.begin_mask().getSExtValue() != 0 ||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
slice_op.end_mask().getSExtValue() != 0 ||
slice_op.new_axis_mask().getSExtValue() != 0 ||
slice_op.shrink_axis_mask().getSExtValue() != 1)
return {};
// Returns a value if the `value` is defined by a ConstOp with a single
// integer element in it and has an expected rank.
auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
if (!const_op) return None;
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
if (!value_attr || value_attr.getNumElements() != 1) return None;
auto value_ty = value_attr.getType();
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
auto splat = value_attr.getSplatValue<IntegerAttr>();
return splat.getValue().getSExtValue();
};
// All other packed values are scalar constants.
SmallVector<int64_t, 4> packed_dims;
packed_dims.reserve(values().size() - 1);
for (Value operand : llvm::drop_begin(values(), 1)) {
if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
packed_dims.push_back(*dim);
} else {
return {};
}
}
// Slice exactly the first shape dimension:
// begin = [0] end = [1], strides = [1]
auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
*begin != 0 || *end != 1 || *strides != 1)
return {};
// First tensor dimension is dynamic.
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
!arg_ty.isDynamicDim(0))
return {};
// Argument tensor rank is equal to the number of packed dimensions.
if (arg_ty.getRank() != values().size()) return {};
// All other dimensions are statically known and equal to packed dims.
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
return {};
// Replace %pack with %shape.
return slice_op.input();
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
// Paddings must be defined by a constant operation.
auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
if (!paddings_op) return failure();
auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
if (!paddings_value ||
paddings_value.getNumElements() != permutation.size() * 2)
return failure();
SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
size_t outer_idx = index_pair.index() / 2;
size_t inner_idx = index_pair.index() % 2;
shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
index_pair.value().getSExtValue();
}
// Add constant operation with a new paddings.
OpBuilder builder(getOperation());
auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
builder.getIntegerType(32));
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
// Use new paddings.
setOperand(1, shuffled_paddings_op);
// Change the result type.
getResult().setType(ShuffleRankedTensorType(getResult().getType(),
ReversePermutation(permutation)));
return success();
}
//===----------------------------------------------------------------------===//
// ParseExampleV2Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(ParseExampleV2Op op) {
// NOTE(mrry): This validates properties of an op that would previously be
// validated by the TensorFlow OpDef type checker. In addition to these
// checks, the shape inference function for ParseExampleV2 validates the
// consistency of the argument and result types.
// Validate dense variadic input and output lengths.
// NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
// do not need to validate dense_defaults.
auto dense_types_count =
std::distance(op.Tdense().begin(), op.Tdense().end());
auto dense_values_count =
std::distance(op.dense_values().begin(), op.dense_values().end());
if (dense_values_count != dense_types_count) {
return op.emitError() << "output 'dense_values' should have same length "
<< "as attribute 'Tdense'";
}
// Validate sparse variadic output lengths.
// NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
// do not need to validate sparse_values.
auto sparse_types_count =
std::distance(op.sparse_types().begin(), op.sparse_types().end());
if (op.num_sparse() != sparse_types_count) {
return op.emitError() << "attribute 'num_sparse' should be the same as "
<< "the length of attribute 'sparse_types'";
}
if (op.sparse_indices().size() != sparse_types_count) {
return op.emitError() << "output 'sparse_indices' should have same length "
<< "as attribute 'sparse_types'";
}
if (op.sparse_shapes().size() != sparse_types_count) {
return op.emitError() << "output 'sparse_shapes' should have same length "
<< "as attribute 'sparse_types'";
}
// Validate ragged variadic output lengths.
auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
op.ragged_value_types().end());
auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
op.ragged_split_types().end());
if (ragged_value_types_count != ragged_split_types_count) {
return op.emitError() << "attribute 'ragged_value_types' should have same "
<< "length as attribute 'ragged_split_types'";
}
return success();
}
//===----------------------------------------------------------------------===//
// PartitionedCallOp
//===----------------------------------------------------------------------===//
template <class OpClass>
static LogicalResult VerifyPartitionedCall(OpClass op) {
auto module = op.template getParentOfType<ModuleOp>();
SymbolRefAttr func = op.getAttr("f").template cast<SymbolRefAttr>();
auto function =
dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
if (!function) {
return op.emitError("'f' attribute refers to an undefined function: ")
<< func;
}
FunctionType function_ty = function.getType();
int func_arg_count = function_ty.getNumInputs();
int arg_count = op.args().size();
if (arg_count != func_arg_count) {
return op.emitError() << "argument count mismatch: 'args' has " << arg_count
<< " arguments, but '" << func << "' expects "
<< func_arg_count;
}
return success();
}
//===----------------------------------------------------------------------===//
// PowOp
//===----------------------------------------------------------------------===//
OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
if (constant_y && constant_y.isSplat()) {
APFloat y_value = constant_y.getSplatValue<APFloat>();
auto output_type = getType().cast<ShapedType>();
if (y_value.isZero() && output_type.hasStaticShape()) {
return DenseElementsAttr::get(
output_type,
FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
}
if (y_value.isExactlyValue(1.0)) {
return x();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// QrOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// * Input type, if ranked, must have at least 2 dimensions and at most
// INT32_MAX dimensions.
//
static LogicalResult Verify(QrOp op) {
auto ttype = op.input().getType().cast<TensorType>();
if (!ttype.hasRank()) return success();
if (!HasRankAtLeast(op.input(), 2))
return op.emitOpError(
"requires ranked input tensor to be of rank 2 or more");
if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
return op.emitOpError(
"requires ranked input tensor to be of rank INT32_MAX or less");
return success();
}
//===----------------------------------------------------------------------===//
// ReadVariableOp
//===----------------------------------------------------------------------===//
void ReadVariableOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReadVariableOfCast>(context);
}
//===----------------------------------------------------------------------===//
// ReciprocalOp
//===----------------------------------------------------------------------===//
void ReciprocalOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReciprocalNested>(context);
}
//===----------------------------------------------------------------------===//
// RandomUniformOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(RandomUniformOp op) {
if (!IsOfRankOrUnranked(op.shape(), 1))
return op.emitOpError("shape must be 1D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// RangeOp
//===----------------------------------------------------------------------===//
void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
Value limit, Value delta) {
assert(start.getType() == limit.getType());
assert(start.getType() == delta.getType());
DenseIntElementsAttr start_val;
DenseIntElementsAttr limit_val;
DenseIntElementsAttr delta_val;
if (matchPattern(start, m_Constant(&start_val)) &&
matchPattern(limit, m_Constant(&limit_val)) &&
matchPattern(delta, m_Constant(&delta_val))) {
auto size = llvm::APIntOps::RoundingSDiv(
*limit_val.begin() - *start_val.begin(), *delta_val.begin(),
llvm::APInt::Rounding::DOWN);
return RangeOp::build(
builder, result,
RankedTensorType::get(
size.getSExtValue(),
start.getType().cast<TensorType>().getElementType()),
start, limit, delta);
}
return RangeOp::build(
builder, result,
RankedTensorType::get(
{-1}, start.getType().cast<TensorType>().getElementType()),
start, limit, delta);
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
return RankOp::build(builder, result,
RankedTensorType::get({}, builder.getIntegerType(32)),
input);
}
// This will create a constant value for RankOp of a ranked tensor.
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
auto type = input().getType();
auto ranked_type = type.dyn_cast<RankedTensorType>();
if (!ranked_type) return {};
auto output_type = getType().cast<ShapedType>();
int32_t rank = ranked_type.getRank();
return DenseIntElementsAttr::get(output_type, rank);
}
//===----------------------------------------------------------------------===//
// RealDivOp
//===----------------------------------------------------------------------===//
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
}
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
// TODO(b/128020684): Verify the output type.
static LogicalResult Verify(ReshapeOp op) {
auto shape_type = op.shape().getType().cast<TensorType>();
if (!shape_type.hasRank()) return success();
if (shape_type.getRank() != 1)
return op.emitOpError("shape must be 1D tensor");
auto rank_by_shape = shape_type.getShape()[0];
auto type_of_tensor = op.tensor().getType().cast<TensorType>();
// No compile time verification for unknown sized shape.
if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success();
int64_t num_by_tensor = type_of_tensor.getNumElements();
auto out_ty = op.getType().dyn_cast<RankedTensorType>();
if (out_ty && out_ty.hasStaticShape()) {
int64_t num_output_elements = out_ty.getNumElements();
if (num_by_tensor != num_output_elements)
return op.emitOpError()
<< "number of output elements (" << num_output_elements
<< ") does not match expected number of elements ("
<< num_by_tensor << ")";
}
// Check values if constant shape. No compiling time verification for
// non-constant shape.
auto *shape_op = op.shape().getDefiningOp();
if (!shape_op) return success();
Attribute shape_cst;
if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success();
auto shape_cst_attr = shape_cst.dyn_cast<ElementsAttr>();
if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor");
if (auto opaque_attr = shape_cst_attr.dyn_cast<OpaqueElementsAttr>()) {
opaque_attr.decode(shape_cst_attr);
}
// We know the shape is a 1-D Tensor, then let us get the number of
// elements it implies.
unsigned num_by_shape = 1;
unsigned unknown_dim_count = 0;
for (int i = 0, e = rank_by_shape; i != e; ++i) {
auto num = shape_cst_attr.getValue<IntegerAttr>(i).getInt();
// The dimension size value can be -1, and that the real size needs to
// be computed so that the total size remains constant. At most one
// component of shape can be -1.
if (num == -1) {
if (++unknown_dim_count > 1) {
return op.emitOpError("more than one component of shape are -1");
}
} else {
num_by_shape *= num;
}
}
// If there is one component of shape is -1, the dimension should be
// computed so that the total size remains constant.
if (unknown_dim_count == 1) {
if (num_by_tensor % num_by_shape != 0)
return op.emitOpError(
"one component of shape is -1 but couldn't infer the dimension");
return success();
}
// If the elements by the tensor and implies by the shape don't match,
// fail this static check.
if (num_by_tensor != num_by_shape) {
return op.emitOpError(
"mismatch in tensor elements and shape implied elements");
}
return success();
}
void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
Value shape) {
auto ttype = tensor.getType().cast<ShapedType>();
auto etype = ttype.getElementType();
auto unranked = [&builder, etype, &result, shape, tensor]() {
return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype),
tensor, shape);
};
// If tensor is unranked then we have no info about output of shape.
if (!ttype.hasRank()) return unranked();
DenseIntElementsAttr attr_shape;
if (matchPattern(shape, m_Constant(&attr_shape))) {
llvm::SmallVector<int64_t, 4> const_shape;
const_shape.reserve(attr_shape.getNumElements());
// Detect if reshape output shape is folded.
bool flatten = false;
int unknown_index = -1;
// The product of constant shape argument excluding unknown dimension.
int64_t product_cshape = 1;
for (auto e : llvm::enumerate(attr_shape)) {
int64_t val = e.value().getSExtValue();
if (IsUnknownDimOrRank(val)) {
if (flatten) {
mlir::emitError(result.location)
<< "only one unknown dimension allowed";
return;
}
flatten = true;
unknown_index = e.index();
} else {
product_cshape *= val;
}
const_shape.push_back(val);
}
// Compute the value of the unknown dimension.
if (flatten) {
// Compute number of elements in tensor shape.
auto tshape = ttype.getShape();
int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1,
std::multiplies<int64_t>());
// Set the unknown dimension such that total number of elements remain
// constant.
// Note: The case where the ratio is not integral, and so the total size
// of reshape not constant, is checked in verify function.
const_shape[unknown_index] = product_tshape / product_cshape;
}
return ReshapeOp::build(builder, result,
RankedTensorType::get(const_shape, etype), tensor,
shape);
}
return unranked();
}
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RedundantReshape, ReshapeToSelfShape>(context);
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
Value tensor = this->tensor();
// Fold reshape if operand and result types are the same and all dimensions
// are statically known (no-op reshape).
auto result_ty = getType().dyn_cast<ShapedType>();
if (result_ty && result_ty.hasStaticShape() &&
result_ty == tensor.getType()) {
return tensor;
}
return {};
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SelectToSelectV2>(context);
}
// Verifies a few extra requirements on SelectOp:
// (1) `then` and `else` must have same shape
// (2) At least one of the following must be true:
// (a) `cond` has the same rank as `then` and `else`
// (b) `cond` is a scalar
// (c) `cond` is a vector AND `then` and `else` are non-scalar with their
// first dimension equal to `cond`.
static LogicalResult Verify(SelectOp op) {
auto then_tensor = op.t().getType().cast<TensorType>();
auto else_tensor = op.e().getType().cast<TensorType>();
// Check (1).
if (!AreCastCompatible({then_tensor, else_tensor}))
return op.emitOpError() << "requires t and e have compatible shapes";
// Get data rank (if exists).
int data_rank;
// If data is unranked or data_rank is 0, this will remain -2. Otherwise
// refers to first dimension of then and/or else.
int data_first_dim = -2;
bool then_has_rank = then_tensor.hasRank();
bool else_has_rank = else_tensor.hasRank();
if (then_has_rank && else_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
if (else_tensor.getRank() > 0)
data_first_dim = std::max(
static_cast<int>(else_tensor.getShape().front()), data_first_dim);
} else if (then_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
} else if (else_has_rank) {
data_rank = else_tensor.getRank();
if (else_tensor.getRank() > 0)
data_first_dim = else_tensor.getShape().front();
} else {
// Neither has a rank.
return success();
}
auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
if (!cond_tensor) return success();
auto cond_rank = cond_tensor.getRank();
// Check (2a) and (2b).
if (cond_rank == 0 || cond_rank == data_rank) return success();
// Check (2c).
if (cond_rank == 1) {
auto cond_shape = cond_tensor.getShape().front();
if (data_rank == 0) {
return op.emitOpError()
<< "requires that t and e are nonscalar when pred is a vector";
}
// We know `data` tensor has a rank of at least 1.
if (data_first_dim != -1 && cond_shape != -1 &&
data_first_dim != cond_shape) {
return op.emitOpError() << "requires that, when pred is a vector, the "
"shape matches the first dimension of t and e";
}
return success();
}
// None of (2a,b,c) were true; fail.
return op.emitOpError() << "requires that pred is a scalar OR has the same "
"rank as t and e OR is a vector";
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//
static Type InferSelectV2OpType(Value condition, Value e, Value t) {
Type element_ty = e.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
Type broadcasted_ty =
OpTrait::util::getBroadcastedType(e.getType(), t.getType());
if (!broadcasted_ty) return unranked_ty;
auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
// Explicitly get broadcasted output type as element types of condition may
// not be same as the broadcated type's element type.
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
broadcasted_ranked_ty.getShape(),
result_shape))
return unranked_ty;
return RankedTensorType::get(result_shape, element_ty);
}
void SelectV2Op::build(OpBuilder &builder, OperationState &result,
Value condition, Value e, Value t) {
build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
}
//===----------------------------------------------------------------------===//
// ShapeOp
//===----------------------------------------------------------------------===//
namespace {
// Validates Shape/ShapeN/VariableShape operand and associated result types.
LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
Type result_type,
int variadic_idx = -1) {
std::string variadic_idx_str =
variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
if (!result_ranked_type) return success();
if (result_ranked_type.getShape().size() != 1)
return op->emitOpError("requires 1D type for result") << variadic_idx_str;
auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
if (operand_ranked_type) {
// The operand is a ranked tensor.
if (result_ranked_type.hasStaticShape() &&
!operand_ranked_type.getShape().empty() &&
result_ranked_type.getDimSize(0) !=
operand_ranked_type.getShape().size())
return op->emitOpError("requires dimension size of result")
<< variadic_idx_str << " to match rank of operand"
<< variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, print a warning if the result
// is static.
// Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
op->emitWarning("has static shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
}
Type element_type = result_ranked_type.getElementType();
if (!element_type.isSignlessInteger(32) &&
!element_type.isSignlessInteger(64))
return op->emitOpError("requires int32 or int64 return type for result")
<< variadic_idx_str;
return success();
}
} // anonymous namespace
static LogicalResult Verify(ShapeOp op) {
return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
}
// Converts shape of the given type to attribute if it is of ranked tensor type.
// Returned attribute has integer elements of the given width.
static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
auto shape = ranked_ty.getShape();
int rank = shape.size();
SmallVector<APInt, 4> dimensions;
dimensions.reserve(rank);
for (int i = 0; i < rank; ++i)
dimensions.push_back(APInt(out_width, shape[i]));
auto result_type = RankedTensorType::get(
{rank}, IntegerType::get(out_width, input_ty.getContext()));
return DenseElementsAttr::get(result_type, dimensions);
}
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
return ConvertShapeToAttr(getOperand().getType(), width);
}
void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
BoolAttr use32Bit) {
auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
: builder.getIntegerType(64);
return ShapeOp::build(builder, result,
RankedTensorType::get({rank}, out_type), input);
}
//===----------------------------------------------------------------------===//
// ShapeNOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(ShapeNOp op) {
const size_t num_tensors = op.N();
if (op.getNumOperands() != num_tensors)
return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
<< op.getNumOperands() << " operand(s)";
if (op.getNumResults() != num_tensors)
return op.emitOpError() << "requires " << num_tensors << " result(s), got "
<< op.getNumResults() << " result(s)";
for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
auto verification = VerifyShapeOperandAndResult(
op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
if (failed(verification)) return verification;
}
return success();
}
namespace {
// Canonicalization pattern for ShapeNOp that don't have all
// static input shapes. Replacing output values corresponding to static input
// types may enable optimizations in users of the values.
class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
using OpRewritePattern<ShapeNOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeNOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() == 0) {
rewriter.eraseOp(op);
return success();
}
int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
SmallVector<Value, 4> results(op.getNumOperands());
SmallVector<int64_t, 4> dynamic_indices;
SmallVector<Value, 4> dynamic_inputs;
SmallVector<Type, 4> result_types;
for (auto e : llvm::enumerate(op.getOperands())) {
if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
} else {
dynamic_indices.push_back(e.index());
dynamic_inputs.push_back(e.value());
result_types.push_back(op.getType(e.index()));
}
}
if (dynamic_inputs.size() == op.getNumOperands()) {
// Cannot canonicalize ShapeN if all inputs are dynamic.
return failure();
}
// Create a ShapeNOp for all dynamic inputs.
if (!dynamic_inputs.empty()) {
auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
op.getLoc(), result_types, dynamic_inputs);
for (auto index_result :
llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
results[std::get<0>(index_result)] = std::get<1>(index_result);
}
}
rewriter.replaceOp(op, results);
return success();
}
};
// Canonicalize ShapeNOp to ShapeOp if there is only one operand.
class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
using OpRewritePattern<ShapeNOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeNOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 1) {
return failure();
}
auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
op.getOperand(0));
rewriter.replaceOp(op, {shape});
return success();
}
};
} // namespace
void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
}
//===----------------------------------------------------------------------===//
// SizeOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
//
static LogicalResult Verify(SizeOp op) {
if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
return op.emitOpError(
"requires ranked input tensor to be of rank INT32_MAX or less");
return success();
}
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
// Verifies that:
//
// - operands begin and size are 1D with the same number of elements.
// - if the input is a ranked tensor, the rank of the input equals the number
// of elements in operands begin and size.
// - if begin are constants, that
// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
// - if begins aren't constant but the input is a ranked tensor, that
// size[i] <= input_ty.getShape()[i]
//
static LogicalResult Verify(SliceOp op) {
RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
if (begin_ty && begin_ty.getRank() != 1) {
return op.emitOpError() << "requires begin operand to be 1D tensor";
}
RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
if (size_ty && size_ty.getRank() != 1) {
return op.emitOpError() << "requires size operand to be 1D tensor";
}
if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
!size_ty.hasStaticShape())
return success();
if (begin_ty.getNumElements() != size_ty.getNumElements()) {
return op.emitOpError() << "requires begin and size operands to have the"
" same number of elements";
}
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
return op.emitOpError() << "requires number of elements in begin and size"
"are equal to input rank";
}
DenseIntElementsAttr begin_indices;
if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
DenseIntElementsAttr slice_sizes;
bool constant_slice_sizes =
matchPattern(op.size(), m_Constant(&slice_sizes));
int dim = 0;
for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
int64_t begin_index = raw_begin_index.getSExtValue();
int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
int64_t slice_size = constant_slice_sizes
? slice_sizes.getValue<APInt>(dim).getSExtValue()
: 0;
if (slice_size == -1 && input_size != -1) {
slice_size = input_size - begin_index;
}
if (begin_index < 0 ||
(input_size != -1 && begin_index + slice_size > input_size)) {
return op.emitOpError()
<< "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
}
++dim;
}
} else if (input_ty) {
// If the inputs are ranked, we can do a few more sanity checks.
DenseIntElementsAttr slice_sizes;
if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
auto input_shape = input_ty.getShape();
for (int64_t i = 0; i < input_ty.getRank(); ++i) {
int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
int64_t input_size = input_shape[i];
if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
"is unknown at compile time";
}
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// SoftmaxOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SoftmaxOp op) {
if (!HasRankAtLeast(op.logits(), 1)) {
return op.emitOpError("requires operand to have rank at least 1");
}
return success();
}
//===----------------------------------------------------------------------===//
// SoftmaxCrossEntropyWithLogitsOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// * Input types are broadcast compatible and the broadcasted type has rank two.
//
static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
auto broadcasted_ty = OpTrait::util::getBroadcastedType(
op.features().getType(), op.labels().getType())
.dyn_cast_or_null<ShapedType>();
if (!broadcasted_ty ||
(broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
return op.emitOpError(
"requires features and labels to be broadcast compatible to rank two");
return success();
}
//===----------------------------------------------------------------------===//
// SparseSoftmaxCrossEntropyWithLogitsOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
if (!IsOfRankOrUnranked(op.features(), 2)) {
return op.emitOpError("requires features operand of rank two");
}
if (!IsOfRankOrUnranked(op.labels(), 1)) {
return op.emitOpError("requires labels operand of rank one");
}
auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
if (features_ty && labels_ty) {
int64_t features_batches = features_ty.getDimSize(0);
int64_t labels_batches = labels_ty.getDimSize(0);
if (!ShapedType::isDynamic(features_batches) &&
!ShapedType::isDynamic(labels_batches) &&
features_batches != labels_batches)
return op.emitOpError(
"requires features and labels with matching first dimension");
}
return success();
}
//===----------------------------------------------------------------------===//
// SplitOp
//===----------------------------------------------------------------------===//
// Verifies the input and split dimension operands for tf.Split/tf.SplitV.
// Writes the split dimension's index (adjusted with input rank) via `dim_index`
// if it's a constant.
template <class Op>
LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
*dim_index = llvm::None;
Value split_dim = op.split_dim();
if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
if (split_dim_type.getRank() != 0)
return op.emitOpError(
"split dimension should be an integer scalar tensor");
// We can perform further verification if the input tensor to be split has
// known rank and the split dimension tensor is a constant.
auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t input_rank = input_type.getRank();
if (input_rank == 0)
return op.emitOpError("cannot split scalar input tensor");
DenseIntElementsAttr split_dim_attr;
if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
int64_t index = (*split_dim_attr.begin()).getSExtValue();
if (index + input_rank < 0 || index >= input_rank) {
return op.emitOpError("split dimension must be in range [-")
<< input_rank << ", " << input_rank << ")";
}
if (index < 0) index += input_rank;
*dim_index = index;
return success();
}
static LogicalResult Verify(SplitOp op) {
Optional<int64_t> dim_index;
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
if (!dim_index) return success();
int64_t input_dim_size =
op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (input_dim_size == ShapedType::kDynamicSize) return success();
if (input_dim_size % op.getNumResults() != 0)
return op.emitOpError("dimension #")
<< *dim_index << " not divisible by the number of result tensors";
return success();
}
//===----------------------------------------------------------------------===//
// SplitVOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SplitVOp op) {
auto split_sizes_type =
op.size_splits().getType().dyn_cast<RankedTensorType>();
if (!split_sizes_type) return success();
if (split_sizes_type.getRank() != 1 ||
split_sizes_type.getDimSize(0) != op.getNumResults())
return op.emitOpError("split sizes should be a 1D tensor of ")
<< op.getNumResults() << " elements";
Optional<int64_t> dim_index = 0;
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
if (!dim_index) return success();
int64_t input_dim_size =
op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (input_dim_size == ShapedType::kDynamicSize) return success();
// If split sizes come from a constant, they must sum to the dimension size
// along split_dim, and we can have no more than one dynamic dimension.
DenseIntElementsAttr split_sizes_attr;
if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
return success();
int64_t total_dim_size = 0; // Total dimension size assigned to splits
llvm::Optional<int> dynamic_dim_index;
SmallVector<int64_t, 4> split_sizes;
split_sizes.reserve(
split_sizes_attr.getType().cast<ShapedType>().getNumElements());
for (auto dim : llvm::enumerate(split_sizes_attr)) {
int64_t dim_val = dim.value().getSExtValue();
split_sizes.push_back(dim_val);
if (dim_val == ShapedType::kDynamicSize) {
// We cannot have more than one dynamic dimension.
if (dynamic_dim_index)
return op.emitOpError(
"cannot have more than one dynamic dimension in split sizes");
dynamic_dim_index = dim.index();
} else {
total_dim_size += dim_val;
}
}
if (!dynamic_dim_index && total_dim_size != input_dim_size)
return op.emitOpError(
"split sizes must sum up to the dimension size along split "
"dimension, found ")
<< total_dim_size << " vs " << input_dim_size;
if (dynamic_dim_index && total_dim_size > input_dim_size)
return op.emitOpError(
"split sizes must sum up to be less than or equal to the "
"dimension size along split dimension, found ")
<< total_dim_size << " vs " << input_dim_size;
return success();
}
//===----------------------------------------------------------------------===//
// SquareOp
//===----------------------------------------------------------------------===//
void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SquareOfSub>(context);
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SubOfNeg>(context);
}
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<SubOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//
void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
Value reduction_indices, BoolAttr keep_dims) {
Type out_ty =
InferReductionOpType(input, reduction_indices, keep_dims, &builder);
build(builder, result, out_ty, input, reduction_indices, keep_dims);
}
//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//
// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
// tf.SliceOp if both of the following are true:
// - All strides have a known value equal to 1
// - No masks are set (or masks can be applied by transforming the inputs to
// Slice)
// Verifies that,
//
// - begin, end and strides operands are 1D and they have the same number of
// elements. Here, the number of elements should be less than 32 to support
// 32-bit mask attributes.
// - None of the strides values are zero.
// - Ellipsis mask can have at most one bit set.
template <class OpTy>
static LogicalResult VerifyStridedSliceBase(OpTy op) {
// Expected size for operands begin, end and strides vector operands.
int64_t expected_size = -1;
for (Value val : {op.begin(), op.end(), op.strides()}) {
auto operand_ty = val.getType().dyn_cast<ShapedType>();
if (!operand_ty || !operand_ty.hasStaticShape()) {
// TensorFlow constant ops may have non-static shape because the shape is
// not propagated during constant folding. If the defining op for this
// operand is a constant op, use the constant op's attribute to get the
// actual shape.
DenseIntElementsAttr attr;
if (!matchPattern(val, m_Constant(&attr))) continue;
operand_ty = attr.getType();
}
if (operand_ty.getRank() != 1)
return op.emitOpError()
<< "requires begin, end and strides to be 1D tensors";
int64_t length = operand_ty.getDimSize(0);
if (length == -1) continue;
if (expected_size == -1) {
// This op uses 32-bit masks.
if (length >= 32)
return op.emitOpError(
"requires begin, end and strides operands with less than 32 "
"elements");
expected_size = length;
} else if (length != expected_size) {
return op.emitOpError() << "requires begin, end and strides to have the "
"same number of elements";
}
}
// If strides are constants, verify that none of the element is zero.
DenseIntElementsAttr strides;
if (matchPattern(op.strides(), m_Constant(&strides))) {
if (llvm::is_contained(strides.getValues<APInt>(), 0))
return op.emitOpError("requires non-zero strides");
}
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
// exists only no more than one ellipsis.
uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue();
if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
return op.emitOpError("cannot have multiple ellipses");
return success();
}
// Clamps the given `val`: returns `low` if `val` is less than `low`; returns
// `high` if `high` is less than `val`; otherwise returns `val`.
template <class T>
constexpr const T &Clamp(const T &val, const T &low, const T &high) {
assert(!(high < low));
return (val < low) ? low : (high < val) ? high : val;
}
// Checks if the `index` bit of `val` is set.
template <class T>
constexpr bool IsSet(const T &val, unsigned index) {
return (val & (1 << index)) != 0;
}
// Sets the `index` bit of `val`.
template <class T>
constexpr void Set(T &val, unsigned index) {
val |= (1 << index);
}
// Unset the `index` bit of `val`.
template <class T>
constexpr void Unset(T &val, unsigned index) {
val &= ~(1 << index);
}
// Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
template <class T>
constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
unsigned dst_index) {
if (IsSet(src, src_index))
Set(dst, dst_index);
else
Unset(dst, dst_index);
}
// The sparse spec of strided slice does not correspond to the number of
// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
// 4, 8) would have dims = 2.
struct SparseSliceSpec {
int64_t dims;
int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
const ArrayRef<int64_t> &begin;
const ArrayRef<int64_t> &end;
const ArrayRef<int64_t> &strides;
};
// The dense spec of strided slice is the canonicalized version of sparse spec.
// The number of dimensions of dense spec correspond to the number of dimensions
// in operand tensor.
struct DenseSliceSpec {
int64_t dims;
int32_t begin_mask, end_mask, shrink_axis_mask;
SmallVectorImpl<int64_t> &begin;
SmallVectorImpl<int64_t> &end;
SmallVectorImpl<int64_t> &strides;
};
// Make a sparse spec into a dense index spec.
// The sparse spec does not correspond to the number of dimensions
// Make a dense spec that corresponds to the number of dimensions
//
// For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
// we need to produce the missing begin_mask, end_mask for the first two
// dimensions i.e. foo[:, :, 3:, 2].
static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
DenseSliceSpec *dense) {
// Build expanded dense begin, end, strides, begin_mask, end_mask, and
// shrink_axis_mask.
dense->begin.resize(dense->dims);
dense->end.resize(dense->dims);
dense->strides.resize(dense->dims);
dense->begin_mask = 0;
dense->end_mask = 0;
dense->shrink_axis_mask = 0;
// Count number of new_axis after ellipsis. This helps in calculating the
// number of dimensions ellipsis represents in the sparse spec.
bool ellipsis_seen = false;
int num_new_axis_after_ellipsis = 0;
for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
num_new_axis_after_ellipsis++;
if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
}
int dense_index = 0;
for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
if (IsSet(sparse.ellipsis_mask, sparse_index)) {
auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1 + num_new_axis_after_ellipsis,
dense->dims);
// Expand ellipsis into the appropriate dense indices. From current index
// until next_index, all dimensions would have begin and end masks set and
// stride 1, i.e., get all elements in those dimensions.
for (; dense_index < next_index; ++dense_index) {
dense->begin[dense_index] = dense->end[dense_index] = 0;
dense->strides[dense_index] = 1;
Set(dense->begin_mask, dense_index);
Set(dense->end_mask, dense_index);
}
continue;
}
assert(dense_index < dense->dims);
// Copy over the sparse indices to dense indices if ellipsis_mask and
// new_axis_mask are not set.
dense->begin[dense_index] = sparse.begin[sparse_index];
dense->end[dense_index] = sparse.end[sparse_index];
dense->strides[dense_index] = sparse.strides[sparse_index];
CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
dense_index);
dense_index++;
}
}
// For the given `input_shape`, calculates the sliced shape using the given
// `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
// `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
// `shrink_axis_mask` is not zero, this function will not drop the corresponding
// dimensions in `input_shape`; it will turn them into 1s. At the same time,
// canonicalizes `begin`, `end`, and `strides. The calculation follows
// tf.StridedSlice op semantics.
static void CalculateSlicedShapeFromDenseIndices(
MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
assert(input_shape.size() <= 32); // Only 32-bit masks are supported.
// Make sure ranges' ranks are consistent with the input.
assert(input_shape.size() == begin.size());
assert(input_shape.size() == end.size());
assert(input_shape.size() == stride.size());
for (int i = 0, e = input_shape.size(); i < e; ++i) {
if (ShapedType::isDynamic(input_shape[i])) continue;
int64_t dim_i = input_shape[i];
int64_t begin_i = begin[i];
int64_t end_i = end[i];
int64_t stride_i = stride[i];
// [0]: mask for begin, [1]: mask for end
int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
// [0]: bound for begin, [1]: bound for end
int64_t bounds[] = {stride_i > 0 ? 0 : -1,
stride_i > 0 ? dim_i : dim_i - 1};
// Canonicalizes the given range `point` (begin/end) according to the
// current dimension. `c` means case: 0 for begin, 1 for end.
auto canonicalize = [&](int64_t point, int c) {
if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
// Add dim as offset to negative range point.
point = point < 0 ? dim_i + point : point;
return Clamp(point, bounds[0], bounds[1]);
};
begin_i = canonicalize(begin_i, 0);
end_i = canonicalize(end_i, 1);
int64_t interval_len = end_i - begin_i;
int64_t size_i = 0;
// If internal length is zero or has different sign from stride, it's a
// degenerated case: we are slicing nothing. Otherwise, calculate the sliced
// size.
if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
begin[i] = begin_i;
if (IsSet(shrink_axis_mask, i)) {
// Shrink this dimension. It means we only take the element at begin_i.
input_shape[i] = 1;
end[i] = begin_i + 1;
stride[i] = 1;
} else {
input_shape[i] = size_i;
end[i] = end_i;
stride[i] = stride_i;
}
}
}
// For the given `input_shape`, calculates the sliced shape using the given
// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
// Updates the result back to `input_shape`.
static void CalculateSlicedShapeFromSparseIndices(
MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
int32_t new_axis_mask, int32_t shrink_axis_mask,
SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
SmallVectorImpl<int64_t> *stride) {
int64_t num_sparse_indices = sparse_begin.size();
SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask,
ellipsis_mask, new_axis_mask, shrink_axis_mask,
sparse_begin, sparse_end, sparse_strides};
// If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
// inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
// a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
if (sparse.ellipsis_mask == 0) {
Set(sparse.ellipsis_mask, sparse.dims);
sparse.dims++;
}
int64_t dims = input_shape.size();
DenseSliceSpec dense = {dims,
/*begin_mask = */ 0,
/*end_mask = */ 0,
/*shrink_axis_mask = */ 0,
*begin,
*end,
*stride};
BuildDenseSliceSpec(sparse, &dense);
CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
dense.end_mask, dense.shrink_axis_mask,
*begin, *end, *stride);
}
bool StridedSliceOp::GetSlicedBoundRanges(
SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
SmallVectorImpl<int64_t> *slice_stride) {
// TODO(hinsu): Support lowering for ops with dynamic begin and end values
// when it is possible to derive indices based on mask attributes.
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
!matchPattern(end(), m_Constant(&sparse_end_attr)) ||
!matchPattern(strides(), m_Constant(&sparse_strides_attr)))
return false;
auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return false;
auto input_shape = llvm::to_vector<4>(input_ty.getShape());
SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
for (const APInt &index : sparse_begin_attr)
sparse_begin.push_back(index.getSExtValue());
for (const APInt &index : sparse_end_attr)
sparse_end.push_back(index.getSExtValue());
for (const APInt &stride : sparse_strides_attr)
sparse_strides.push_back(stride.getSExtValue());
CalculateSlicedShapeFromSparseIndices(
input_shape, sparse_begin, sparse_end, sparse_strides,
begin_mask().getZExtValue(), end_mask().getZExtValue(),
ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(),
shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride);
return true;
}
//===----------------------------------------------------------------------===//
// StridedSliceGradOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(StridedSliceGradOp op) {
auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
if (shape_type && shape_type.getRank() != 1)
return op.emitOpError("'shape' operand must be 1D tensor, but got ")
<< shape_type.getRank() << "D tensor";
if (failed(VerifyStridedSliceBase(op))) return failure();
// TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
// the sliced type from StridedSlice.
return success();
}
bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
SmallVectorImpl<int64_t> *input_shape,
SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
SmallVectorImpl<int64_t> *slice_stride) {
DenseIntElementsAttr shape_attr;
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
!matchPattern(end(), m_Constant(&sparse_end_attr)) ||
!matchPattern(strides(), m_Constant(&sparse_strides_attr)))
return false;
int rank = std::distance(shape_attr.begin(), shape_attr.end());
input_shape->clear();
input_shape->reserve(rank);
for (const APInt &dim : shape_attr)
input_shape->push_back(dim.getSExtValue());
SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
for (const APInt &index : sparse_begin_attr)
sparse_begin.push_back(index.getSExtValue());
for (const APInt &index : sparse_end_attr)
sparse_end.push_back(index.getSExtValue());
for (const APInt &stride : sparse_strides_attr)
sparse_strides.push_back(stride.getSExtValue());
CalculateSlicedShapeFromSparseIndices(
*input_shape, sparse_begin, sparse_end, sparse_strides,
begin_mask().getZExtValue(), end_mask().getZExtValue(),
ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(),
shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride);
return true;
}
//===----------------------------------------------------------------------===//
// TensorListReserveOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TensorListReserveOp op) {
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
!IsOfRankOrUnranked(op.element_shape(), 1)) {
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
}
if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
return op.emitOpError("requires num_elements operand to be 0D tensor");
}
return success();
}
//===----------------------------------------------------------------------===//
// TensorListElementShapeOp
//===----------------------------------------------------------------------===//
OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
auto variant_type =
getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
if (variant_type.getSubtypes().empty()) return {};
return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
}
//===----------------------------------------------------------------------===//
// TensorListStackOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TensorListStackOp op) {
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
!IsOfRankOrUnranked(op.element_shape(), 1)) {
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
}
return success();
}
//===----------------------------------------------------------------------===//
// TensorScatterUpdateOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TensorScatterUpdateOp op) {
if (!HasRankAtLeast(op.tensor(), 1))
return op.emitOpError(
"requires tensor operand to have at least 1 dimension");
if (!HasRankAtLeast(op.indices(), 1))
return op.emitOpError(
"requires indices operand to have at least 1 dimension");
if (!HasRankAtLeast(op.updates(), 1))
return op.emitOpError(
"requires updates operand to have at least 1 dimension");
auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (!tensor_ty || !indices_ty) return success();
int64_t num_index_dims = indices_ty.getShape().back();
if (ShapedType::isDynamic(num_index_dims)) return success();
if (num_index_dims > tensor_ty.getRank())
return op.emitOpError(
"requires tensor operand with rank greater than or equal to the "
"indices operand's last dimensions");
return success();
}
//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TopKV2Op op) {
if (!HasRankAtLeast(op.input(), 1))
return op.emitOpError(
"requires input operand to have at least 1 dimension");
if (!IsOfRankOrUnranked(op.k(), 0))
return op.emitOpError("requires k operand to be 0D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// ToBoolOp
//===----------------------------------------------------------------------===//
namespace {
// If the input to ToBoolOp is a `tensor<i1>`, then the ToBoolOp is an identity
// function and can be removed.
class ToBoolOfZeroDBoolTensor : public OpRewritePattern<ToBoolOp> {
using OpRewritePattern<ToBoolOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToBoolOp op,
PatternRewriter &rewriter) const override {
if (auto type = op.getOperand().getType().dyn_cast<RankedTensorType>()) {
if (type.getRank() == 0 && type.getElementType().isInteger(1)) {
rewriter.replaceOp(op, op.getOperand());
return success();
}
}
return failure();
}
};
} // namespace
void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ToBoolOfZeroDBoolTensor>(context);
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeOp op) {
auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
if (perm_type && perm_type.getRank() != 1) {
return op.emitOpError()
<< "expected perm to be a 1-D Tensor, got perm of rank "
<< perm_type.getRank();
}
if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
return op.emitOpError() << "x should be of the same rank with y, got "
<< "x of rank " << x_type.getRank()
<< ", and y of rank " << y_type.getRank();
}
if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
return success();
}
if (x_type.getRank() != perm_type.getNumElements()) {
return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
<< "equal to the rank of x, got perm of size "
<< perm_type.getNumElements() << ", and x of rank "
<< x_type.getRank();
}
DenseIntElementsAttr attr_perm;
if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
// y.shape[i] should be equal to x.shape[perm[i]]
// for i = [0, 1, ..., rank(x) - 1]
for (auto e : llvm::enumerate(attr_perm)) {
const int64_t y_idx = e.index();
const int64_t y_dim = y_type.getDimSize(y_idx);
const int64_t x_idx = e.value().getSExtValue();
const int64_t x_dim = x_type.getDimSize(x_idx);
if (y_dim != ShapedType::kDynamicSize &&
x_dim != ShapedType::kDynamicSize && y_dim != x_dim) {
return op.emitOpError()
<< "requires y.shape[" << y_idx << "] (" << y_dim << ") "
<< "to be equal to x.shape[perm[" << x_idx << "]] "
<< "(" << x_dim << ")";
}
}
}
return success();
}
// TODO(jpienaar): perm could be optional too.
void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
Value perm) {
auto x_type = x.getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!x_type.hasRank())
return TransposeOp::build(builder, result,
UnrankedTensorType::get(x_type.getElementType()),
x, perm);
// TODO(jpienaar): Handle unknown perm case.
// TODO(jpienaar): Extract utility function.
auto etype = x_type.cast<ShapedType>().getElementType();
DenseIntElementsAttr attr_shape;
if (matchPattern(perm, m_Constant(&attr_shape))) {
llvm::SmallVector<int64_t, 4> const_shape;
if (attr_shape.isSplat()) {
const_shape.assign(
attr_shape.getNumElements(),
x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
} else {
const_shape.reserve(attr_shape.getNumElements());
for (const auto &dim : attr_shape)
const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
}
return TransposeOp::build(
builder, result, RankedTensorType::get(const_shape, etype), x, perm);
}
return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
perm);
}
namespace {
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
if (!const_perm) return {};
auto const_value = const_perm.value();
const auto elements = const_value.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) return {};
}
// TODO(jpienaar): Remove if/when we handle this more generally.
if (op.getType() != op.x().getType()) {
// If the types don't match then only fold if all the operands are in the TF
// dialect.
for (auto user : op.getOperation()->getUsers())
if (user->getDialect() != op.getDialect()) return {};
}
return op.x();
}
OpFoldResult FoldCancellableTranspose(TransposeOp op) {
// Operand is a TransposeOp.
auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
if (!transpose) return {};
// Permutations defined by constant operations.
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
if (!perm0 || !perm1) return {};
// With permutation indices that cancel each other
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
return transpose.x();
}
} // namespace
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (auto folded = FoldIdentityTranspose(*this)) return folded;
if (auto folded = FoldCancellableTranspose(*this)) return folded;
return {};
}
//===----------------------------------------------------------------------===//
// TruncateDivOp
//===----------------------------------------------------------------------===//
void TruncateDivOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TruncateDivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(UnpackOp op) {
auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!value_type) return success();
int64_t value_rank = value_type.getRank();
int64_t axis = op.axis().getSExtValue();
if (axis < -value_rank || axis >= value_rank)
return op.emitOpError("axis attribute must be in the range of [-")
<< value_rank << ", " << value_rank << ')';
axis = GetDimForAxis(axis, value_rank);
int64_t dim_size = value_type.getDimSize(axis);
if (ShapedType::isDynamic(dim_size)) return success();
if (dim_size != op.getNumResults())
return op.emitOpError("result count must be equal to ") << dim_size;
return success();
}
//===----------------------------------------------------------------------===//
// Unsorted segment reduction ops
//===----------------------------------------------------------------------===//
template <class Op>
static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
if (!HasRankAtMost(op.num_segments(), 0))
return op.emitOpError("number of segments should be a 0-D tensor");
auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
auto segment_ids_type =
op.segment_ids().getType().template dyn_cast<RankedTensorType>();
if (data_type && segment_ids_type) {
if (data_type.getRank() < segment_ids_type.getRank())
return op.emitOpError(
"requires segment ids rank to be less than or equal to data's rank");
int index = 0;
for (auto shape_pair :
llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
int64_t segment_id_dim = std::get<0>(shape_pair);
int64_t data_dim = std::get<1>(shape_pair);
if (!ShapedType::isDynamic(segment_id_dim) &&
!ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
return op.emitOpError(
"requires segment ids shape to be a prefix of data shape, "
"but dimension #")
<< index << " differs: " << segment_id_dim << " vs. "
<< data_dim;
++index;
}
}
DenseIntElementsAttr num_segments_attr;
if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
if (num_segments < 0)
return op.emitOpError("num of segments cannot be negative");
}
return success();
}
//===----------------------------------------------------------------------===//
// VarIsInitializedOp
//===----------------------------------------------------------------------===//
namespace {
/// Erase VarIsInitializedOp operations with no uses. This op has side effect on
/// resources (read-only), but can still be deleted if it has zero uses.
struct EraseDeadVarIsInitializedOp
: public OpRewritePattern<VarIsInitializedOp> {
using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(VarIsInitializedOp op,
PatternRewriter &rewriter) const override {
if (!op.use_empty()) return failure();
rewriter.eraseOp(op);
return success();
}
};
} // end anonymous namespace.
void VarIsInitializedOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<EraseDeadVarIsInitializedOp>(context);
}
//===----------------------------------------------------------------------===//
// VariableShapeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(VariableShapeOp op) {
auto input_type = op.input().getType().cast<TensorType>();
if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
return op.emitOpError("requires input to have one resource");
auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
auto subtypes = resource_type.getSubtypes();
switch (subtypes.size()) {
case 1:
return VerifyShapeOperandAndResult(
op, resource_type.getSubtypes().front(), op.getType());
case 0:
return VerifyShapeOperandAndResult(op, Type(), op.getType());
default:
return op.emitOpError(
"requires resource input type to have at most 1 subtype");
}
}
OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
auto resource_type =
getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
if (resource_type.getSubtypes().empty()) return {};
return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(WhileOp op) {
auto cond_fn = op.cond_func();
auto body_fn = op.body_func();
if (!cond_fn) {
return op.emitOpError("cond refers to an undefined function : ")
<< op.cond();
}
if (!body_fn) {
return op.emitOpError("body refers to an undefined function : ")
<< op.body();
}
auto cond_fn_type = cond_fn.getType();
auto body_fn_type = body_fn.getType();
// Verify that the cond function has exactly one result.
if (cond_fn_type.getNumResults() != 1)
return op.emitOpError("requires cond function to have exactly one result");
SmallVector<Type, 4> operands(op.getOperandTypes());
// Collect all the type lists for the op so that different pairs of type lists
// can be compared for the compatibility.
constexpr int kNumTypeLists = 5;
const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
type_lists = {{
{"operand", operands},
{"body function result", body_fn_type.getResults()},
{"result", op.getResultTypes()},
{"cond function input", cond_fn_type.getInputs()},
{"body function input", body_fn_type.getInputs()},
}};
// A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a
// common source of inputs for both. Therefore, the While op requires the
// following pairs of type lists to be cast compatible for the tensor_cast
// operation:
//
// * Operands and cond inputs to call the cond function before the
// first iteration.
// * Operands and body inputs to call the body function for the first
// iteration if the cond functions returns True or equivalent result.
// * Operands and results to assign cond function arguments to op results if
// the cond function returns False or equivalent result.
// * All three pairs using cond inputs, body inputs and results as operand is
// a common source for all three.
// * Body result and cond inputs to call the cond function for the subsequent
// iterations. Similarly, Body result should be compatible with body inputs
// and op results.
//
// Note that the operands and body results need not be compatible as they are
// never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive.
for (int i = 0; i < kNumTypeLists; ++i) {
// Skip the first pair as the While op operands and body function results
// does not need to be compatible with each other.
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
auto &a = type_lists[i];
auto &b = type_lists[j];
int a_size = a.second.size();
if (a_size != b.second.size())
return op.emitOpError(
llvm::formatv("requires the number of {0}s to be equal to the "
"number of {1}s. Found {2} and {3}, respectively",
a.first, b.first, a_size, b.second.size()));
for (int idx = 0; idx < a_size; ++idx) {
auto a_type = a.second[idx];
auto b_type = b.second[idx];
if (!AreCastCompatible({a_type, b_type}))
return op.emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, a_type, b.first, b_type, idx));
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// WhileOp canonicalization.
//===----------------------------------------------------------------------===//
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<DropAttributes<WhileOp>>(context);
}
//===----------------------------------------------------------------------===//
// WhileRegionOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(WhileRegionOp op) {
// Verify that the condition generates a single tensor<i1> result.
YieldOp yield = cast<YieldOp>(op.cond().front().getTerminator());
if (yield.getNumOperands() != 1)
return op.emitOpError()
<< "condition should have a single tensor<i1> result";
auto cond_type = yield.getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!cond_type || !cond_type.getShape().equals({}) ||
!cond_type.getElementType().isInteger(/*width=*/1))
return op.emitOpError()
<< "condition should have a single tensor<i1> result";
// The body result types should match while op result types.
if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure();
// Both condition and body should have same number and type of operands as
// the WhileRegion inputs.
const int num_inputs = op.getNumOperands();
auto block_inputs_match_op_inputs = [&](Region &region,
StringRef name) -> LogicalResult {
Block &block = region.front();
if (block.getNumArguments() != num_inputs)
return op.emitOpError()
<< name << " should have same number of inputs (" << num_inputs
<< ") as " << WhileRegionOp::getOperationName() << " but has "
<< block.getNumArguments() << " inputs";
for (auto types_idx : llvm::enumerate(
llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) {
auto op_input_type = std::get<0>(types_idx.value());
auto block_input_type = std::get<1>(types_idx.value());
if (!AreCastCompatible({block_input_type, op_input_type}))
return op.emitOpError(llvm::formatv(
"{0} input type {1} is incompatible with {2} "
"input type {3} at index {4}",
name, block_input_type, WhileRegionOp::getOperationName(),
op_input_type, types_idx.index()));
}
return success();
};
if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) ||
failed(block_inputs_match_op_inputs(op.body(), "body")))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// WhileRegionOp LoopLikeOpInterface
//===----------------------------------------------------------------------===//
Region &WhileRegionOp::getLoopBody() { return body(); }
bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) {
// If the Op defining the value exists and the defining op is outside the
// scope of this WhileRegion, then we can infer that its defined outside.
// The defining Op is outside the scope of this WhileRegion if this
// WhileRegionOp is not an ancestor of the defining op in the parent chain.
Operation *def_op = value.getDefiningOp();
return def_op && !getOperation()->isAncestor(def_op);
}
LogicalResult WhileRegionOp::moveOutOfLoop(
llvm::ArrayRef<mlir::Operation *> ops) {
// Move the hoisted value to just before the while.
Operation *while_op = this->getOperation();
for (auto op : ops) op->moveBefore(while_op);
return success();
}
//===----------------------------------------------------------------------===//
// WhileRegionOp canonicalization
//===----------------------------------------------------------------------===//
namespace {
// Eliminate values that pass through the WhileRegionOp body.
struct WhileRegionEliminatePassThrough
: public OpRewritePattern<WhileRegionOp> {
using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileRegionOp while_op,
PatternRewriter &rewriter) const override {
// Replace values that simply passthrough the body with extern values. The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
int old_num_operands = while_op.getNumOperands();
int new_num_operands = old_num_operands;
auto &body_block = while_op.body().front();
auto &cond_block = while_op.cond().front();
auto &yield = *body_block.getTerminator();
// Bit mask indicating which operands will be removed.
SmallVector<bool, 16> removed_operand(old_num_operands, false);
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
auto body_arg = body_block.getArgument(op_idx);
if (body_arg == yield.getOperand(op_idx)) {
// Replace the use of the passthrough value with the while operand
// in the body and condition regions, as well as the while output (if
// type match)
// TODO(jurahul): Use PatternRewriter API for IR modification.
auto value = while_op.getOperand(op_idx);
if (body_arg.getType() == value.getType())
body_arg.replaceAllUsesWith(value);
auto cond_arg = cond_block.getArgument(op_idx);
if (cond_arg.getType() == value.getType())
cond_arg.replaceAllUsesWith(value);
auto result = while_op.getResult(op_idx);
if (result.getType() == value.getType())
result.replaceAllUsesWith(value);
}
// Now check if the operand is unused in both regions as well as the
// result. If so, mark it for removal.
if (body_block.getArgument(op_idx).use_empty() &&
cond_block.getArgument(op_idx).use_empty() &&
while_op.getResult(op_idx).use_empty()) {
removed_operand[op_idx] = true;
new_num_operands--;
}
}
if (new_num_operands == old_num_operands) return failure();
// Compress the operands, region arguments, and outputs.
SmallVector<Value, 4> new_while_operands;
SmallVector<Type, 4> new_result_types;
new_while_operands.reserve(new_num_operands);
new_result_types.reserve(new_num_operands);
// Build new operands and result type.
int next_idx = 0;
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
if (removed_operand[op_idx]) continue;
new_while_operands.push_back(while_op.getOperand(op_idx));
new_result_types.push_back(while_op.getResult(op_idx).getType());
next_idx++;
}
// Create the new while operation.
auto new_while_op =
rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
new_while_operands, while_op.getAttrs());
// Move region bodies to the new while.
rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
new_while_op.cond().end());
rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
new_while_op.body().end());
auto &new_cond_block = new_while_op.cond().front();
auto &new_body_block = new_while_op.body().front();
auto &new_yield = *new_body_block.getTerminator();
// Build a vector of new results. Also patch up the region bodies and yield.
SmallVector<Value, 4> new_results;
next_idx = 0;
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
if (removed_operand[op_idx]) {
new_cond_block.eraseArgument(next_idx);
new_body_block.eraseArgument(next_idx);
new_yield.eraseOperand(next_idx);
new_results.push_back(nullptr);
} else {
new_results.push_back(new_while_op.getResult(next_idx++));
}
}
rewriter.replaceOp(while_op, new_results);
return success();
}
};
} // anonymous namespace
void WhileRegionOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<WhileRegionEliminatePassThrough>(context);
}
//===----------------------------------------------------------------------===//
// XdivyOp
//===----------------------------------------------------------------------===//
void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<XdivyWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
} // namespace TF
} // namespace mlir