blob: 744da4661cd6fd769f47bc71add1f83af8e59701 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
// NOLINTNEXTLINE
static llvm::cl::list<std::string> io_node_types(
"tfl-test-io-types", llvm::cl::value_desc("list"),
llvm::cl::desc("comma separated type strings. Allowed values: "
"'int8', 'uint8', 'float32']"),
llvm::cl::CommaSeparated);
namespace mlir {
namespace TFL {
namespace {
// This transformation pass modifies the input and output types of the function
// to what are specified. The task was not just adding cast operations, but,
// instead, using tfl.quantize and tfl.dequantize ops to scale the tensors.
struct ModifyIONodesPass
: public PassWrapper<ModifyIONodesPass, OperationPass<func::FuncOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ModifyIONodesPass)
explicit ModifyIONodesPass() {}
explicit ModifyIONodesPass(mlir::Type input_type, mlir::Type output_type)
: input_type(input_type), output_type(output_type) {}
void runOnOperation() override;
StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "tfl-modify-io-nodes";
}
StringRef getDescription() const final {
// This is a brief description of the pass.
return "Modify the type of the model io nodes.";
}
private:
// Assign the io types from the command line flag. This is only required for
// tests.
LogicalResult SetupInputOutputTypesIfNull(OpBuilder builder);
// Modifies the element types of entry block arguments to be user specified
// and returns the new argument types.
LogicalResult ModifyInputNodes(func::FuncOp func,
llvm::SmallVectorImpl<Type>& new_input_types,
OpBuilder builder);
// Modifies the element types of entry block returns to be user specified
// and returns the new return types.
LogicalResult ModifyOutputNodes(func::FuncOp func,
llvm::SmallVectorImpl<Type>& new_output_types,
OpBuilder builder);
mlir::Type input_type;
mlir::Type output_type;
};
LogicalResult ModifyIONodesPass::SetupInputOutputTypesIfNull(
OpBuilder builder) {
if (input_type && output_type) return success();
auto convert_str_to_type = [&builder](absl::string_view str) -> Type {
if (str == "int8") {
return builder.getIntegerType(8);
} else if (str == "uint8") {
return builder.getIntegerType(8, /*isSigned=*/false);
} else if (str == "float32") {
return builder.getF32Type();
} else {
return {};
}
};
if (io_node_types.size() < 2) return failure();
if (!input_type) input_type = convert_str_to_type(io_node_types[0]);
if (!output_type) output_type = convert_str_to_type(io_node_types[1]);
return success();
}
LogicalResult ModifyIONodesPass::ModifyInputNodes(
func::FuncOp func, llvm::SmallVectorImpl<Type>& new_input_types,
OpBuilder builder) {
if (input_type.isa<FloatType>()) {
return success();
}
Block& block = func.front();
builder.setInsertionPointToStart(&block);
for (int i = 0; i != block.getNumArguments(); ++i) {
Value arg = block.getArgument(0);
Type arg_type = arg.getType();
Value new_arg = arg;
Location loc = func.getLoc();
if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
auto quantize_output = quantize_op.output();
auto current_type = quant::QuantizedType::getQuantizedElementType(
quantize_output.getType())
.getStorageType();
if (current_type == input_type) { // int8 == int8
arg_type = quantize_output.getType();
new_arg = block.addArgument(arg_type, loc);
quantize_output.replaceAllUsesWith(new_arg);
} else if (input_type.isUnsignedInteger(
current_type.getIntOrFloatBitWidth())) { // int8 != uint8
arg_type = quant::ConvertSignedQuantizedToUnsigned(
quantize_output.getType(), loc);
new_arg = block.addArgument(arg_type, loc);
quantize_op.setOperand(new_arg);
} else {
input_type.print(llvm::errs() << "Requested input type ");
quantize_op.emitError(" Couldn't be modified to the requested type.");
return failure();
}
new_input_types[i] = arg_type;
arg.dropAllUses();
if (quantize_op.use_empty()) {
quantize_op.erase();
}
} else {
// `arg` has multiple uses or the user isn't a quantiz op (so we couldn't
// rewrite it to a different type. Make a copy of the `arg` and replace
// its use.
new_arg = block.addArgument(arg_type, loc);
arg.replaceAllUsesWith(new_arg);
}
block.eraseArgument(0);
}
return success();
}
LogicalResult ModifyIONodesPass::ModifyOutputNodes(
func::FuncOp func, llvm::SmallVectorImpl<Type>& new_output_types,
OpBuilder builder) {
Block& block = func.front();
auto* terminator = block.getTerminator();
builder.setInsertionPoint(terminator);
if (output_type.isa<FloatType>()) {
return success();
}
int num_return_operands = terminator->getNumOperands();
new_output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) {
auto returned_value = terminator->getOperand(i);
Type returned_type = returned_value.getType();
Operation* returned_op = returned_value.getDefiningOp();
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
auto dequantize_input = dequantize_op.input();
Type current_type = quant::QuantizedType::getQuantizedElementType(
dequantize_input.getType())
.getStorageType();
if (current_type == output_type) { // int8 == int8
returned_type = dequantize_input.getType();
returned_value = dequantize_input;
} else if (output_type.isUnsignedInteger(
current_type.getIntOrFloatBitWidth())) { // int8 != uint8
returned_type = quant::ConvertSignedQuantizedToUnsigned(
dequantize_input.getType(), dequantize_op.getLoc());
// replace the dequantize op by a quantize op
TypeAttr type_attr = TypeAttr::get(returned_type);
auto quantize_op = builder.create<QuantizeOp>(
dequantize_op.getLoc(), returned_type, dequantize_input, type_attr);
returned_value = quantize_op.output();
} else {
output_type.print(llvm::errs() << "Requested output type ");
dequantize_op.emitError(" Couldn't be modified to the requested type.");
return failure();
}
new_output_types[i] = returned_type;
terminator->setOperand(i, returned_value);
if (dequantize_op.use_empty()) {
dequantize_op.erase();
}
}
}
return success();
}
void ModifyIONodesPass::runOnOperation() {
auto func = getOperation();
auto attrs = func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
// Handle the entry functions only.
if (func.getName() != "main" && (!attrs || attrs.empty())) {
return;
}
OpBuilder builder(func);
FunctionType func_type = func.getFunctionType();
llvm::SmallVector<Type, 4> new_input_types(func_type.getInputs().begin(),
func_type.getInputs().end());
llvm::SmallVector<Type, 4> new_output_types(func_type.getResults().begin(),
func_type.getResults().end());
if (failed(SetupInputOutputTypesIfNull(builder))) {
return;
}
if (failed(ModifyInputNodes(func, new_input_types, builder))) {
return;
}
if (failed(ModifyOutputNodes(func, new_output_types, builder))) {
return;
}
auto new_func_type =
builder.getFunctionType(new_input_types, new_output_types);
func.setType(new_func_type);
}
} // namespace
// Creates an instance of the TensorFlow Lite modify io nodes pass.
std::unique_ptr<OperationPass<func::FuncOp>> CreateModifyIONodesPass(
Type input_type, Type output_type) {
return std::make_unique<ModifyIONodesPass>(input_type, output_type);
}
static PassRegistration<ModifyIONodesPass> pass;
} // namespace TFL
} // namespace mlir