blob: d3254bc2ddf7a117571418242dbf8ba1dd6f3906 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include <cstdint>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//
TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
: Dialect(/*name=*/"tfl", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
>();
}
//===----------------------------------------------------------------------===//
// Common support logic
//===----------------------------------------------------------------------===//
namespace {
// Returns true if the dimensions in `a` is a suffix of the ones in `b`.
// For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
// {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
if (a.size() > b.size()) return false;
return std::equal(a.rbegin(), a.rend(), b.rbegin());
}
// Returns true if it is a shaped type of f32 elements.
inline bool IsF32ShapedType(Type t) {
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
return shaped_type.getElementType().isF32();
}
return false;
}
// Performs const folding `calculate` with broadcast behavior on the two
// attributes `operand1` and `operand2` and returns the result if possible.
// The two operands are expected to both be scalar values.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto lhs = operand1.cast<AttrElementT>();
auto rhs = operand2.cast<AttrElementT>();
assert(lhs.getType() == result_type && rhs.getType() == result_type &&
"values of incompatible types should be caught by op verification");
// TODO: Need to handle overflow/underflow cases.
return AttrElementT::get(result_type,
calculate(lhs.getValue(), rhs.getValue()));
}
// TODO: We have multiple functions to handle different attriubte kinds in the
// following. Consider add methods to ElementsAttr to unify these functions.
// Performs const folding `calculate` with broadcast behavior on the two
// attributes `operand1` and `operand2` and returns the result if possible.
// This function assumes that both operands are `AttrElementT` attributes.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpSplatSplat(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto type = result_type.cast<ShapedType>();
auto elem_type = type.getElementType();
auto element_result = ConstFoldBinaryOpScalarScalar<AttrElementT>(
elem_type, operand1, operand2, calculate);
if (!element_result) return {};
return DenseElementsAttr::get(type, element_result);
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the first operand is a DenseElementsAttr and the
/// second one is a SplatElementsAttr, and both are verified to have value
/// attributes of broadcastable types.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpDenseSplat(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto lhs = operand1.cast<DenseElementsAttr>();
// TODO(b/139192933): Support broadcast behavior
if (lhs.getType() != result_type || operand2.getType() != result_type)
return {};
auto rhs = operand2.cast<SplatElementsAttr>().getSplatValue();
auto type = result_type.cast<ShapedType>();
SmallVector<ElementValueT, 16> new_values;
new_values.reserve(lhs.rawSize());
// Add the splat value to each of the values in the dense elements
// attribute.
auto rhs_val = rhs.cast<AttrElementT>().getValue();
for (auto old_val : lhs.getValues<ElementValueT>()) {
new_values.push_back(calculate(old_val, rhs_val));
}
return DenseElementsAttr::get(type, new_values);
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the both operands are DenseElementsAttr and verified
/// to have value attributes of broadcastable types.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1,
Attribute operand2,
const CalculationT &calculate) {
auto lhs = operand1.cast<DenseElementsAttr>();
auto rhs = operand2.cast<DenseElementsAttr>();
if (lhs.getType() != rhs.getType()) {
// We only support the case that one of the operand's dimensions are
// a perfect suffix of the other.
// TODO: support the general broadcast behavior.
auto lhs_shape = lhs.getType().getShape();
auto rhs_shape = rhs.getType().getShape();
if (!IsTrailingDimensions(lhs_shape, rhs_shape) &&
!IsTrailingDimensions(rhs_shape, lhs_shape))
return {};
}
auto lhs_num_elements = lhs.getType().getNumElements();
auto rhs_num_elements = rhs.getType().getNumElements();
auto type = result_type.cast<ShapedType>();
auto num_elements = type.getNumElements();
// We assume the arguments have broadcast-compatible types. Make sure again.
assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements);
assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0);
SmallVector<ElementValueT, 16> lhs_old_values(lhs.getValues<ElementValueT>());
SmallVector<ElementValueT, 16> rhs_old_values(rhs.getValues<ElementValueT>());
SmallVector<ElementValueT, 16> new_values;
new_values.reserve(num_elements);
// Add each pair of the corresponding values in the dense elements
// attributes.
for (int i = 0; i < num_elements; ++i) {
// We only support a degenerated case here: the dimensions in one operand's
// shape is a perfect suffix to the other operand. Then conceptually it's
// similar to broadcasting a scalar to a 1-D vector.
// TODO: support the general broadcast behavior.
// We are tiling the operand with less elements an integral times to match
// the operand with more elements. We don't care which operand has less
// elements here because we are iterating its elements in circles, which can
// be achieved using the result index modulo the element count. For the
// operand with more elements, since the result has the same number of
// elements, we are only going over its elements once. The modulo operation
// also works for that.
int lhs_index = i % lhs_num_elements;
int rhs_index = i % rhs_num_elements;
new_values.push_back(
calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index]));
}
return DenseElementsAttr::get(type, new_values);
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the two operands are verified to have value
/// attributes of broadcastable types.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
Attribute operand2, const CalculationT &calculate,
bool is_commutative) {
if (operand1.dyn_cast_or_null<AttrElementT>()) {
// Scalar op scalar case
if (operand2.dyn_cast_or_null<AttrElementT>())
return ConstFoldBinaryOpScalarScalar<AttrElementT>(result_type, operand1,
operand2, calculate);
} else if (auto lhs = operand1.dyn_cast_or_null<SplatElementsAttr>()) {
// Splat op splat case
if (auto rhs = operand2.dyn_cast_or_null<SplatElementsAttr>())
return ConstFoldBinaryOpSplatSplat<AttrElementT>(
result_type, lhs.getSplatValue(), rhs.getSplatValue(), calculate);
// Splat op dense case
if (auto rhs = operand2.dyn_cast_or_null<DenseElementsAttr>()) {
if (is_commutative) {
// Swap the two constant values to fall into the following case
return ConstFoldBinaryOpDenseSplat<AttrElementT>(result_type, operand2,
operand1, calculate);
}
}
} else if (auto lhs = operand1.dyn_cast_or_null<DenseElementsAttr>()) {
// Dense op splat case
if (auto rhs = operand2.dyn_cast_or_null<SplatElementsAttr>())
return ConstFoldBinaryOpDenseSplat<AttrElementT>(result_type, operand1,
operand2, calculate);
// Dense op dense case
if (auto rhs = operand2.dyn_cast_or_null<DenseElementsAttr>())
return ConstFoldBinaryOpDenseDense<AttrElementT>(result_type, operand1,
operand2, calculate);
}
// TODO: support other attribute kinds
return {};
}
/// Performs const folding with broadcast behavior on the two attributes in
/// `operands` and returns the result if possible.
/// Depending on the given `resultType`, either `floatCalculate` or
/// `intCalculate` is chosen to conduct the calculate.
Attribute ConstFoldBinaryOp(
Type result_type, ArrayRef<Attribute> operands,
llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
llvm::function_ref<APInt(APInt, APInt)> int_calculate,
bool is_commutative) {
// Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
// represented as tensor<f32>. So we are only handling tensor types here.
auto type = result_type.dyn_cast<ShapedType>();
if (!type) return {};
auto elemType = type.getElementType();
if (elemType.isa<FloatType>())
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
float_calculate, is_commutative);
if (elemType.isa<IntegerType>())
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
int_calculate, is_commutative);
return {};
}
/// Performs const folding a attributes `operand` and returns the result if
/// possible.
/// The function currently asserts that the `result_type` to be a f32 tensor
/// type.
/// TODO: Extend this function to handle integral tensor for ops like
/// "tfl.logical_not".
Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
llvm::function_ref<APFloat(APFloat)> calculate) {
assert(IsF32ShapedType(result_type));
auto result_shape_type = result_type.cast<ShapedType>();
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
SmallVector<APFloat, 16> new_values;
const int num_elements = result_shape_type.getNumElements();
new_values.reserve(num_elements);
for (APFloat old_value : dense_elements.getValues<APFloat>()) {
new_values.push_back(calculate(old_value));
}
return DenseElementsAttr::get(result_shape_type, new_values);
}
return {};
}
void buildComparisonBinOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
if (!result_type)
emitError(result->location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
result->addOperands({lhs, rhs});
// Comparison binary ops always return i1 tensor.
if (auto shaped_type = result_type.dyn_cast<ShapedType>()) {
auto resultShape = shaped_type.getShape();
result->types.push_back(
builder->getTensorType(resultShape, builder->getI1Type()));
} else {
result->types.push_back(builder->getTensorType(builder->getI1Type()));
}
}
void buildFusedBroadcastableBinOp(Builder *builder, OperationState *result,
Value *lhs, Value *rhs,
StringAttr fused_activation_function) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
if (!result_type)
emitError(result->location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
result->addOperands({lhs, rhs});
result->addAttribute("fused_activation_function", fused_activation_function);
result->types.push_back(result_type);
}
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a + b; },
[](APInt a, APInt b) { return a + b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// ConcatenationOp
//===----------------------------------------------------------------------===//
// TODO(ashwinm): Implement shape inference for Concatenation
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
static void BuildGatherOp(Builder *builder, OperationState *result,
Value *params, Value *indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank())
return TFL::GatherOp::build(
builder, result, builder->getTensorType(params_type.getElementType()),
params, indices, axis);
int64_t params_rank = params_type.getRank();
int64_t indices_rank = indices_type.getRank();
// params rank is guaranteed to be at least 1.
// Produces an output tensor with shape:
// params.shape[:axis] + indices.shape + params.shape[axis + 1:]
std::vector<int64_t> shape(params_type.getShape());
int64_t axis_i = axis.getInt();
// For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
if (axis_i < 0) {
axis_i += params_rank;
}
// params must be atleast rank axis + 1
if (params_rank < axis_i + 1) {
emitError(result->location, "params must be atleast rank axis + 1");
}
if (indices_rank == 0) {
// Scalar indices (output is rank(params) - 1).
// Erase shape[axis]
shape.erase(shape.begin() + axis_i);
} else if (indices_rank == 1) {
// Vector indices (output is rank(params)).
// Copy indices.shape into params.shape[axis]
std::copy(std::begin(indices_type.getShape()),
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
} else {
// Higher rank indices (output is rank(params) + rank(indices) - 1).
shape.resize(params_rank + indices_rank - 1);
// Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
std::copy(std::begin(params_type.getShape()) + axis_i + 1,
std::end(params_type.getShape()),
std::begin(shape) + axis_i + indices_rank);
// Copy indices.shape into params.shape[axis]
std::copy(std::begin(indices_type.getShape()),
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
}
TFL::GatherOp::build(
builder, result,
builder->getTensorType(shape, params_type.getElementType()), params,
indices, axis);
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a * b; },
[](APInt a, APInt b) { return a * b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
// TODO(b/133486129): Implement shape inference for pack
static LogicalResult Verify(PackOp op) {
// TODO(antiagainst): Implement other checks as in
// tensorflow/lite/kernels/pack.cc
if (op.getOperation()->getNumOperands() != op.values_count())
return op.emitOpError("input count should match 'values_count' attribute");
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
namespace {
/// This pattern matches and merges a tfl.reshape under the following
/// condition:
/// * The input's defining op is another tfl.reshape.
// TODO(antiagainst): This pattern probably should be moved to the peephole
// category, after we have the infra for peephole passes.
struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshape(MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand()->getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand()->getDefiningOp());
// Replace
// %1 = "tfl.reshape"(%0)
// %2 = "tfl.reshape"(%1)
// With
// %2 = "tfl.reshape"(%0)
rewriter.replaceOpWithNewOp<ReshapeOp>(
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand());
}
};
} // end anonymous namespace
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape.
if (getType() == getOperand()->getType()) return getOperand();
// Constant folding
assert(operands.size() == 1);
if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
auto result_shape_type = getType().cast<ShapedType>();
return dense_elements.reshape(result_shape_type);
}
return nullptr;
}
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RemoveAdjacentReshape>(context);
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a - b; },
[](APInt a, APInt b) { return a - b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// TopKOp
//===----------------------------------------------------------------------===//
static void BuildTopKOp(Builder *builder, OperationState *result, Value *input,
Value *k) {
// Output size is only known if k is constant value. A negative dimension is
// considered dynamic so use -1 here if k is not a constant value.
int const_k = -1;
ElementsAttr cst;
if (matchPattern(k, m_Constant(&cst)))
// These casts should all be valid due to how Tensor constants are stored.
// TODO(jpienaar): This should use a helper function.
const_k = cst.getValue({}).cast<IntegerAttr>().getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!val_type.hasRank())
return TFL::TopKV2Op::build(
builder, result, builder->getTensorType(val_type.getElementType()),
builder->getTensorType(builder->getIntegerType(32)), input, k);
// Resultant shape is value.shape[:-1] + [k]
std::vector<int64_t> shape(val_type.getShape());
shape[shape.size() - 1] = const_k;
TFL::TopKV2Op::build(
builder, result, builder->getTensorType(shape, val_type.getElementType()),
builder->getTensorType(shape, builder->getIntegerType(32)), input, k);
}
//===----------------------------------------------------------------------===//
// FakeQuantOp
//===----------------------------------------------------------------------===//
// Return true if the op has non-empty "minmax" attribute.
static inline bool HasValidMinMaxAttribute(Operation *op) {
auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
return minmax && minmax.getValue().size() == 2;
}
namespace {
/// This pattern matches and remove a tfl.fake_quant if all the users of this op
/// and itself have "minmax" attribute set.
struct DropFakeQuant : public RewritePattern {
explicit DropFakeQuant(MLIRContext *context)
: RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
// We only match the op with valid "minmax" attribute.
if (!HasValidMinMaxAttribute(op)) return matchFailure();
// If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
// Replace the matched FakeQuantOp by its primiary operand.
rewriter.replaceOp(op, op->getOperand(0));
}
};
} // end anonymous namespace
void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<DropFakeQuant>(context);
}
//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
// TODO(b/133486129): Implement shape inference for unpack
static LogicalResult Verify(UnpackOp op) {
// TODO(antiagainst): Implement other checks as in
// tensorflow/lite/kernels/unpack.cc
if (op.getOperation()->getNumResults() != op.num())
return op.emitOpError("output count should match 'num' attribute");
return success();
}
//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//
// TODO(b/133854225): Implement shape inference to Mean
//===----------------------------------------------------------------------===//
// LSTMOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(LSTMOp op) {
auto operands = op.GetStatefulOperands();
if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
return success();
}
return op.emitError("LSTMOp expected to have two stateful operands");
}
//===----------------------------------------------------------------------===//
// UnidirectionalSequenceLSTMOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
auto operands = op.GetStatefulOperands();
if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
return success();
}
return op.emitError(
"UnidirectionalSequenceLSTMOp expected to have two stateful operands");
}
//===----------------------------------------------------------------------===//
// UnidirectionalSequenceRNNOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(UnidirectionalSequenceRNNOp op) {
auto operands = op.GetStatefulOperands();
if (operands.size() == 1 && operands[0] == 4) {
return success();
}
return op.emitError(
"UnidirectionalSequenceRNNOp expected to have one stateful operand");
}
//===----------------------------------------------------------------------===//
// AbsOp
//===----------------------------------------------------------------------===//
OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SinOp
//===----------------------------------------------------------------------===//
OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat {
float f = value.convertToFloat();
float result = std::sin(f);
return APFloat(result);
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// CosOp
//===----------------------------------------------------------------------===//
OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat {
float f = value.convertToFloat();
float result = std::cos(f);
return APFloat(result);
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//
OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat {
float f = value.convertToFloat();
float result = std::log(f);
return APFloat(result);
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SqrtOp
//===----------------------------------------------------------------------===//
OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat {
float f = value.convertToFloat();
float result = std::sqrt(f);
return APFloat(result);
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// RsqrtOp
//===----------------------------------------------------------------------===//
OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat {
float f = value.convertToFloat();
float result = 1.f / std::sqrt(f);
return APFloat(result);
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SquareOp
//===----------------------------------------------------------------------===//
OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return value * value; };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1);
auto result_type = getType().cast<ShapedType>();
if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
return DenseElementsAttr::get(result_type, {rank});
}
// Also fold if `input` has a known rank.
auto input_type = input()->getType().cast<ShapedType>();
// Do not fold if rank is zero because the TFLite converter doesn't
// distinguish between unranked input and scalar input due to b/138865275.
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
// predicate and fold the op when rank is zero.
if (input_type.hasRank() && input_type.getRank() != 0) {
auto rank = static_cast<int32_t>(input_type.getRank());
return DenseElementsAttr::get(result_type, {rank});
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// ConstOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
}
//===----------------------------------------------------------------------===//
// RangeOp
//===----------------------------------------------------------------------===//
namespace {
// Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
// Template parameter `FloatOrInt` must be standard C integer or floating-point
// types.
template <typename FloatOrInt>
int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
// Refer to the implementation in
// tensorflow/lite/kernels/range.cc.
return std::is_integral<FloatOrInt>::value
? ((std::abs(limit - start) + std::abs(delta) - 1) /
std::abs(delta))
: std::ceil(std::abs((limit - start) / delta));
}
// Builds a constant range tensor of `result_elem_type` elements.
// Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
// mlir::FloatAttr.
template <typename FloatOrIntAtrr>
DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
FloatOrIntAtrr start_attr,
FloatOrIntAtrr delta_attr) {
using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
ValueType start = start_attr.getValue();
ValueType delta = delta_attr.getValue();
SmallVector<ValueType, 16> new_values;
new_values.reserve(num_elements);
ValueType new_value = start;
for (int i = 0; i < num_elements; ++i) {
new_values.push_back(new_value);
new_value = new_value + delta;
}
// Result is always a 1-D tensor.
auto new_result_type =
RankedTensorType::get({num_elements}, result_elem_type);
return DenseElementsAttr::get(new_result_type, new_values);
}
} // namespace
OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 3);
auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
if (start_tensor && limit_tensor && delta_tensor) {
// Operands should all be scalars
assert(start_tensor.getType().getRank() == 0 &&
limit_tensor.getType().getRank() == 0 &&
delta_tensor.getType().getRank() == 0);
Type elem_type = getType().cast<ShapedType>().getElementType();
if (elem_type.isa<IntegerType>()) {
auto start_attr = start_tensor.getValue({}).cast<IntegerAttr>();
auto limit_attr = limit_tensor.getValue({}).cast<IntegerAttr>();
auto delta_attr = delta_tensor.getValue({}).cast<IntegerAttr>();
const int num_elements = GetLengthOfRange(
start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
return BuildConstRangeTensor(elem_type, num_elements, start_attr,
delta_attr);
} else if (elem_type.isa<FloatType>()) {
auto start_attr = start_tensor.getValue({}).cast<FloatAttr>();
auto limit_attr = limit_tensor.getValue({}).cast<FloatAttr>();
auto delta_attr = delta_tensor.getValue({}).cast<FloatAttr>();
const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
limit_attr.getValueAsDouble(),
delta_attr.getValueAsDouble());
return BuildConstRangeTensor(elem_type, num_elements, start_attr,
delta_attr);
}
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
namespace {
// Computes the permutation of a constant `input_tensor` according to `perm`.
// The function recursively traverses the dimensions of the output tensor in
// a row-major order and writes the value in the output tensor into
// `new_values`.
void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
ArrayRef<int64_t> output_shape, int num_dimensions,
int output_axis, std::vector<uint64_t> *input_indices,
std::vector<Attribute> *new_values) {
// Refer to the implementation of `Transpose` function in
// tensorflow/lite/kernels/internal/reference/reference_ops.h
assert(output_axis < num_dimensions);
const int input_axis = perm[output_axis];
for (int i = 0; i < output_shape[output_axis]; ++i) {
// Update the input indices on `input_axis`.
input_indices->at(input_axis) = i;
// Write the value from `input_tensor` if it is the last axis or
// recurse into the next axis.
const bool is_last_axis = output_axis == num_dimensions - 1;
if (is_last_axis) {
new_values->push_back(input_tensor.getValue(*input_indices));
} else {
ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
output_axis + 1, input_indices, new_values);
}
}
}
} // namespace
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
if (!input_tensor || !perm_tensor) return nullptr;
// Do not try to fold elements attr of a quant type because
// DenseElementsAttr does not support it.
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
return nullptr;
assert(perm_tensor.getType().getRank() == 1);
const int num_dimensions = input_tensor.getType().getRank();
assert(perm_tensor.getType().getNumElements() == num_dimensions);
ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
auto output_type = getType().cast<ShapedType>();
SmallVector<int32_t, 4> perm;
SmallVector<int64_t, 4> output_shape;
for (int i = 0; i < num_dimensions; ++i) {
perm.push_back(perm_tensor.getValue({static_cast<uint64_t>(i)})
.cast<IntegerAttr>()
.getInt());
output_shape.push_back(input_shape[perm[i]]);
// Check that the derived output shape matches the static shape.
assert(!output_type.hasStaticShape() ||
output_type.getShape()[i] == output_shape[i]);
}
std::vector<Attribute> new_values;
new_values.reserve(input_tensor.getType().getNumElements());
std::vector<uint64_t> input_indices(num_dimensions);
ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
/*output_axis=*/0, &input_indices, &new_values);
auto result_type =
RankedTensorType::get(output_shape, output_type.getElementType());
return DenseElementsAttr::get(result_type, new_values);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type, Location loc) {
// If this is an opaque elements attribute or the result type doesn't match
// the attribute type, then generate a tfl.pseudo_const.
if (value.isa<OpaqueElementsAttr>() ||
(value.isa<ElementsAttr>() && value.getType() != type))
return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
} // namespace TFL
} // namespace mlir