blob: f9e636c75b9ce3d623da8f926ab603510a578a91 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
#include <algorithm>
#include <iterator>
#include <string>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/FunctionImplementation.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
namespace mlir {
namespace TFR {
//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for inlining within the TFR dialect.
class TFRInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
public:
// Allow all call operations to be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
// Returns true if the given region 'src' can be inlined into the region
// 'dest' that is attached to an operation registered to the current dialect.
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
// Returns true if the given operation 'op', that is registered to this
// dialect, can be inlined into the region 'dest' that is attached to an
// operation registered to the current dialect.
bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
// Handle the given inlined terminator by replacing it with a new operation
// as necessary. Required when the region has only one block.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
auto retValOp = dyn_cast<TFRReturnOp>(op);
if (!retValOp) return;
for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
}
}
// Attempts to materialize a conversion for a type mismatch between a call
// from this dialect, and a callable region. This method should generate an
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type result_type,
Location conversion_loc) const final {
if (!input.getType().isa<IntegerType>() ||
!result_type.isa<IntegerType>()) {
return nullptr;
}
auto input_itype = input.getType().cast<IntegerType>();
auto result_itype = result_type.cast<IntegerType>();
if (input_itype.getWidth() == result_itype.getWidth()) return nullptr;
if (input_itype.getWidth() > result_itype.getWidth()) {
return builder.create<arith::TruncIOp>(conversion_loc, result_type,
input);
} else {
return builder.create<arith::ExtSIOp>(conversion_loc, result_type, input);
}
}
};
} // namespace
//===----------------------------------------------------------------------===//
// TFR Dialect
//===----------------------------------------------------------------------===//
TFRDialect::TFRDialect(MLIRContext *context)
: Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
// TFR depends on TensorFlow for its canonicalization
context->getOrLoadDialect<TF::TensorFlowDialect>();
addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
>();
addInterfaces<TFRInlinerInterface>();
}
Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
if (func::ConstantOp::isBuildableWith(value, type))
return builder.create<func::ConstantOp>(loc, type,
value.cast<FlatSymbolRefAttr>());
return nullptr;
}
bool TFRType::classof(Type type) {
return llvm::isa<TFRDialect>(type.getDialect());
}
//===----------------------------------------------------------------------===//
// Custom op methods
//===----------------------------------------------------------------------===//
LogicalResult ConstantTensorOp::verify() {
ConstantTensorOp op = *this;
auto input_type = op.arg().getType();
auto output_type = op.out().getType();
if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
return success();
}
auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
op.emitError("output type should be static and ranked.");
return failure();
}
if (output_tensor_type.getRank() == 0) {
bool same_scalar = output_tensor_type.getElementType() == input_type;
if (!same_scalar) {
op.emitError("input and output should have the same scalar types.");
}
return success(same_scalar);
}
if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
bool same_element_type = output_tensor_type.getElementType() ==
input_vector_type.getElementType();
bool same_shape =
output_tensor_type.getShape() == input_vector_type.getShape();
if (!same_element_type || !same_shape) {
op.emitError("input and output should have same shape and element type.");
}
return success(same_element_type && same_shape);
}
op.emitError("input can not be converted to an output tensor.");
return failure();
}
LogicalResult TFRFuncOp::verify() {
TFRFuncOp func = *this;
// Collect all attribute names used by the tensor and tensor list arguments
// and returns. Also, collect the names of all the attribute arguments as the
// defined list. Later on, the used attribute names will be verified to be in
// the defined list.
llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
// While scanning the arguments, record the start/end indices of each argument
// type, so the order can be verified as well.
// TODO(fengliuai): the attribute arguments with default values need to be
// at the end?
int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
last_tensor_list = -1, first_attr = -1;
for (auto arg : llvm::enumerate(func.getFunctionType().getInputs())) {
Type arg_type = arg.value();
if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
if (first_tensor == -1) {
first_tensor = arg.index();
}
last_tensor = arg.index();
auto used = tensor.getAttrKeys();
input_used_attrs.append(used.begin(), used.end());
continue;
}
if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
if (first_tensor_list == -1) {
first_tensor_list = arg.index();
}
last_tensor_list = arg.index();
auto used = tensor_list.getAttrKeys();
input_used_attrs.append(used.begin(), used.end());
continue;
}
if (!arg_type.isa<TensorType>()) {
if (first_attr == -1) {
first_attr = arg.index();
}
auto name =
func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
if (!name) {
func.emitError(
llvm::Twine(arg.index()) +
" attribute argument doesn't have a tfr.name attribute.");
return failure();
}
continue;
}
func.emitError("Builtin TensorType isn't allowed as the argument.");
return failure();
}
// Collect all the undefined attributes used in the inputs.
llvm::SmallVector<StringAttr, 4> undefined_attrs;
for (auto attr : input_used_attrs) {
if (!func->getAttr(attr.getValue())) {
undefined_attrs.push_back(attr);
}
}
// Verify the argument order: tensors, tensor list, attributes; and also
// verify there is at most one tensor list argument.
if (first_attr != -1 &&
(first_attr < last_tensor_list || first_attr < last_tensor)) {
func.emitError(
"tfr.tensor/tfr.tensor_list argument should be before non tensor "
"arguments.");
return failure();
}
// The order between tensor arguments and tensor list arguments and the number
// of tensor list arguments are verified only when they couldn't be determined
// by the attributes.
if (!undefined_attrs.empty()) {
if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
func.emitError(
"tfr.tensor argument should be before tfr.tensor_list argument.");
return failure();
}
if (first_tensor_list != last_tensor_list) {
func.emitError("More than one tfr.tensor_list argument isn't allowed.");
return failure();
}
}
// Verify the result order: tensor, tensor list, and also verify at most one
// tensor list result.
int undefined_input_attrs_number = undefined_attrs.size();
bool seen_tensor_list = false, has_tensor_list_order_error = false,
has_multiple_tensor_lists_error = false;
for (auto result_type : func.getFunctionType().getResults()) {
if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
if (seen_tensor_list) {
has_tensor_list_order_error = true;
} else {
auto used = tensor.getAttrKeys();
output_used_attrs.append(used.begin(), used.end());
}
continue;
}
if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
if (seen_tensor_list) {
has_multiple_tensor_lists_error = true;
} else {
seen_tensor_list = true;
auto used = tensor_list.getAttrKeys();
output_used_attrs.append(used.begin(), used.end());
}
continue;
}
func.emitError(
"None tfr.tensor/tfr.tensor_list results aren't allowed as a "
"result.");
return failure();
}
// Collect all the undefined attributes used in the outputs.
for (auto attr : output_used_attrs) {
if (!func->getAttr(attr.getValue())) {
undefined_attrs.push_back(attr);
}
}
// Verify there are no tensor/tensor list order error and multiple tensor
// list arguments error.
if (undefined_input_attrs_number != undefined_attrs.size()) {
if (has_tensor_list_order_error) {
func.emitError(
"tfr.tensor result should be before tfr.tensor_list result.");
return failure();
} else if (has_multiple_tensor_lists_error) {
func.emitError("More than one tfr.tensor_list result isn't allowed.");
return failure();
}
}
// TODO(fengliuai): We might want to refine this constraint because the
// tensor element type can be derived.
if (!undefined_attrs.empty()) {
llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
std::transform(undefined_attrs.begin(), undefined_attrs.end(),
attr_names.begin(),
[](StringAttr attr) { return attr.getValue().str(); });
func.emitError(llvm::Twine("Undefined attributes are used: ",
llvm::join(attr_names, ",")));
return failure();
}
return success();
}
ParseResult TFRFuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto build_func_type =
[](Builder &builder, ArrayRef<Type> arg_types, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(arg_types, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, build_func_type);
}
void TFRFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
} // namespace TFR
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
namespace mlir {
namespace TFR {
namespace {
class ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
PatternRewriter &rewriter) const override {
Location loc = cst_tensor_op.getLoc();
Type out_type = cst_tensor_op.getType();
Operation *new_cst = nullptr;
ArrayAttr array;
if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
llvm::DenseSet<Type> all_types;
for (auto it : array) {
TypedAttr typed_attr = it.dyn_cast<TypedAttr>();
if (!typed_attr) return failure();
all_types.insert(typed_attr.getType());
}
if (all_types.size() != 1) return failure();
ShapedType new_out_type = RankedTensorType::get(
{static_cast<int64_t>(array.size())}, *all_types.begin());
DenseElementsAttr attr =
DenseElementsAttr::get(new_out_type, array.getValue());
new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
if (out_type.isa<TFRTensorType>()) {
new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
}
rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
return success();
}
TypedAttr scalar;
if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
Type new_out_type = RankedTensorType::get({}, scalar.getType());
new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
if (out_type.isa<TFRTensorType>()) {
new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
}
rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
return success();
}
return failure();
}
};
inline bool isQuantizedType(Type type) {
auto tensor_type = type.dyn_cast<TensorType>();
return (tensor_type &&
tensor_type.getElementType().isa<quant::QuantizedType>());
}
class RemoveRedundantCast : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(CastOp cast_op,
PatternRewriter &rewriter) const override {
auto preceding_cast =
llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
if (!preceding_cast) {
return failure();
}
Value input = preceding_cast.arg();
Type input_type = input.getType();
Type output_type = cast_op.getType();
// Preserve quantization information for intermediate tensors.
auto intermediate_type = preceding_cast.getType();
if (isQuantizedType(intermediate_type) || isQuantizedType(output_type)) {
return failure();
}
auto input_tensor_type = input_type.dyn_cast<TensorType>();
auto output_tensor_type = output_type.dyn_cast<TensorType>();
if (!input_tensor_type || !output_tensor_type) {
return failure();
}
// Canonicalize two tfr.cast pairs with different element type to
// two tfr.casts with the same element type followed by a tf.Cast.
if ((input_tensor_type.getElementType() !=
output_tensor_type.getElementType()) &&
!isQuantizedType(input_type) && !isQuantizedType(output_type)) {
auto new_tfr_cast = rewriter.create<TFR::CastOp>(
cast_op.getLoc(),
output_tensor_type.clone(input_tensor_type.getElementType()),
cast_op.arg());
rewriter.replaceOpWithNewOp<TF::CastOp>(cast_op, output_type,
new_tfr_cast);
return success();
}
// If the two types are the same, the back-to-back tfr.cast ops can be
// removed.
if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
rewriter.replaceOp(cast_op, {input});
return success();
}
// If the rank of the input tensor isn't ranked, we replace the pair
// with tf.EnsureShape op so it can be removed after shape inference or
// confirmed at runtime.
if (input_type.isa<UnrankedTensorType>()) {
auto shape = output_type.cast<ShapedType>().getShape();
auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
input, shape_attr);
return success();
}
return failure();
}
};
class GetTensorShape : public OpRewritePattern<GetShapeOp> {
using OpRewritePattern<GetShapeOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(GetShapeOp shape_op,
PatternRewriter &rewriter) const override {
Operation *preceding_op = shape_op.arg().getDefiningOp();
if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
// replace this pair by shape.shape_of, so the folding works.
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
return success();
}
return failure();
}
};
class RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
using OpRewritePattern<GetElementOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(GetElementOp ge_op,
PatternRewriter &rewriter) const override {
IntegerAttr index;
if (!matchPattern(ge_op.index(), m_Constant(&index))) {
return failure();
}
auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
ge_op.tensor_list().getDefiningOp());
if (!preceding_build_list ||
preceding_build_list.getNumOperands() <= index.getInt()) {
return failure();
}
Value input = preceding_build_list.getOperand(index.getInt());
Type output_type = ge_op.getType();
if (input.getType() != output_type &&
!output_type.isa<UnrankedTensorType>()) {
return failure();
}
rewriter.replaceOp(ge_op, {input});
return success();
}
};
class RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
using OpRewritePattern<GetLengthOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(GetLengthOp gl_op,
PatternRewriter &rewriter) const override {
auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
gl_op.tensor_list().getDefiningOp());
if (!preceding_build_list) {
return failure();
}
int64_t num_tensors = preceding_build_list.getNumOperands();
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
gl_op, rewriter.getIndexAttr(num_tensors));
return success();
}
};
class BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
using OpRewritePattern<BuildListOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(BuildListOp bl_op,
PatternRewriter &rewriter) const override {
SmallVector<Attribute, 4> array_list;
array_list.reserve(bl_op.getNumOperands());
for (const auto &operand : bl_op.getOperands()) {
Attribute array_elt;
if (!matchPattern(operand, m_Constant(&array_elt))) {
return failure();
}
array_list.push_back(array_elt);
}
auto array_attr = rewriter.getArrayAttr(array_list);
rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
return success();
}
};
quant::QuantizedType getQuantizedElementType(CastOp cast_op) {
if (!cast_op || !cast_op.getInputElementType()) {
return {};
}
return cast_op.getInputElementType()
.cast<TypeAttr>()
.getValue()
.dyn_cast<quant::QuantizedType>();
}
class RemoveRawDataOp : public OpRewritePattern<TFRQuantRawDataOp> {
using OpRewritePattern<TFRQuantRawDataOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(TFRQuantRawDataOp raw_data_op,
PatternRewriter &rewriter) const override {
auto preceding_op = raw_data_op.input().getDefiningOp();
if (isa<BuildListOp>(preceding_op)) {
return rewritePrecedingListOp(raw_data_op, rewriter);
}
auto preceding_cast = dyn_cast_or_null<CastOp>(preceding_op);
if (!preceding_cast || !getQuantizedElementType(preceding_cast)) {
return failure();
}
// If there are redundant casts, hoist output of raw data op originating op.
if (preceding_cast.arg().getDefiningOp()) {
auto redundant_cast = preceding_cast.arg().getDefiningOp<CastOp>();
if (!redundant_cast ||
redundant_cast.arg().getType() != preceding_cast.out().getType()) {
return failure();
}
raw_data_op.output().replaceAllUsesWith(redundant_cast.arg());
} else {
// If the argument of cast op is input, then simply remove the RawData op.
raw_data_op.output().replaceAllUsesWith(preceding_cast.out());
}
return success();
}
LogicalResult rewritePrecedingListOp(TFRQuantRawDataOp raw_data_op,
PatternRewriter &rewriter) const {
llvm::SmallVector<Value> new_list_values;
auto preceding_list = raw_data_op.input().getDefiningOp<BuildListOp>();
for (Value operand : preceding_list.tensors()) {
auto preceding_cast = operand.getDefiningOp<CastOp>();
if (!preceding_cast || !getQuantizedElementType(preceding_cast)) {
return failure();
}
// This function currently only supports the case with redundant casts.
auto redundant_cast = preceding_cast.arg().getDefiningOp<CastOp>();
if (!redundant_cast ||
redundant_cast.arg().getType() != preceding_cast.out().getType()) {
return failure();
}
new_list_values.push_back(redundant_cast.arg());
}
auto new_list = rewriter.create<BuildListOp>(
raw_data_op.getLoc(), preceding_list.getType(), new_list_values);
raw_data_op.output().replaceAllUsesWith(new_list.out());
return success();
}
};
class RemoveQParamsOp : public OpRewritePattern<TFRQuantQParamsOp> {
using OpRewritePattern<TFRQuantQParamsOp>::OpRewritePattern;
public:
LogicalResult matchAndRewrite(TFRQuantQParamsOp qparams_op,
PatternRewriter &rewriter) const override {
auto cast_op = dyn_cast<TFR::CastOp>(qparams_op.input().getDefiningOp());
auto cast_qtype = getQuantizedElementType(cast_op);
if (!cast_qtype) {
return failure();
}
TF::ConstOp scale_op;
TF::ConstOp zp_op;
// Reads quantization parameters from the quantized type, and converts
// them to constants.
rewriter.setInsertionPoint(qparams_op);
Location loc = qparams_op->getLoc();
if (auto qtype = cast_qtype.dyn_cast<quant::UniformQuantizedType>()) {
scale_op = rewriter.create<TF::ConstOp>(
loc, RankedTensorType::get({}, rewriter.getF32Type()),
rewriter.getF32FloatAttr(qtype.getScale()));
zp_op = rewriter.create<TF::ConstOp>(
loc, RankedTensorType::get({}, rewriter.getI32Type()),
rewriter.getI32IntegerAttr(qtype.getZeroPoint()));
} else if (auto qtype =
cast_qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
SmallVector<float> scales(qtype.getScales().begin(),
qtype.getScales().end());
SmallVector<int32_t> zps(qtype.getZeroPoints().begin(),
qtype.getZeroPoints().end());
const size_t num_channels = qtype.getScales().size();
auto scales_type = RankedTensorType::get(
{static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
auto scales_attr =
DenseElementsAttr::get(scales_type, llvm::makeArrayRef(scales));
scale_op = rewriter.create<TF::ConstOp>(loc, scales_attr);
auto zps_type = RankedTensorType::get(
{static_cast<int64_t>(num_channels)}, rewriter.getI32Type());
auto zps_attr = DenseElementsAttr::get(zps_type, llvm::makeArrayRef(zps));
zp_op = rewriter.create<TF::ConstOp>(loc, zps_attr);
}
if (!scale_op || !zp_op) {
return failure();
}
auto scale_cast = rewriter.create<CastOp>(loc, qparams_op.scale().getType(),
scale_op.output());
auto zp_cast =
rewriter.create<CastOp>(loc, qparams_op.zp().getType(), zp_op.output());
qparams_op.scale().replaceAllUsesWith(scale_cast.out());
qparams_op.zp().replaceAllUsesWith(zp_cast.out());
return success();
}
};
// TODO(b/193731721): Migrate tfr_ builtin canonicalizations to LowerTFROpPass
class RemoveScaleFactorOp : public OpRewritePattern<TFRQuantScaleFactorOp> {
using OpRewritePattern<TFRQuantScaleFactorOp>::OpRewritePattern;
public:
// Replace quant_scale_factor with constant tensor equivalent to
// TFR_ConstantTensorOp (
// ConstantOp (ConstAttr<F32Attr (in_scale[0] * in_scale[1] /
// out_scale))
// )
// Currently, all decompositions using this pattern (Conv2D, FC) have the
// following preconditions:
// * out_scale: float scalar attribute
// * in_scale[0] (input scale): float scalar, given by tf.Const -> tfr.cast
// * in_scale[1] (filter scale): float scalar/vector
// (per-tensor vs per-channel) quantization, given by tf.Const -> tfr.cast
LogicalResult matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,
PatternRewriter &rewriter) const override {
auto out_scale_op =
scale_factor_op.out_scale().getDefiningOp<arith::ConstantOp>();
if (!out_scale_op) {
return failure();
}
const double out_scale =
out_scale_op.getValue().cast<FloatAttr>().getValueAsDouble();
auto in_scales_op =
scale_factor_op.in_scales().getDefiningOp<BuildListOp>();
if (!in_scales_op || in_scales_op.getNumOperands() != 2) {
// BuildListOp is variadic, but we require two values: input_scale
// and filter_scale.
return failure();
}
auto in_scale_op = in_scales_op.getOperand(0).getDefiningOp<CastOp>();
if (!in_scale_op) {
return failure();
}
DenseFPElementsAttr in_scale_attr;
if (!matchPattern(in_scale_op.arg(), m_Constant(&in_scale_attr)) ||
in_scale_attr.size() != 1) {
return failure();
}
const float in_scale = in_scale_attr.getValues<float>()[0];
auto filter_scale_op = in_scales_op.getOperand(1).getDefiningOp<CastOp>();
if (!filter_scale_op) {
return failure();
}
DenseFPElementsAttr filter_scale_attr;
if (!matchPattern(filter_scale_op.arg(), m_Constant(&filter_scale_attr))) {
return failure();
}
// The shape of scale_type is {} (rank 0) for per-tensor quantized tensor,
// and {num_channels} (rank 1) for per-channel quantized one.
auto scale_type = filter_scale_attr.getType().dyn_cast<RankedTensorType>();
if (scale_type.getRank() != 0 && scale_type.getRank() != 1) {
return failure();
}
SmallVector<float> scale_factors;
scale_factors.reserve(filter_scale_attr.size());
for (auto value : filter_scale_attr.getValues<APFloat>()) {
scale_factors.push_back(in_scale * value.convertToFloat() / out_scale);
}
rewriter.setInsertionPoint(scale_factor_op);
const Location loc = scale_factor_op->getLoc();
auto result_scale_op = rewriter.create<TF::ConstOp>(
loc,
DenseElementsAttr::get(scale_type, llvm::makeArrayRef(scale_factors)));
auto result_scale_cast_op = rewriter.create<CastOp>(
loc, scale_factor_op.getType(), result_scale_op.output());
scale_factor_op.scale_factor().replaceAllUsesWith(
result_scale_cast_op.out());
return success();
}
};
class RemoveRescaleOp : public OpRewritePattern<TFRQuantRescaleOp> {
using OpRewritePattern<TFRQuantRescaleOp>::OpRewritePattern;
public:
// Replace quant_rescale (input, scale, zp) with
// tf.Cast(tf.Round(tf.Cast(input, f32) * scale) + tf.Cast(zp, f32), i32)
LogicalResult matchAndRewrite(TFRQuantRescaleOp rescale_op,
PatternRewriter &rewriter) const override {
Value input = rescale_op.input();
Value scale = rescale_op.scale();
Value zp = rescale_op.zp();
const Location loc = rescale_op->getLoc();
const auto result_types = rescale_op->getResultTypes();
auto c_false =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));
TypeAttr f32_attr = TypeAttr::get(rewriter.getF32Type());
TFRAttrType output_type = TFRAttrType::get(rewriter.getContext());
auto constant_f32_op = rewriter.create<ConstOp>(loc, output_type, f32_attr);
TypeAttr i32_attr = TypeAttr::get(rewriter.getI32Type());
auto constant_i32_op = rewriter.create<ConstOp>(loc, output_type, i32_attr);
IntegerAttr zp_attr;
if (!matchPattern(zp, m_Constant(&zp_attr))) {
return failure();
}
rewriter.setInsertionPoint(zp.getDefiningOp());
auto zp_tensor = rewriter.create<TF::ConstOp>(
loc, RankedTensorType::get({}, zp.getType()), zp_attr);
auto zp_cast = rewriter.create<CastOp>(
loc, rewriter.getType<TFRTensorType>(), zp_tensor.output());
rewriter.setInsertionPoint(rescale_op);
auto cast_input_to_float_op = rewriter.create<CallOp>(
loc, result_types,
SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
ArrayRef<Value>{input, constant_f32_op, c_false});
auto input_x_scale_op = rewriter.create<CallOp>(
loc, result_types, SymbolRefAttr::get(rewriter.getContext(), "tf__mul"),
ArrayRef<Value>{cast_input_to_float_op.getResult(0), scale});
auto round_rescaled_op = rewriter.create<CallOp>(
loc, result_types,
SymbolRefAttr::get(rewriter.getContext(), "tf__round"),
ArrayRef<Value>{input_x_scale_op->getResult(0)});
auto cast_zp_to_float_op = rewriter.create<CallOp>(
loc, result_types,
SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
ArrayRef<Value>{zp_cast, constant_f32_op, c_false});
auto recentered_op = rewriter.create<CallOp>(
loc, result_types, SymbolRefAttr::get(rewriter.getContext(), "tf__add"),
ArrayRef<Value>{round_rescaled_op->getResult(0),
cast_zp_to_float_op->getResult(0)});
auto cast_output_to_i32 = rewriter.create<CallOp>(
loc, result_types,
SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
ArrayRef<Value>{recentered_op->getResult(0), constant_i32_op, c_false});
rescale_op.output().replaceAllUsesWith(cast_output_to_i32.getResult(0));
return success();
}
};
} // namespace
void ConstantTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertConstToTensorConst>(context);
}
void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveRedundantCast>(context);
}
void GetShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<GetTensorShape>(context);
}
void GetElementOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveRedundantGetElement>(context);
}
void GetLengthOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveRedundantGetLength>(context);
}
void BuildListOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<BuildConstantListAsAttr>(context);
}
void TFRQuantRawDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveRawDataOp>(context);
}
void TFRQuantQParamsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveQParamsOp>(context);
}
void TFRQuantRescaleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveRescaleOp>(context);
}
void TFRQuantScaleFactorOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<RemoveScaleFactorOp>(context);
}
OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "equal op has two operands");
auto ctx = getContext();
if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
return BoolAttr::get(ctx, false);
}
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
}
// CallableOpInterface
Region *TFRFuncOp::getCallableRegion() {
return isExternal() ? nullptr : &body().front();
}
// CallableOpInterface
ArrayRef<Type> TFRFuncOp::getCallableResults() {
return getFunctionType().getResults();
}
//===----------------------------------------------------------------------===//
// Dialect type definitions
//===----------------------------------------------------------------------===//
// Parses a TFR type.
// tfr_type ::= tensor_type | tensor_list_type | attr_type
// string_list ::= `[` string-literal (, string-literal)+ `]`
// tensor_type ::= `tensor`
// | `tensor<` (string-literal | string_list) '>'
// tensor_list_type ::= `tensor_list`
// | `tensor_list<` (string-literal | string_list) '>'
// attr_type ::= `attr`
Type TFRDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
MLIRContext *ctx = loc.getContext();
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
llvm::SmallVector<StringAttr, 4> attrs;
if (succeeded(parser.parseOptionalLess())) {
bool l_square_parsed = false;
if (succeeded(parser.parseOptionalLSquare())) {
l_square_parsed = true;
}
do {
StringRef attr;
if (failed(parser.parseKeyword(&attr))) return {};
attrs.push_back(StringAttr::get(ctx, attr));
} while (succeeded(parser.parseOptionalComma()));
if (l_square_parsed && failed(parser.parseRSquare())) {
parser.emitError(parser.getNameLoc(), "expected ']'");
}
if (failed(parser.parseGreater())) {
parser.emitError(parser.getNameLoc(), "expected '>'");
}
}
if (typeNameSpelling == "tensor") {
return TFRTensorType::getChecked(attrs, loc);
} else if (typeNameSpelling == "tensor_list") {
return TFRTensorListType::getChecked(attrs, loc);
} else if (typeNameSpelling == "attr") {
return TFRAttrType::getChecked(loc, loc.getContext());
} else {
parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
return {};
}
}
void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
llvm::ArrayRef<StringAttr> attrs;
if (type.isa<TFRAttrType>()) {
os << "attr";
return;
}
if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
attrs = tensor_ty.getAttrKeys();
os << "tensor";
} else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
attrs = tensor_list_ty.getAttrKeys();
os << "tensor_list";
} else {
llvm_unreachable("Unhandled tfr type");
}
if (attrs.empty()) return;
os << "<";
if (attrs.size() > 1) {
os << "[";
}
llvm::interleaveComma(attrs, os,
[&](StringAttr attr) { os << attr.getValue(); });
if (attrs.size() > 1) {
os << "]";
}
os << ">";
}
} // namespace TFR
} // namespace mlir