blob: ce628fe4e258d96a083d2f352c3fb7dd0c0000a7 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h"
#include <queue>
#include <stack>
#include <string>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace quant {
constexpr char kAttrMapAttribute[] = "attr_map";
// This attribute will be set for functions created by this pass.
constexpr char kFusedFunctionAttr[] = "tf_quant.fused_function";
// The keyword to detect if this is a `NullAttribute`.
constexpr char kNullAttributeValue[] = "N/A";
// Checks if the op is inside a lifted function.
bool IsInLiftedFunc(Operation *op) {
return op->getParentOfType<func::FuncOp>()->hasAttr(kFusedFunctionAttr);
}
// Inserts the function to the symbol table of the module thread-safely.
StringAttr InsertToSymbolTable(Operation *module, Operation *function,
const std::string &func_name) {
static tensorflow::mutex *mtx = new tensorflow::mutex();
tensorflow::mutex_lock lock(*mtx);
SymbolTable symbol_table(module);
std::string unique_name = func_name;
int32_t uniquing_counter = 0;
while (symbol_table.lookup(unique_name) != nullptr) {
++uniquing_counter;
unique_name = func_name + "_" + std::to_string(uniquing_counter);
}
function->setAttr("sym_name",
StringAttr::get(module->getContext(), unique_name));
return symbol_table.insert(function);
}
ValueRange createFusedFnCall(OpBuilder builder, Location location,
StringRef func_name, TypeRange output_types,
ValueRange args) {
TF::PartitionedCallOp call_op = builder.create<TF::PartitionedCallOp>(
location, output_types, args,
FlatSymbolRefAttr::get(builder.getStringAttr(func_name)),
/*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
call_op->setAttr(
kQuantTraitAttrName,
builder.getStringAttr(llvm::StringRef(
std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable]))));
return call_op.output();
}
// Finds ops in the paths from arguments to results. The ops is listed in an
// order that the former ops shouldn't have any dependencies on the later ones.
llvm::SmallVector<Operation *> FindOpsFromArgumentsToResults(
const llvm::SmallVector<Value> &arguments,
const llvm::SmallVector<Value> &results) {
std::queue<Value> value_queue;
for (Value result : results) {
value_queue.push(result);
}
absl::flat_hash_set<mlir::detail::ValueImpl *> argument_set;
for (Value argument : arguments) {
argument_set.insert(argument.getImpl());
}
// Searching for ops from results to arguments. Duplicate ops in the op stack
// are intentional in order to make sure the op on the top of the stack
// doesn't depends on any ops below it.
std::stack<Operation *> op_stack;
while (!value_queue.empty()) {
Value current_value = value_queue.front();
value_queue.pop();
Operation *defining_node = current_value.getDefiningOp();
if (defining_node == nullptr) continue;
op_stack.push(defining_node);
for (const auto &arg : defining_node->getOperands()) {
if (!argument_set.contains(arg.getImpl())) {
value_queue.push(arg);
}
}
}
// Remove duplicate ops from the op stack.
llvm::SmallVector<Operation *> sorted_ops;
absl::flat_hash_set<Operation *> unique_ops;
while (!op_stack.empty()) {
Operation *current_op = op_stack.top();
op_stack.pop();
if (unique_ops.contains(current_op)) continue;
sorted_ops.push_back(current_op);
unique_ops.insert(current_op);
}
return sorted_ops;
}
// Finds the name of each attribute in `attributes` and set the attr_map
// attribute which maps an attribute identifier to its attribute name. The
// identifier is the order of that attribute in `attributes`. This map
// is then used to set attributes in the quantized functions in the
// QuantizeCompositeFunctionsPass.
// This function returns success if all attributes could be found.
LogicalResult SetAttributeMap(MLIRContext *context,
const llvm::SmallVector<Attribute> &attributes,
const llvm::SmallVector<Operation *> &ops) {
// A map to find which operation an attribute belongs to.
llvm::SmallDenseMap<Attribute, Operation *> attr_to_op_map;
// A map from the attribute to its name.
llvm::SmallDenseMap<Attribute, llvm::StringRef> attr_to_name_map;
for (Operation *op : ops) {
for (const auto &named_attr : op->getAttrs()) {
attr_to_op_map.insert({named_attr.getValue(), op});
attr_to_name_map.insert(
{named_attr.getValue(), named_attr.getName().getValue()});
}
}
for (int idx : llvm::seq<int>(0, attributes.size())) {
const Attribute &attribute = attributes[idx];
// Skip following steps if this attribute is a `NullAttribute`.
auto string_attr = attribute.dyn_cast_or_null<StringAttr>();
if (string_attr != nullptr &&
string_attr.getValue().equals(kNullAttributeValue)) {
continue;
}
if (attr_to_op_map.count(attribute) == 0) {
return failure();
}
llvm::StringRef attribute_name = attr_to_name_map[attribute];
std::string identifier = std::to_string(idx);
Operation *owner_op = attr_to_op_map[attribute];
std::string new_attr_map_str;
if (owner_op->hasAttr(kAttrMapAttribute)) {
new_attr_map_str =
owner_op->getAttrOfType<StringAttr>(kAttrMapAttribute).str();
absl::StrAppend(&new_attr_map_str, ",");
}
absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str());
owner_op->setAttr(kAttrMapAttribute,
StringAttr::get(context, new_attr_map_str));
}
return success();
}
// Creates a function to wrap the section between arguments and results.
llvm::SmallVector<Value, 4> LiftAsFunctionCall(
OpBuilder builder, Location location, StringRef func_name,
const llvm::SmallVector<Value> &arguments,
const llvm::SmallVector<Value> &results,
const llvm::SmallVector<Attribute> &attributes) {
MLIRContext *context = builder.getContext();
if (results.empty()) {
mlir::emitError(UnknownLoc::get(context), "No result values specified");
return {};
}
Operation *result_op = results[0].getDefiningOp();
auto module = result_op->getParentOfType<ModuleOp>();
// Create a private function and copy all ops between arguments and results.
auto current_func = result_op->getParentOfType<func::FuncOp>();
auto guard = OpBuilder::InsertionGuard(builder);
builder.setInsertionPointAfter(current_func);
TypeRange arg_types(
llvm::ArrayRef<Value>(arguments.begin(), arguments.end()));
TypeRange result_types(llvm::ArrayRef<Value>(results.begin(), results.end()));
auto func_type = FunctionType::get(context, arg_types, result_types);
llvm::SmallVector<Location> arg_locs;
for (const auto &arg : arguments) {
arg_locs.push_back(arg.getLoc());
}
auto wrap_func = builder.create<func::FuncOp>(location, func_name, func_type);
wrap_func.setVisibility(SymbolTable::Visibility::Private);
wrap_func->setAttr(kFusedFunctionAttr, builder.getUnitAttr());
builder.createBlock(&wrap_func.getBody(), wrap_func.begin(), arg_types,
arg_locs);
BlockAndValueMapping mapping;
for (int32_t i : llvm::seq<int32_t>(0, arguments.size())) {
mapping.map(arguments[i], wrap_func.getArgument(i));
}
auto cloning_ops = FindOpsFromArgumentsToResults(arguments, results);
if (failed(SetAttributeMap(context, attributes, cloning_ops))) {
current_func.emitError() << "Some attributes couldn't be found.";
}
for (Operation *op : cloning_ops) {
builder.clone(*op, mapping);
}
llvm::SmallVector<Value> return_values;
for (Value result : results) {
return_values.push_back(mapping.lookupOrNull(result));
}
builder.create<mlir::func::ReturnOp>(location, return_values);
// Create a function call to the newly created function.
StringAttr new_func_name =
InsertToSymbolTable(module, wrap_func, func_name.str());
builder.setInsertionPointAfter(result_op);
ValueRange new_results = createFusedFnCall(
builder, location, new_func_name.getValue(), result_types, arguments);
return llvm::SmallVector<Value, 4>(new_results.begin(), new_results.end());
}
llvm::SmallVector<Value, 4> LiftAsFunctionCall(
OpBuilder builder, Location location, StringRef func_name,
const llvm::SmallVector<Value> &arguments,
const llvm::SmallVector<Value> &results) {
llvm::SmallVector<Attribute> attributes;
return LiftAsFunctionCall(builder, location, func_name, arguments, results,
attributes);
}
} // namespace quant
} // namespace mlir