blob: 85a988a9bde3b44ef30cb211238da9d8b74427bb [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#define DEBUG_TYPE "quantization-context"
namespace mlir {
namespace quant {
QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
: func_(func), target_spec_(spec) {
llvm::DenseMap<Value, int> value_to_state;
func.walk([&](quant::QuantizeRegionOp op) {
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
states_manager_.InitializeOperandState(op, i, &value_to_state);
}
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
states_manager_.InitializeResultState(op, res, &value_to_state);
}
});
}
llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
return all_ops;
}
LogicalResult QuantizeContext::Handle(
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed) {
auto spec = target_spec_.Get(op);
if (!spec.hasValue()) {
op.emitWarning(
"Couldn't find kernel from the registeration for quantization.");
return success();
}
switch (spec->type) {
case ScaleConstraintType::OutputInputFreeScale: {
// no propagation.
*changed = false;
break;
}
case ScaleConstraintType::CustomScale: {
if (failed(spec->scale_fn(this, op, new_items, changed))) {
return failure();
}
break;
}
default: {
llvm_unreachable("no implementation.");
return failure();
}
}
return success();
}
LogicalResult QuantizeContext::Finalize() {
MLIRContext *context = func_.getContext();
func_.walk([&](quant::QuantizeRegionOp op) {
llvm::SmallVector<Attribute, 4> input_specs;
auto original_input_specs = op.input_specs().getValue();
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
auto &state = states_manager_.GetOperandQuantState(op, i);
auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
input_specs.push_back(original_input_specs[i]);
} else if (requantize.pos == RequantizeState::ON_OUTPUT) {
input_specs.push_back(TypeAttr::get(requantize.params));
} else {
input_specs.push_back(TypeAttr::get(state.params));
}
}
op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
llvm::SmallVector<Attribute, 4> output_specs;
auto original_output_specs = op.output_specs().getValue();
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
auto &state = states_manager_.GetResultQuantState(op, res);
auto &requantize = states_manager_.GetResultRequantizeState(op, res);
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
output_specs.push_back(original_output_specs[res]);
} else if (requantize.pos == RequantizeState::ON_INPUT) {
output_specs.push_back(TypeAttr::get(requantize.params));
} else {
output_specs.push_back(TypeAttr::get(state.params));
}
}
op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
});
return success();
}
void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
if (current_op) {
llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
}
func_.walk([&](QuantizeRegionOp op) {
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op.logical_kernel() << " : (";
for (auto i = 0; i < op.getNumOperands(); ++i) {
if (auto params = GetOperandParams(op, i))
params.print(llvm::errs());
else
llvm::errs() << "_";
llvm::errs() << ",";
}
llvm::errs() << ") -> (";
for (auto i = 0; i < op.getNumResults(); ++i) {
if (auto params = GetResultParams(op, i))
params.print(llvm::errs());
else
llvm::errs() << "_";
llvm::errs() << ",";
}
llvm::errs() << ")\n";
});
}
int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
int index, bool as_result) {
Attribute params_attr;
if (as_result) {
params_attr = op.output_specs()[index];
} else {
params_attr = op.input_specs()[index];
}
QuantParams params =
params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
if (as_result) {
result_states_.insert({{op, index}, next_state_index});
} else {
operand_states_.insert({{op, index}, next_state_index});
}
return next_state_index;
}
void QuantizeContext::StatesManager::InitializeOperandState(
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
Value in = op.getOperand(index);
auto cached = cache->insert({in, 0});
if (!cached.second) {
operand_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, /*as_result=*/false);
}
void QuantizeContext::StatesManager::InitializeResultState(
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
auto res = op.getResult(index);
auto cached = cache->insert({res, 0});
if (!cached.second) {
result_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, /*as_result=*/true);
}
bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
llvm_unreachable("no implementation.");
return false;
}
bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
int res_index,
QuantParams params) {
auto &state = GetResultQuantState(op, res_index);
if (state.params == params) {
return false;
}
if (!state.IsEmpty()) {
auto &rescale = GetResultRequantizeState(op, res_index);
rescale.params = params;
rescale.pos = RequantizeState::ON_INPUT;
return false;
}
state.params = params;
return true;
}
bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
QuantParams params) {
auto &state = GetOperandQuantState(op, index);
if (state.params == params) {
return false;
}
if (!state.IsEmpty()) {
auto &rescale = GetOperandRequantizeState(op, index);
rescale.params = params;
rescale.pos = RequantizeState::ON_OUTPUT;
return false;
}
state.params = params;
return true;
}
} // namespace quant
} // namespace mlir