blob: 35ed1feaaab5b84cdb1361f20e53273cd8fa40bf [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
#include "llvm/ADT/DenseMap.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir {
namespace quant {
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
// The state for each op result during the quantization parameters propagation.
struct QuantState {
// Quantization parameters propagated to an op result.
QuantParams params;
// A flag indicates this state (the params) shouldn't be changed after it is
// initialized. This flag will be set to true if the quantization parameters
// are from the quantization-aware training.
const bool immutable;
bool IsEmpty() { return EmptyParams(params); }
};
// The state for rescaling the propagated quantization parameters. This can be
// on the input side to satisfy the constraint of previous operation, or on the
// output side to satisfy the constraint of the next operation.
struct RequantizeState {
// Sometimes, we have to "requantize" the quantization result to satisfy all
// the constraints. The "requantize" can happen either on the input or output
// of the quantization result.
enum RequantizePosition {
NO_REQUANTIZE,
ON_INPUT,
ON_OUTPUT
} pos = NO_REQUANTIZE;
// Quantization parameters will be used to add the requantize ops.
QuantParams params;
};
// This class manages all the intermedaite quantization states.
class QuantizeContext {
public:
QuantizeContext(FuncOp func, const DeviceTarget &spec);
// Returns all the quant region ops.
ArrayRef<quant::QuantizeRegionOp> GetAllOps();
// For each quant region op, propagates its quantization parameters according
// to the kernel specification and also returns the adjcent quant region ops
// which get the new quantization parameters propagated.
LogicalResult Handle(quant::QuantizeRegionOp op,
llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed);
// Updates the port quantization specifications of all the quant region ops
// with the propagation results.
LogicalResult Finalize();
// Dumps the states stores in the state manager.
void DumpStates(QuantizeRegionOp current_op = {});
// Update the quantization parameter for certain result of the op. By this
// method, the quantization parameter is propagated to all the users of the
// result as well.
bool SetResultParams(Operation *op, int index, QuantParams params) {
return states_manager_.SetResultParams(op, index, params);
}
// Update the quantization parameter for certain operand of the op. By this
// method, the quantization parameter is propagated to the defining op of
// operand as well.
bool SetOperandParams(Operation *op, int index, QuantParams params) {
return states_manager_.SetOperandParams(op, index, params);
}
// Return the quantization parameter of certain result of the op.
QuantParams GetResultParams(Operation *op, int index) {
return states_manager_.GetResultParams(op, index);
}
// Return the quantization parameter of certain operand of the op.
QuantParams GetOperandParams(Operation *op, int index) {
return states_manager_.GetOperandParams(op, index);
}
private:
class StatesManager {
public:
// Sets the quantization parameters of the constant result according to its
// content.
//
// Always returns true.
bool SetConstantResultParams(Operation *op);
// Sets the quantization parameters of the result to a fixed value. If any
// quantization parameters have been propagated, a `requantize` will happen
// on the input of propagated quantization.
//
// Returns true, if the users of the result needs to be added to the
// worklist.
bool SetResultParams(Operation *op, int index, QuantParams params);
// Sets the quantization parameters of the operand to a fixed value. If any
// quantization parameters have been propagated, a `requantize` will happen
// on the output of propagated quantization.
//
// Returns true, if the defining op of the operand needs to be added to the
// worklist.
bool SetOperandParams(Operation *op, int index, QuantParams params);
// Returns the quantization parameters of the index-th result of the op.
QuantParams GetResultParams(Operation *op, int index) {
return states_[result_states_[{op, index}]].params;
}
// Returns the quantization parameters of the index-th operand of the op.
QuantParams GetOperandParams(Operation *op, int index) {
return states_[operand_states_[{op, index}]].params;
}
private:
friend class QuantizeContext;
// Uses the type of `val` to set the initial state of the index-th result if
// `as_result` is true or index-th operand if `as_result` is false. The
// state is immutable if the type is a quantized type. Returns the index of
// this new state in the state vector.
int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
// Sets the state of the index-th operand of the op. If this operand is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
void InitializeOperandState(quant::QuantizeRegionOp op, int index,
llvm::DenseMap<Value, int> *cache);
// Sets the state of the index-th result of the op. If this result is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
void InitializeResultState(quant::QuantizeRegionOp op, int index,
llvm::DenseMap<Value, int> *cache);
// Returns the state of the index-th operand of the op.
QuantState &GetOperandQuantState(Operation *op, int index) {
return states_[operand_states_[{op, index}]];
}
// Returns the state of the index-th result of the op.
QuantState &GetResultQuantState(Operation *op, int index) {
return states_[result_states_[{op, index}]];
}
// Returns the state of the index-th operand of the op.
RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
return rescale_states_[operand_states_[{op, index}]];
}
// Returns the state of the index-th result of the op.
RequantizeState &GetResultRequantizeState(Operation *op, int index) {
return rescale_states_[result_states_[{op, index}]];
}
private:
// This is used to identify an operand or result of an op. The second
// element of this pair is the index of the operand or result.
using OpValue = std::pair<mlir::Operation *, int>;
// The vector contains all the quantization parameters propagated from the
// defining operations of the value, or from the quantization aware
// training.
std::vector<QuantState> states_;
// The map contains all the quantization parameters which are required to
// satisfy the same operands and results constraint. The keys of this map
// are the values from `operand_states_` and `result_state_`.
std::unordered_map<int, RequantizeState> rescale_states_;
// Maps of indexes to the propagation state vector from the ops operands,
// results and arguments.
llvm::DenseMap<OpValue, int> operand_states_;
llvm::DenseMap<OpValue, int> result_states_;
};
FuncOp func_;
DeviceTarget target_spec_;
StatesManager states_manager_;
};
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_