blob: a1cac6fcc5bc4453f337b9d84be30210dba91cd7 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass takes operations in TensorFlowLite dialect and
// optimizes them to resulting operations in TensorFlowLite dialect.
#include <algorithm>
#include <climits>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <numeric>
#include <utility>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
//===----------------------------------------------------------------------===//
// The actual Optimize Pass.
namespace {
constexpr char kRelu[] = "RELU";
constexpr char kRelu6[] = "RELU6";
constexpr char kRelu1[] = "RELU_N1_TO_1";
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
*axis.getValues<int>().begin() ||
*axis.getValues<int>().begin() == -1) {
return true;
}
if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
return false;
}
auto shape = sq_op.getType().cast<ShapedType>();
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
axis.getValues<int>().end()};
for (int i = 0; i < shape.getRank(); ++i) {
if (i != elems[i]) return false;
}
return true;
}
using ::llvm::cast;
// Optimize TFLite operations in functions.
class OptimizePass : public PassWrapper<OptimizePass, FunctionPass> {
public:
OptimizePass() = default;
OptimizePass(const OptimizePass &) {}
explicit OptimizePass(bool enable_canonicalization) {
enable_canonicalization_ = enable_canonicalization;
}
void runOnFunction() override;
private:
Option<bool> enable_canonicalization_{
*this, "enable-canonicalization",
llvm::cl::desc("Enable canonicalization during optimization pass."),
llvm::cl::init(false)};
};
// Returns whether the given type `a` is broadcast-compatible with `b`.
bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
return OpTrait::util::getBroadcastedType(a, b) != Type();
}
// Returns whether the resultant type of any broadcastable operation with
// operands `a` and `b` matches `expected_output`. Returns false if `a` is not
// broadcast-compatible with `b`.
bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
Type output_element_type =
expected_output.cast<ShapedType>().getElementType();
Type broadcasted_type =
OpTrait::util::getBroadcastedType(a, b, output_element_type);
return broadcasted_type != Type() && broadcasted_type == expected_output;
}
// Returns whether if `type1` dimensions are the same as the ending dimensions
// of `type2`. This is more restricted than broadcastable.
bool IsTailOfShape(Type type1, Type type2) {
auto tail_type = type1.dyn_cast<ShapedType>();
auto full_type = type2.dyn_cast<ShapedType>();
if (!tail_type || !full_type || !tail_type.hasRank() ||
!full_type.hasRank() || tail_type.getRank() > full_type.getRank())
return false;
auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
auto i2 = full_type.getShape().rbegin();
return std::equal(i1, e1, i2);
}
bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
const ArrayRef<int64_t> elements_shape,
bool is_depthwise) {
// Make sure the val tensor has shape where all dimensions are 1 except
// last one.
// Also, val tensor must be of rank 1 or 4 or 0 (scalar).
const auto elements_rank = elements_shape.size();
for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
if (elements_shape[i] != 1) return false;
}
if (elements_rank != 1 && elements_rank != 0 && elements_rank != 4) {
return false;
}
auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
// If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
// can let binary op to broadcast elements.
if (elements_depth == 1) {
return true;
}
// In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
// For conv:
// Check if last dimension in filter equals the first dimension
// For depthwise conv:
// Check if the first in filter dimension equals the first dimension.
if (filter_shape.empty() ||
(is_depthwise ? filter_shape.back() != elements_depth
: filter_shape[0] != elements_depth))
return false;
return true;
}
bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
bool is_depthwise) {
const auto elements = val.dyn_cast<DenseElementsAttr>();
if (!elements) {
return false;
}
const auto elements_shape = elements.getType().getShape();
const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
is_depthwise);
}
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
bool is_depthwise) {
if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
return CanFuseConvOrDepthwiseConvShapes(
filter_elements.getType().getShape(), elements.getType().getShape(),
is_depthwise);
}
}
return false;
}
// Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
// of `indices` are from 0 to n-1, the output tensor are identical to the
// `params`.
bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
DenseIntElementsAttr indices) {
auto params_type = params.getType().dyn_cast<RankedTensorType>();
auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
// Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
// `indices` means it gets the first row of `params`. As long as indices
// iterate the first row of `params`, the output is identical to input.
if (!params_type || !indices_type || indices_type.getRank() != 2 ||
indices_type.getDimSize(0) != params_type.getDimSize(0) ||
indices_type.getDimSize(1) != 1)
return false;
// Checks the value in `indices` is from 0 to n-1.
int cur_value = 0;
for (const auto &v : indices.getValues<APInt>()) {
if (v.getSExtValue() != cur_value) return false;
++cur_value;
}
return true;
}
// Returns true if we can eliminate the SliceOp. When the values of `begin` are
// all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
// for each dim i, the output tensor is identical to `input`.
bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
// Checks if `begin` and `size` are i32 or i64.
auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
if (!begin_attr || !size_attr) {
return false;
}
auto begin_elem_ty = begin_attr.getType().getElementType();
if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
return false;
}
auto size_elem_ty = size_attr.getType().getElementType();
if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
return false;
}
// Checks if `input` is ranked and its rank is equal to number of elements in
// `begin` and `size`.
auto input_ty = input.getType().cast<ShapedType>();
if (!input_ty.hasRank()) {
return false;
}
int64_t rank = input_ty.getRank();
if (rank != begin_attr.getNumElements() ||
rank != size_attr.getNumElements()) {
return false;
}
// Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
// `input.shape[i]`.
for (uint64_t i = 0; i < rank; ++i) {
if (begin_attr.getValue<APInt>({i}).getSExtValue() != 0) return false;
int64_t si = size_attr.getValue<APInt>({i}).getSExtValue();
if (si != -1 && si != input_ty.getDimSize(i)) return false;
}
return true;
}
// Expand Attribute 'a' to 4D with all 1s except 1 dimension.
// Which dimension depends on 'is_depthwise' is true or false.
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
auto elements = a.dyn_cast<DenseElementsAttr>();
auto shape = elements.getType().getShape();
if (!shape.empty()) {
// Checks that elements are essentially 1d.
assert(elements.getNumElements() == shape.back());
}
std::vector<int64_t> shape_data = {1, 1, 1, 1};
const int vector_length = elements.getNumElements();
if (is_depthwise)
shape_data[3] = vector_length;
else
shape_data[0] = vector_length;
auto new_shape =
RankedTensorType::get(shape_data, elements.getType().getElementType());
return elements.reshape(new_shape);
}
ElementsAttr ExpandTo4DForConv(Attribute a) {
return ExpandTo4DForConvImpl(a, false);
}
ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
return ExpandTo4DForConvImpl(a, true);
}
TypeAttr RescaleQtype(Type input, Attribute factor) {
return quant::RescaleQuantizedType(input, factor);
}
// Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor.
DenseElementsAttr GetShape(Value output_val) {
auto output_type = output_val.getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();
std::vector<int32_t> shape;
shape.reserve(shape_vector.size());
for (auto shape_object : shape_vector) {
shape.push_back(shape_object);
}
return mlir::DenseElementsAttr::get(
RankedTensorType::get(
{static_cast<int>(shape.size())},
mlir::IntegerType::get(output_val.getContext(), 32)),
llvm::makeArrayRef(shape));
}
static Type GetShapeStrippedType(TypeAttr type_attr) {
auto type = type_attr.getValue();
auto shaped_type = type.dyn_cast<ShapedType>();
if (shaped_type) {
return shaped_type.getElementType();
} else {
return type;
}
}
// Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
// the specified `shape` and `false` otherwise.
static bool ShapeMatchesReduceWithKeepAxes(Value input,
const mlir::Attribute &axes,
const mlir::Attribute &shape) {
RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
if (!type) return false;
DenseIntElementsAttr axes_attr =
axes.dyn_cast_or_null<DenseIntElementsAttr>();
DenseIntElementsAttr shape_attr =
shape.dyn_cast_or_null<DenseIntElementsAttr>();
if (!axes_attr || !shape_attr) return false;
if (shape_attr.getNumElements() != type.getRank()) return false;
llvm::SmallSet<uint64_t, 4> axes_set;
for (auto a : axes_attr.getIntValues()) {
axes_set.insert(a.getZExtValue());
}
auto type_shape = type.getShape();
for (uint64_t i = 0; i < type.getRank(); ++i) {
if (axes_set.contains(i)) {
if (shape_attr.getValue<APInt>({i}) != 1) return false;
} else {
if (shape_attr.getValue<APInt>({i}) != type_shape[i]) return false;
}
}
return true;
}
static bool FloatValueEquals(const Attribute &attr, double value) {
auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
if (!fp_attr) return false;
if (fp_attr.isSplat()) {
return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
}
return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
return f.isExactlyValue(value);
});
}
// Returns true if the value's element type is F32.
bool IsF32Value(Value value) {
return value.getType().cast<ShapedType>().getElementType().isF32();
}
// Returns the number of elements in attr if it is a DenseElementsAttr, 1
// otherwise, as an unranked int32 Attribute.
Attribute GetNumElementsOrOne(Attribute attr) {
const auto dense_attr = attr.dyn_cast_or_null<DenseElementsAttr>();
int32_t num_elements = dense_attr ? dense_attr.getNumElements() : 1;
OpBuilder builder(attr.getContext());
return DenseIntElementsAttr::get(
RankedTensorType::get({}, builder.getI32Type()),
{llvm::APInt(32, num_elements, true)});
}
// Returns true if attr is a DenseIntElementsAttr with the last element equal 1.
bool IsLastElementEqualsOne(Attribute attr) {
const auto ints = attr.dyn_cast_or_null<DenseIntElementsAttr>();
if (!ints) return false;
if (ints.empty()) return false;
const auto last_element_index = ints.getNumElements() - 1;
const auto iterator = ints.getIntValues().begin();
const APInt last_element = iterator[last_element_index];
return last_element == 1;
}
// Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an
// incrementing sequence from 0 to N-1.
//
// If such a value is used in an Equal operator, it can be replaced with OneHot.
bool IsOneHotIndexAttribute(Attribute attr) {
const auto dense_attr = attr.dyn_cast_or_null<DenseIntElementsAttr>();
if (!dense_attr) {
return false;
}
auto index_type = dense_attr.getType();
const auto index_elem_bits = index_type.getElementTypeBitWidth();
if (index_elem_bits != 32 && index_elem_bits != 64) {
return false;
}
if (index_type.getRank() != 1) {
return false;
}
const auto elems = dense_attr.getIntValues().begin();
for (int i = 0; i < dense_attr.getNumElements(); ++i) {
if (i != elems[i]) {
return false;
}
}
return true;
}
// Converts an Attribute with a single value of float or integral type to an
// Attribute holding a single value of float type. If attr has no elements, the
// result is 0.0f.
Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
if (dense_fp_attr) {
// Already float => return
return dense_fp_attr;
}
OpBuilder builder(attr.getContext());
const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
const auto int_values = dense_int_attr.getIntValues();
float float_val = 0.0f;
if (!int_values.empty()) {
const APInt apint_val = *int_values.begin();
if (dense_int_attr.getType().getElementType().isSignedInteger()) {
// Get the sign-extended value (=>int64) if the type is signed.
float_val = apint_val.getSExtValue();
} else {
// Get the zero-extended value (=>uint64) if unsigned or signless.
float_val = apint_val.getZExtValue();
}
}
return DenseFPElementsAttr::get(
RankedTensorType::get({}, builder.getF32Type()),
{llvm::APFloat(float_val)});
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with proceeding FullyConnected.
// TODO(b/136285429): Move to tablegen when variadic is supported
struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::AddOp add_op,
PatternRewriter &rewriter) const override {
// Match Add.
DenseElementsAttr added_value;
Value constant_val = add_op.rhs();
if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
// Match Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
if (!fc_op) return failure();
// Check if the constant RHS is either 0D (scalar), or a 1D with
// `{num_channels}` shape.
auto constant_val_type = constant_val.getType().cast<TensorType>();
// In TFLite FullyConnect definition, bias must be a 1D tensor where
// the number of elements is equal to the number of channels.
// If it's not 1D or 0D (which can be broadcasted to 1D), reject the
// matching.
bool is_scalar_rhs = false;
if (constant_val_type.getRank() == 0) {
is_scalar_rhs = true;
} else if (constant_val_type.getRank() != 1) {
return failure();
}
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr bias_value;
const bool is_none_bias = bias.getType().isa<NoneType>();
if (fc_op.fused_activation_function() != "NONE") return failure();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
return failure();
// Rewrite
if (is_none_bias) {
if (is_scalar_rhs) {
// If the `constant_val` is scalar, we must the shape of filter
// to properly broadcast the scalar to `{num_channels}` shape.
// Get the number of channels if possible.
auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
// Filter must be a `2D` tensor with `{num_channels, num_features}`
// shape. The following check is rejecting unknown rank (-1).
if (filter_type == nullptr || filter_type.getRank() != 2) {
return failure();
}
int num_channels = filter_type.getShape()[0];
// Create a zero tensor with shape {num_channels}, and the type need to
// be the same as constant_val.
// This is a way to gracefully handle scalar tensor. The Add will always
// be constant-folded away regardless if `constant_val` is a scalar or
// not.
RankedTensorType type = RankedTensorType::get(
{num_channels}, constant_val_type.getElementType());
auto attr = rewriter.getZeroAttr(type);
bias = rewriter.create<ConstantOp>(add_op.getLoc(), type, attr);
auto none_af = rewriter.getStringAttr("NONE");
bias =
rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
.output();
} else {
// If there no pre-existing bias and the `constant_val` is 1D, simply
// use `constant_val` as bias.
bias = constant_val;
}
} else {
auto none_af = rewriter.getStringAttr("NONE");
bias =
rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
.output();
}
auto fc = rewriter.create<TFL::FullyConnectedOp>(
FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), add_op.getLoc()}),
add_op.getType(),
/*input=*/fc_op.input(),
/*filter=*/filter,
/*bias=*/bias,
/*fused_activation_function=*/
rewriter.getStringAttr(add_op.fused_activation_function()),
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
rewriter.replaceOp(add_op, fc.output());
return success();
}
};
// Replace ..
// FC(Add(lhs, rhs), filter, bias)
// .. with ..
// FC(lhs, filter, FC(rhs, filter, bias))
// .. if rhs, filter, and bias are all constants.
// The second FC will be constant folded to a single vector.
// TODO(b/136285429): Move to tablegen when variadic is supported
struct FuseAddAndFullyConnected
: public OpRewritePattern<TFL::FullyConnectedOp> {
using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
PatternRewriter &rewriter) const override {
// This only works with default format.
if (fc_op.weights_format() != "DEFAULT") return failure();
// Match Add.
auto add_op = dyn_cast_or_null<TFL::AddOp>(fc_op.input().getDefiningOp());
if (!add_op) return failure();
if (add_op.fused_activation_function() != "NONE") return failure();
// Don't match adds where the added constant is not 1D.
{
auto addend_shape = add_op.rhs().getType().cast<ShapedType>();
if (!addend_shape.hasStaticShape()) return failure();
if (addend_shape.getShape().size() != 1) return failure();
}
// Calculate new bias. Generate a new FC; it will be constant folded.
auto old_bias = fc_op.bias();
if (!old_bias || old_bias.getType().isa<NoneType>()) {
// TODO(b/180752069): Figure out new bias' type when old bias is empty.
return failure();
}
// The FC relies on constant folding, which is implemented on F32. Checks
// types to be F32.
{
if (!IsF32Value(add_op.rhs()) || !IsF32Value(fc_op.filter()) ||
!IsF32Value(old_bias))
return failure();
}
auto new_bias = rewriter.create<TFL::FullyConnectedOp>(
fc_op.getLoc(), old_bias.getType(),
/*input=*/add_op.rhs(),
/*filter=*/fc_op.filter(),
/*bias=*/old_bias,
/*fused_activation_function=*/rewriter.getStringAttr("NONE"),
/*weights_format=*/rewriter.getStringAttr("DEFAULT"),
/*keep_num_dims=*/rewriter.getBoolAttr(true));
// Create the updated FC.
auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
FusedLoc::get(add_op.getContext(), {add_op.getLoc(), fc_op.getLoc()}),
fc_op.output().getTypes(),
/*input=*/add_op.lhs(),
/*filter=*/fc_op.filter(),
/*bias=*/*new_bias.output().begin(),
/*fused_activation_function=*/
rewriter.getStringAttr(fc_op.fused_activation_function()),
/*weights_format=*/rewriter.getStringAttr("DEFAULT"),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
return success();
}
};
// Replace ..
// FC(Mul(lhs, rhs), filter, bias)
// .. with ..
// FC(lhs, Mul(filter, rhs), bias)
// .. if rhs, filter, and bias are all constants.
// The generated Mul will be constant folded to a single matrix.
struct FuseMulAndFullyConnected
: public OpRewritePattern<TFL::FullyConnectedOp> {
using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
PatternRewriter &rewriter) const override {
// This only works with default format.
if (fc_op.weights_format() != "DEFAULT") return failure();
// Match Mul.
auto mul_op = dyn_cast_or_null<TFL::MulOp>(fc_op.input().getDefiningOp());
if (!mul_op) return failure();
if (mul_op.fused_activation_function() != "NONE") return failure();
// Don't match muls where the multiplier constant is not 1D.
{
auto multiplier_shape = mul_op.rhs().getType().cast<ShapedType>();
if (!multiplier_shape.hasStaticShape()) return failure();
if (multiplier_shape.getShape().size() != 1) return failure();
}
// We rely on constant folding, implemented only for F32. Check types.
if (!IsF32Value(mul_op.rhs()) || !IsF32Value(fc_op.filter())) {
return failure();
}
auto location =
FusedLoc::get(mul_op.getContext(), {mul_op.getLoc(), fc_op.getLoc()});
auto new_filter = rewriter.create<TFL::MulOp>(
location,
/*lhs=*/fc_op.filter(),
/*rhs=*/mul_op.rhs(),
/*fused_activation_function=*/rewriter.getStringAttr("NONE"));
// Create the updated FC.
auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
location, fc_op.output().getTypes(),
/*input=*/mul_op.lhs(),
/*filter=*/new_filter,
/*bias=*/fc_op.bias(),
/*fused_activation_function=*/
rewriter.getStringAttr(fc_op.fused_activation_function()),
/*weights_format=*/rewriter.getStringAttr("DEFAULT"),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
return success();
}
};
// TODO(b/136285429): Move to tablegen when variadic is supported.
template <typename ReluXOp, char const *Act>
struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
using OpRewritePattern<ReluXOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReluXOp relu_op,
PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
auto fully_connected_op = cast<FullyConnectedOp>(input);
if (fully_connected_op.fused_activation_function() != "NONE")
return failure();
auto new_activation_func = rewriter.getStringAttr(Act);
auto new_weights_format =
rewriter.getStringAttr(fully_connected_op.weights_format());
auto new_keep_num_dims =
rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
auto fc = rewriter.create<FullyConnectedOp>(
FusedLoc::get(relu_op.getContext(),
{fully_connected_op.getLoc(), relu_op.getLoc()}),
relu_op.getType(), fully_connected_op.input(),
fully_connected_op.filter(), fully_connected_op.bias(),
new_activation_func, new_weights_format, new_keep_num_dims);
rewriter.replaceOp(relu_op, fc.output());
return success();
}
};
// Fuse Mul with proceeding FullyConnected.
// TODO(b/136285429): Move to tablegen when variadic is supported
struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
// If we are broadcasting on the lhs then don't fold the multiply as it
// would increase the amount of compute done by the fully connected op.
if (mul_op.lhs().getType() != mul_op.getType()) return failure();
// Mul.
DenseElementsAttr cst;
Value constant_val = mul_op.rhs();
if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
if (!fc_op) return failure();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp)))
return failure();
if (fc_op.fused_activation_function() != "NONE") return failure();
// Only fuse multiplier if all dimensions other than the depth dimension
// are equal to 1 since otherwise
// `matmul(x, filter) * cst != matmul(x, filter * cst)`
// even if `filter` and `cst` are be broadcastable.
auto shape = cst.getType().getShape();
if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
// Expand and transpose the multiplier since weights are using the
// OHWI data format in TFLite.
int64_t normalized_shape[2] = {element_size, 1};
auto new_cst = cst.reshape(RankedTensorType::get(
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return failure();
}
auto new_op =
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
Value new_const_val = new_op.getResult();
// Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
// TF::MulOp is used to fold the constant.
// TODO(b/139192933): switch to the TFL constant folding
auto new_filter =
rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
// If bias isn't None, it needs to be multiplied as well.
if (!bias.getType().isa<NoneType>()) {
bias =
rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
}
auto fc = rewriter.create<TFL::FullyConnectedOp>(
FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), mul_op.getLoc()}),
mul_op.getType(),
/*input=*/fc_op.input(),
/*filter=*/new_filter,
/*bias=*/bias,
/*fused_activation_function=*/
rewriter.getStringAttr(mul_op.fused_activation_function()),
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
rewriter.replaceOp(mul_op, fc.output());
return success();
}
};
// Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
// following table gen implementation, which doesn't derived the result type of
// the TFL_DequantizeOp.
// def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
// (TFL_DequantizeOp (TFL_QuantizeOp
// (ConstantOp F32ElementsAttr:$filter), $qtype)),
// (ConstantOp F32ElementsAttr:$bias),
// $h_factor, $w_factor, TFL_AF_None,
// $padding, $stride_h, $stride_w),
// (ConstantOp F32ElementsAttr:$value), $act_fn),
// (TFL_Conv2DOp $input,
// (TFL_DequantizeOp (TFL_QuantizeOp
// (TFL_MulOp (ConstantOp $filter),
// (ConstantOp (ExpandTo4DForConv $value)),
// TFL_AF_None),
// (RescaleQtype $qtype, $value))),
// (TFL_MulOp (ConstantOp $bias), (ConstantOp $value),
// TFL_AF_None),
// $h_factor, $w_factor, $act_fn,
// $padding, $stride_h, $stride_w),
// [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
// (HasOneUse $conv_output),
// (IsPerAxisQuantization $qtype), // per-axis quantization
// ]>;
template <typename AffineOpType>
struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
// Mul. Required 1-D rhs for batch normalization.
DenseElementsAttr gamma_cst;
Value gamma = mul_op.rhs();
if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
if (gamma_cst.getType().getRank() != 1) return failure();
// Affine op
Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
if (!fc_op) return failure();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
// QDQs
auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
if (!dq_op) return failure();
auto q_op =
dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
if (!q_op) return failure();
filter = q_op.input();
// weight constant
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp)))
return failure();
if (fc_op.fused_activation_function() != "NONE") return failure();
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1.
rewriter.setInsertionPoint(q_op);
Location loc = fc_op.getLoc();
Value broadcasted_gamma;
if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
auto mul_rhs = ExpandTo4DForConv(gamma_cst);
broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
} else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
} else {
return failure();
}
// Rewrite filter constant. Since the folder of TFL::MulOp couldn't
// broadcast the operands, TF::MulOp is used to fold the constant.
auto new_filter =
rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
// Update the scale in the quantize op.
auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
if (!new_qtype) return failure();
rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
new_filter, new_qtype);
// If bias isn't None, it needs to be multiplied as well.
if (!bias.getType().isa<NoneType>()) {
rewriter.setInsertionPoint(fc_op);
auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
}
// Remove the tailing mul op.
mul_op.replaceAllUsesWith(fc_op.getResult());
return success();
}
};
using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
using FuseDepthwiseConv2DAndMulWithQDQs =
FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
// Fuse Binary Op with following Affine operation.
template <typename AffineOpType>
struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
using OpRewritePattern<AffineOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override {
// Binary op.
Operation *binary_op = fc_op.input().getDefiningOp();
if (!binary_op || binary_op->getNumOperands() != 2) return failure();
// We only handle the cases the RHS is a scalar.
// TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
// the constant operands are on the RHS, we need to consider LHS constant
// operand if necessary.
DenseFPElementsAttr cst;
if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
return failure();
if (cst.getNumElements() != 1) return failure();
APFloat cst_value = *cst.float_value_begin();
// Affine op.
Value filter = fc_op.filter();
Value bias = fc_op.bias();
DenseFPElementsAttr filter_cst, bias_cst;
if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant.
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
if (!dq) return failure();
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
return failure();
}
filter = q.input();
}
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&bias_cst)))
return failure();
auto binary_op_activation_func =
binary_op->template getAttrOfType<StringAttr>(
"fused_activation_function");
if (!binary_op_activation_func ||
binary_op_activation_func.getValue() != "NONE")
return failure();
ShapedType filter_type = filter_cst.getType();
if (llvm::isa<AddOp, SubOp>(binary_op)) {
auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
if (padding && padding.getValue() != "VALID") return failure();
// The fusion of add/sub is actually applying the following
// transformation:
// w * (x + c) + b => w * x + (w * c + b)
// so we have to update the bias.
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
auto bias_and_slice =
GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
int64_t bias_size = bias_and_slice.first;
int64_t slice_size = bias_and_slice.second;
ShapedType new_bias_type =
RankedTensorType::get({bias_size}, filter_type.getElementType());
// The new bias should be a 1-D tensor with length equals to the bias
// dimension of the weight.
SmallVector<APFloat, 4> new_bias_values;
if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
new_bias_values.resize(bias_size,
APFloat::getZero(cst_value.getSemantics()));
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
} else if (bias_cst.getNumElements() == bias_size) { // 1-d bias, copy it
new_bias_values.insert(new_bias_values.begin(),
bias_cst.float_value_begin(),
bias_cst.float_value_end());
} else {
return failure();
}
int64_t flatten_index = 0;
for (auto fp_it = filter_cst.float_value_begin(),
fp_end = filter_cst.float_value_end();
fp_it != fp_end; ++fp_it) {
int bias_index = (flatten_index++ / slice_size) % bias_size;
new_bias_values[bias_index] =
new_bias_values[bias_index] + *fp_it * cst_value;
}
auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
auto new_bias_op =
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
fc_op.setOperand(0, binary_op->getOperand(0));
fc_op.setOperand(2, new_bias_op);
} else if (llvm::isa<MulOp, DivOp>(binary_op)) {
// The fusion of mul/div is actually applying the following
// transformation:
// w * (x ' c) + b => (w ' c) x + b
// so we have to update the weight.
bool is_mul = llvm::isa<MulOp>(binary_op);
auto new_filter =
filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
});
// We recreate the constant op in case it is shared by the other ops. This
// might increase the model size.
auto new_filter_op = rewriter.create<ConstOp>(
fc_op.getLoc(), filter.getType(), new_filter);
fc_op.setOperand(0, binary_op->getOperand(0));
if (fc_op.filter() != filter) {
// This filter goes through quantize and dequantize ops. Then we just
// need to update the weight to the quantize op.
filter.replaceAllUsesWith(new_filter_op);
} else {
// This filter doesn't go through quantize and dequantize ops, Then
// we update the weight of the affine op directly.
fc_op.setOperand(1, new_filter_op);
}
} else {
return failure();
}
return success();
}
private:
// Returns the dimension length of the channel dimension and also the slide
// size by each position in the channel dimension accordingly. tfl.conv2d and
// tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
// has tailing channel dimension. This function is to provide a utility to
// create the above information from the op property.
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
ArrayRef<int64_t> filter_shape, AffineOpType op) {
// Channel dimension index is specified as op property
auto channel_index_iter = filter_shape.begin();
std::advance(channel_index_iter, op.GetChannelDimIndex());
// The slide size is the size of the data in higher dimensions.
int64_t slice_size =
std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
std::multiplies<int64_t>());
return {*channel_index_iter, slice_size};
}
};
// If the operand to a broadcastable op is a splat constant, try to replace it
// with a 0-d constant, e.g. before this optimization,
// %cst = constant dense<1.0> : tensor<16x16x4xf32>
// %0 = "tfl.conv_2d"...
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
// After this optimization:
// %cst = constant dense<1.0> : tensor<f32>
// %0 = "tfl.conv_2d"...
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
// This pattern can enable more fusing opportunities when the binary op is
// following conv ops.
template <typename BinaryOpType>
struct ScalarizeSplatConstantForBroadcastableOps
: public OpRewritePattern<BinaryOpType> {
using OpRewritePattern<BinaryOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(BinaryOpType binary_op,
PatternRewriter &rewriter) const override {
DenseElementsAttr splat_elements_attr;
if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
return failure();
}
constexpr int kSplatOperandIndex = 1;
auto result_type =
binary_op.getResult().getType().template cast<ShapedType>();
mlir::Value non_splat_operand =
binary_op.getOperand(1 - kSplatOperandIndex);
auto non_splat_operand_type =
non_splat_operand.getType().cast<ShapedType>();
// If the other operand's shape does not equal to the result shape, then we
// cannot scalarize the splat constant because the result shape relies on
// the splat constant op's shape for broadcasting.
if (!non_splat_operand_type.hasStaticShape() ||
non_splat_operand_type.getShape() != result_type.getShape() ||
non_splat_operand_type.getRank() > 4) {
return failure();
}
// If non-splat operand is not fusable affine ops, then no need to apply
// this transformation.
if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
return failure();
}
// Creates a new scalar constant op using the splat value.
mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
auto scalar_elements_attr = DenseElementsAttr::get(
RankedTensorType::get({},
splat_elements_attr.getType().getElementType()),
splat_elements_attr.getSplatValue());
auto scalar_constant_op = rewriter.create<ConstantOp>(
splat_operand.getLoc(), scalar_elements_attr.getType(),
scalar_elements_attr);
binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
return success();
}
private:
// Returns true if this value is a splat constant op which can be scalarized.
// Also returns the elements attr if this value is indeed a splat constant.
bool IsScalarizableSplatConstant(mlir::Value value,
DenseElementsAttr *elements_attr) const {
if (!matchPattern(value, m_Constant(elements_attr))) {
return false;
}
auto element_type = value.getType().cast<ShapedType>().getElementType();
// Ignore per-axis quantized constants because after converting to scalar,
// we will lose per-axis qantization parameter.
if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
return false;
}
if (IsScalar(value)) {
return false;
}
return elements_attr->isSplat();
}
// If this type is a scalar shaped type.
bool IsScalar(mlir::Value value) const {
auto type = value.getType().dyn_cast<ShapedType>();
if (!type) {
return false;
}
if (!type.hasStaticShape()) {
return false;
}
return type.getNumElements() == 1;
}
// Returns true if we can fuse an affine op with consuming binary op.
bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
TFL::FullyConnectedOp>(affine_op)) {
return false;
}
DenseElementsAttr value;
// Check that bias are constants if not none.
Value bias = affine_op->getOperand(2);
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&value))) {
return false;
}
// If the binary op is mul/div, also check that filter is constant.
if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
!matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
return false;
}
// We can only fuse F32/BF16.
auto is_fusable_type = [](Type t) {
Type element_type = t;
if (auto shaped_type = t.dyn_cast<ShapedType>()) {
element_type = shaped_type.getElementType();
}
return element_type.isBF16() || element_type.isF32();
};
for (Type t : binary_op->getOperandTypes()) {
if (!is_fusable_type(t)) {
return false;
}
}
return true;
}
};
using ScalarizeSplatConstantForSub =
ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
using ScalarizeSplatConstantForAdd =
ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
using ScalarizeSplatConstantForMul =
ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
using ScalarizeSplatConstantForDiv =
ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
struct ConvertTrivialTransposeOpToReshapeOp
: public OpRewritePattern<TFL::TransposeOp> {
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
PatternRewriter &rewriter) const override {
auto input_type = transpose_op.input().getType().cast<ShapedType>();
auto output_type = transpose_op.output().getType().cast<ShapedType>();
// It's possible to know if the transformation is safe only if the input
// & output shapes are fully known and permutation is a constant.
if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
return failure();
Value perm = transpose_op.perm();
DenseElementsAttr perm_values_attr;
if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
auto input_shape = input_type.getShape();
SmallVector<int64_t, 8> perm_values;
for (const auto &dim : perm_values_attr.getIntValues())
perm_values.push_back(dim.getSExtValue());
// This should never happen unless the input graph is malformed.
if (input_shape.size() != perm_values.size()) {
transpose_op.emitError(
"TransposeOP has inconsistent input and perm values.");
}
SmallVector<int, 8> old_major_index_ordering;
SmallVector<int, 8> new_major_index_ordering;
for (int i = 0, end = input_shape.size(); i < end; i++) {
if (input_shape[i] != 1) {
old_major_index_ordering.push_back(i);
}
if (input_shape[perm_values[i]] != 1) {
new_major_index_ordering.push_back(perm_values[i]);
}
}
if (old_major_index_ordering != new_major_index_ordering) {
return failure();
}
// Rewrite.
Location loc = transpose_op.getLoc();
SmallVector<int32_t, 8> output_shape_values;
for (auto dim : output_type.getShape()) {
output_shape_values.push_back(dim);
}
auto type = mlir::RankedTensorType::get(output_shape_values.size(),
rewriter.getIntegerType(32));
auto new_shape_attr =
mlir::DenseIntElementsAttr::get(type, output_shape_values);
auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
transpose_op, transpose_op.output().getType(), transpose_op.input(),
new_shape);
return success();
}
};
// Remove Reshape before FullyConnected when `keep_num_dims=false` and Reshape
// does not alter the last dimension as FullyConnected will collapse all other
// dimensions into a single dimension. For example,
//
// %shape = constant dense<[1, 128, 64]> : tensor<3xi32>
// %reshape = tfl.reshape(%input, %shape) // %input: tensor<128x64xf32>
// %fc = tfl.fully_connected(%reshape, %filter, %bias)
// {keep_num_dims = false, weights_format = "DEFAULT"}
//
// can be canonicalized to
//
// %fc = tfl.fully_connected(%input, %filter, %bias)
// {keep_num_dims = false, weights_format = "DEFAULT"}
struct RemoveReshapeBeforeFullyConnected
: public OpRewritePattern<TFL::FullyConnectedOp> {
using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op,
PatternRewriter &) const override {
auto input = fully_connected_op.input();
auto input_ty = input.getType().dyn_cast<ShapedType>();
auto output_ty = fully_connected_op.output()[0]
.getType()
.template dyn_cast<ShapedType>();
if (!input_ty.hasStaticShape() ||
fully_connected_op.weights_format() != "DEFAULT" ||
fully_connected_op.keep_num_dims() || !output_ty.hasStaticShape() ||
output_ty.getRank() != 2) {
return failure();
}
auto reshape_op = input.getDefiningOp<TFL::ReshapeOp>();
if (!reshape_op) return failure();
// Check if the last dimension does not change after reshape.
auto reshape_input = reshape_op.input();
auto reshape_input_ty = reshape_input.getType().dyn_cast<ShapedType>();
if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 ||
reshape_input_ty.getRank() == 0 ||
input_ty.getDimSize(input_ty.getRank() - 1) !=
reshape_input_ty.getDimSize(reshape_input_ty.getRank() - 1)) {
return failure();
}
// Connect the input to the one of reshape.
fully_connected_op.setOperand(0, reshape_input);
return success();
}
};
// Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshaoe
// does not alter the last dimension and it restores the batch dimensions
// collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
//
// // %input: tensor<4x16x32xf32>
// %fc = tfl.fully_connected(%input, %filter, %bias)
// {keep_num_dims = false, weights_format = "DEFAULT"}
// %shape = constant dense<[4, 16, 32]> : tensor<3xi32>
// %rs = tfl.reshape(%fc, %shape)
//
// can be canonicalized to
//
// %fc = tfl.fully_connected(%input, %filter, %bias)
// {keep_num_dims = true, weights_format = "DEFAULT"}
struct RemoveReshapeAfterFullyConnected
: public OpRewritePattern<TFL::ReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
PatternRewriter &rewriter) const override {
auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
reshape_op.input().getDefiningOp());
if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
fully_connected_op.weights_format() != "DEFAULT" ||
fully_connected_op.keep_num_dims())
return failure();
if (!reshape_op.input().hasOneUse()) return failure();
auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
auto reshape_shape = reshape_op.getType().cast<ShapedType>();
if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
!reshape_shape.hasStaticShape())
return failure();
// Check that the reshape doesn't modify the last dimension and it restores
// the input (batch) dimension with the exception of the feature (last)
// dimension.
if (output_shape.getShape().back() != reshape_shape.getShape().back() ||
input_shape.getShape().drop_back() !=
reshape_shape.getShape().drop_back())
return failure();
llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
reshape_op, output_type, fully_connected_op.input(),
fully_connected_op.filter(), fully_connected_op.bias(),
fully_connected_op.fused_activation_function(),
fully_connected_op.weights_format(), /*keep_num_dims=*/true);
return success();
}
};
// Fuses Unpack with proceeding Concatenation to Reshape if output type has
// static shape and activation function is none. For example:
//
// // %input: tensor<1x3x2xf32>
// %unpack:3 = "tfl.unpack"(%input) {axis = 1 : i32, num = 3 : i32}
// %res = "tfl.concatenation"(%unpack#0, %unpack#1, %unpack#2)
// {axis = -1 : i32, fused_activation_function = "NONE"}
//
// can be optimized to
//
// %cst = constant dense<[1, 6]> : tensor<2xi32>
// %res = "tfl.reshape"(%input, %cst)
struct FuseUnpackAndConcatToReshape
: public OpRewritePattern<TFL::ConcatenationOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op,
PatternRewriter &rewriter) const override {
if (concat_op.fused_activation_function() != "NONE") {
return failure();
}
// Checks all operands come from the same unpack op.
auto first_operand = concat_op.values().front();
auto unpack_op =
dyn_cast_or_null<TFL::UnpackOp>(first_operand.getDefiningOp());
if (!unpack_op || unpack_op.getNumResults() != concat_op.getNumOperands()) {
return failure();
}
for (auto &index_and_value : llvm::enumerate(concat_op.values())) {
if (index_and_value.value() !=
unpack_op.getResult(index_and_value.index())) {
return failure();
}
}
auto output_type = concat_op.getType().cast<ShapedType>();
if (!output_type.hasStaticShape()) {
return failure();
}
auto new_shape_array = output_type.getShape();
// This is to workaround the unnecessary cast i64 -> i32.
SmallVector<int32_t, 4> new_shape_array_i32;
for (auto size : new_shape_array) {
new_shape_array_i32.push_back(static_cast<int32_t>(size));
}
auto new_shape = rewriter.create<TFL::ConstOp>(
concat_op.getLoc(),
DenseIntElementsAttr::get(
RankedTensorType::get(new_shape_array_i32.size(),
rewriter.getIntegerType(32)),
new_shape_array_i32));
rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(concat_op, output_type,
unpack_op.input(), new_shape);
return success();
}
};
using FuseBinaryOpToFollowingFullyConnected =
FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
using FuseBinaryOpToFollowingDepthwiseConv2D =
FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
// Adds canonicalization patterns to the list of patterns.
void AddCanonicalizationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(*patterns, context);
}
void OptimizePass::runOnFunction() {
OwningRewritePatternList patterns(&getContext());
auto *ctx = &getContext();
auto func = getFunction();
// Merge reshapes into fully connected ops before we start moving them past
// binary ops.
OwningRewritePatternList phase_0_patterns(&getContext());
phase_0_patterns.insert<RemoveReshapeAfterFullyConnected,
RemoveReshapeBeforeFullyConnected>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
// Potentially the binary ops might be fused together, like hard_swish, thus
// we explore these potentially first and then fuse the binary ops with the
// following ops in a second pattern match.
TFL::populateWithGenerated(patterns);
patterns.insert<FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>>(ctx);
if (enable_canonicalization_) AddCanonicalizationPatterns(ctx, &patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
// Fuse the binary ops with the following ops.
OwningRewritePatternList phase_2_patterns(&getContext());
TFL::populateWithGenerated(phase_2_patterns);
phase_2_patterns.insert<
ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
FuseBinaryOpToFollowingConv2D, FuseBinaryOpToFollowingDepthwiseConv2D,
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected,
FuseUnpackAndConcatToReshape>(ctx);
if (enable_canonicalization_)
AddCanonicalizationPatterns(ctx, &phase_2_patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass(
bool enable_canonicalization) {
return std::make_unique<OptimizePass>(enable_canonicalization);
}
static PassRegistration<OptimizePass> pass(
"tfl-optimize", "Optimize within the TensorFlow Lite dialect");
} // namespace TFL
} // namespace mlir