Add the propagation algorithm based on the quant region and target spec
One simple test is added to demonstrate how the custom scale function is used.
PiperOrigin-RevId: 303865677
Change-Id: Ib9eeff4da4dba090fe7ceaa8c1bac97f1c92894f
diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD
index 69994ae..a75135c 100644
--- a/tensorflow/compiler/mlir/lite/quantization/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/BUILD
@@ -123,3 +123,19 @@
"@llvm-project//mlir:Support",
],
)
+
+cc_library(
+ name = "quantization_context",
+ srcs = ["quantization_context.cc"],
+ hdrs = ["quantization_context.h"],
+ deps = [
+ ":device_target",
+ ":quantization_lib",
+ "@com_google_absl//absl/memory",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
new file mode 100644
index 0000000..85a988a
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
@@ -0,0 +1,239 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Matchers.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/StandardTypes.h" // from @llvm-project
+#include "mlir/IR/Value.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+
+#define DEBUG_TYPE "quantization-context"
+
+namespace mlir {
+namespace quant {
+
+QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
+ : func_(func), target_spec_(spec) {
+ llvm::DenseMap<Value, int> value_to_state;
+ func.walk([&](quant::QuantizeRegionOp op) {
+ for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+ states_manager_.InitializeOperandState(op, i, &value_to_state);
+ }
+
+ for (int res = 0, e = op.getNumResults(); res != e; ++res) {
+ states_manager_.InitializeResultState(op, res, &value_to_state);
+ }
+ });
+}
+
+llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
+ llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
+ func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
+ return all_ops;
+}
+
+LogicalResult QuantizeContext::Handle(
+ quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
+ bool *changed) {
+ auto spec = target_spec_.Get(op);
+ if (!spec.hasValue()) {
+ op.emitWarning(
+ "Couldn't find kernel from the registeration for quantization.");
+ return success();
+ }
+ switch (spec->type) {
+ case ScaleConstraintType::OutputInputFreeScale: {
+ // no propagation.
+ *changed = false;
+ break;
+ }
+ case ScaleConstraintType::CustomScale: {
+ if (failed(spec->scale_fn(this, op, new_items, changed))) {
+ return failure();
+ }
+ break;
+ }
+ default: {
+ llvm_unreachable("no implementation.");
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult QuantizeContext::Finalize() {
+ MLIRContext *context = func_.getContext();
+ func_.walk([&](quant::QuantizeRegionOp op) {
+ llvm::SmallVector<Attribute, 4> input_specs;
+ auto original_input_specs = op.input_specs().getValue();
+ for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+ auto &state = states_manager_.GetOperandQuantState(op, i);
+ auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
+ if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
+ input_specs.push_back(original_input_specs[i]);
+ } else if (requantize.pos == RequantizeState::ON_OUTPUT) {
+ input_specs.push_back(TypeAttr::get(requantize.params));
+ } else {
+ input_specs.push_back(TypeAttr::get(state.params));
+ }
+ }
+ op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
+
+ llvm::SmallVector<Attribute, 4> output_specs;
+ auto original_output_specs = op.output_specs().getValue();
+ for (int res = 0, e = op.getNumResults(); res != e; ++res) {
+ auto &state = states_manager_.GetResultQuantState(op, res);
+ auto &requantize = states_manager_.GetResultRequantizeState(op, res);
+ if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
+ output_specs.push_back(original_output_specs[res]);
+ } else if (requantize.pos == RequantizeState::ON_INPUT) {
+ output_specs.push_back(TypeAttr::get(requantize.params));
+ } else {
+ output_specs.push_back(TypeAttr::get(state.params));
+ }
+ }
+ op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
+ });
+ return success();
+}
+
+void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
+ if (current_op) {
+ llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
+ }
+ func_.walk([&](QuantizeRegionOp op) {
+ if (current_op == op) llvm::errs() << "===>>>";
+ llvm::errs() << op.logical_kernel() << " : (";
+ for (auto i = 0; i < op.getNumOperands(); ++i) {
+ if (auto params = GetOperandParams(op, i))
+ params.print(llvm::errs());
+ else
+ llvm::errs() << "_";
+ llvm::errs() << ",";
+ }
+ llvm::errs() << ") -> (";
+ for (auto i = 0; i < op.getNumResults(); ++i) {
+ if (auto params = GetResultParams(op, i))
+ params.print(llvm::errs());
+ else
+ llvm::errs() << "_";
+ llvm::errs() << ",";
+ }
+ llvm::errs() << ")\n";
+ });
+}
+
+int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
+ int index, bool as_result) {
+ Attribute params_attr;
+ if (as_result) {
+ params_attr = op.output_specs()[index];
+ } else {
+ params_attr = op.input_specs()[index];
+ }
+ QuantParams params =
+ params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
+ bool immutable = !EmptyParams(params);
+ int next_state_index = states_.size();
+ states_.push_back({params, immutable});
+ if (as_result) {
+ result_states_.insert({{op, index}, next_state_index});
+ } else {
+ operand_states_.insert({{op, index}, next_state_index});
+ }
+ return next_state_index;
+}
+
+void QuantizeContext::StatesManager::InitializeOperandState(
+ quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
+ Value in = op.getOperand(index);
+ auto cached = cache->insert({in, 0});
+ if (!cached.second) {
+ operand_states_.insert({{op, index}, cached.first->second});
+ return;
+ }
+ cached.first->second = InitializeState(op, index, /*as_result=*/false);
+}
+
+void QuantizeContext::StatesManager::InitializeResultState(
+ quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
+ auto res = op.getResult(index);
+ auto cached = cache->insert({res, 0});
+ if (!cached.second) {
+ result_states_.insert({{op, index}, cached.first->second});
+ return;
+ }
+ cached.first->second = InitializeState(op, index, /*as_result=*/true);
+}
+
+bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
+ llvm_unreachable("no implementation.");
+ return false;
+}
+
+bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
+ int res_index,
+ QuantParams params) {
+ auto &state = GetResultQuantState(op, res_index);
+ if (state.params == params) {
+ return false;
+ }
+ if (!state.IsEmpty()) {
+ auto &rescale = GetResultRequantizeState(op, res_index);
+ rescale.params = params;
+ rescale.pos = RequantizeState::ON_INPUT;
+ return false;
+ }
+ state.params = params;
+ return true;
+}
+
+bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
+ QuantParams params) {
+ auto &state = GetOperandQuantState(op, index);
+ if (state.params == params) {
+ return false;
+ }
+
+ if (!state.IsEmpty()) {
+ auto &rescale = GetOperandRequantizeState(op, index);
+ rescale.params = params;
+ rescale.pos = RequantizeState::ON_OUTPUT;
+ return false;
+ }
+ state.params = params;
+ return true;
+}
+} // namespace quant
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h
new file mode 100644
index 0000000..35ed1fe
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h
@@ -0,0 +1,217 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/Value.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+
+namespace mlir {
+namespace quant {
+
+static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
+
+// The state for each op result during the quantization parameters propagation.
+struct QuantState {
+ // Quantization parameters propagated to an op result.
+ QuantParams params;
+ // A flag indicates this state (the params) shouldn't be changed after it is
+ // initialized. This flag will be set to true if the quantization parameters
+ // are from the quantization-aware training.
+ const bool immutable;
+
+ bool IsEmpty() { return EmptyParams(params); }
+};
+
+// The state for rescaling the propagated quantization parameters. This can be
+// on the input side to satisfy the constraint of previous operation, or on the
+// output side to satisfy the constraint of the next operation.
+struct RequantizeState {
+ // Sometimes, we have to "requantize" the quantization result to satisfy all
+ // the constraints. The "requantize" can happen either on the input or output
+ // of the quantization result.
+ enum RequantizePosition {
+ NO_REQUANTIZE,
+ ON_INPUT,
+ ON_OUTPUT
+ } pos = NO_REQUANTIZE;
+
+ // Quantization parameters will be used to add the requantize ops.
+ QuantParams params;
+};
+
+// This class manages all the intermedaite quantization states.
+class QuantizeContext {
+ public:
+ QuantizeContext(FuncOp func, const DeviceTarget &spec);
+
+ // Returns all the quant region ops.
+ ArrayRef<quant::QuantizeRegionOp> GetAllOps();
+
+ // For each quant region op, propagates its quantization parameters according
+ // to the kernel specification and also returns the adjcent quant region ops
+ // which get the new quantization parameters propagated.
+ LogicalResult Handle(quant::QuantizeRegionOp op,
+ llvm::SmallVectorImpl<Operation *> *new_items,
+ bool *changed);
+
+ // Updates the port quantization specifications of all the quant region ops
+ // with the propagation results.
+ LogicalResult Finalize();
+
+ // Dumps the states stores in the state manager.
+ void DumpStates(QuantizeRegionOp current_op = {});
+
+ // Update the quantization parameter for certain result of the op. By this
+ // method, the quantization parameter is propagated to all the users of the
+ // result as well.
+ bool SetResultParams(Operation *op, int index, QuantParams params) {
+ return states_manager_.SetResultParams(op, index, params);
+ }
+
+ // Update the quantization parameter for certain operand of the op. By this
+ // method, the quantization parameter is propagated to the defining op of
+ // operand as well.
+ bool SetOperandParams(Operation *op, int index, QuantParams params) {
+ return states_manager_.SetOperandParams(op, index, params);
+ }
+
+ // Return the quantization parameter of certain result of the op.
+ QuantParams GetResultParams(Operation *op, int index) {
+ return states_manager_.GetResultParams(op, index);
+ }
+
+ // Return the quantization parameter of certain operand of the op.
+ QuantParams GetOperandParams(Operation *op, int index) {
+ return states_manager_.GetOperandParams(op, index);
+ }
+
+ private:
+ class StatesManager {
+ public:
+ // Sets the quantization parameters of the constant result according to its
+ // content.
+ //
+ // Always returns true.
+ bool SetConstantResultParams(Operation *op);
+
+ // Sets the quantization parameters of the result to a fixed value. If any
+ // quantization parameters have been propagated, a `requantize` will happen
+ // on the input of propagated quantization.
+ //
+ // Returns true, if the users of the result needs to be added to the
+ // worklist.
+ bool SetResultParams(Operation *op, int index, QuantParams params);
+
+ // Sets the quantization parameters of the operand to a fixed value. If any
+ // quantization parameters have been propagated, a `requantize` will happen
+ // on the output of propagated quantization.
+ //
+ // Returns true, if the defining op of the operand needs to be added to the
+ // worklist.
+ bool SetOperandParams(Operation *op, int index, QuantParams params);
+
+ // Returns the quantization parameters of the index-th result of the op.
+ QuantParams GetResultParams(Operation *op, int index) {
+ return states_[result_states_[{op, index}]].params;
+ }
+
+ // Returns the quantization parameters of the index-th operand of the op.
+ QuantParams GetOperandParams(Operation *op, int index) {
+ return states_[operand_states_[{op, index}]].params;
+ }
+
+ private:
+ friend class QuantizeContext;
+
+ // Uses the type of `val` to set the initial state of the index-th result if
+ // `as_result` is true or index-th operand if `as_result` is false. The
+ // state is immutable if the type is a quantized type. Returns the index of
+ // this new state in the state vector.
+ int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
+
+ // Sets the state of the index-th operand of the op. If this operand is
+ // cached, uses the cached result without creating new entry in the state
+ // vector. Otherwise, allocate a new entry in the state vector.
+ void InitializeOperandState(quant::QuantizeRegionOp op, int index,
+ llvm::DenseMap<Value, int> *cache);
+
+ // Sets the state of the index-th result of the op. If this result is
+ // cached, uses the cached result without creating new entry in the state
+ // vector. Otherwise, allocate a new entry in the state vector.
+ void InitializeResultState(quant::QuantizeRegionOp op, int index,
+ llvm::DenseMap<Value, int> *cache);
+
+ // Returns the state of the index-th operand of the op.
+ QuantState &GetOperandQuantState(Operation *op, int index) {
+ return states_[operand_states_[{op, index}]];
+ }
+
+ // Returns the state of the index-th result of the op.
+ QuantState &GetResultQuantState(Operation *op, int index) {
+ return states_[result_states_[{op, index}]];
+ }
+
+ // Returns the state of the index-th operand of the op.
+ RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
+ return rescale_states_[operand_states_[{op, index}]];
+ }
+
+ // Returns the state of the index-th result of the op.
+ RequantizeState &GetResultRequantizeState(Operation *op, int index) {
+ return rescale_states_[result_states_[{op, index}]];
+ }
+
+ private:
+ // This is used to identify an operand or result of an op. The second
+ // element of this pair is the index of the operand or result.
+ using OpValue = std::pair<mlir::Operation *, int>;
+
+ // The vector contains all the quantization parameters propagated from the
+ // defining operations of the value, or from the quantization aware
+ // training.
+ std::vector<QuantState> states_;
+
+ // The map contains all the quantization parameters which are required to
+ // satisfy the same operands and results constraint. The keys of this map
+ // are the values from `operand_states_` and `result_state_`.
+ std::unordered_map<int, RequantizeState> rescale_states_;
+
+ // Maps of indexes to the propagation state vector from the ops operands,
+ // results and arguments.
+ llvm::DenseMap<OpValue, int> operand_states_;
+ llvm::DenseMap<OpValue, int> result_states_;
+ };
+
+ FuncOp func_;
+
+ DeviceTarget target_spec_;
+
+ StatesManager states_manager_;
+};
+
+} // namespace quant
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
index 098739b..2bc1568 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
@@ -35,6 +35,7 @@
deps = [
":cpu_device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/xla/client/lib:quantize",
@@ -60,6 +61,7 @@
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:device_target",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc
index 94ae788..e4bdafa 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
namespace mlir {
namespace xla_hlo {
@@ -36,5 +37,23 @@
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
}
+
+LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
+ quant::QuantizeContext* ctx, Operation* op,
+ quant::AdjacentOperations* new_items, bool* changed) {
+ auto bias_params = ctx->GetOperandParams(op, 2);
+ if (!EmptyParams(bias_params)) {
+ return success();
+ }
+ std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
+ ctx->GetOperandParams(op, 1)};
+ auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
+ if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
+ *changed = true;
+ new_items->push_back(op->getOperand(2).getDefiningOp());
+ }
+ return success();
+}
+
} // namespace xla_hlo
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
index 4087eeb..c4c5904 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
@@ -26,7 +26,9 @@
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
@@ -59,9 +61,36 @@
void PropagateQuantPass::runOnFunction() {
FuncOp func = getFunction();
+ // TODO(fengliuai): deprecate this old code generation path.
// XLA only support uint8/uint16 quantization for now.
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
disable_per_channel, GetOpQuantSpec);
+
+ CpuDeviceTarget spec(&getContext());
+ quant::QuantizeContext ctx(func, spec);
+
+ std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps());
+ bool changed = false;
+ while (!work_list.empty()) {
+ quant::QuantizeRegionOp op = work_list.back();
+ work_list.pop_back();
+
+ llvm::SmallVector<Operation *, 4> new_items;
+ if (failed(ctx.Handle(op, &new_items, &changed))) {
+ // The IR is still valid, thus we shouldn't fail.
+ signalPassFailure();
+ }
+ for (auto item : new_items) {
+ if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
+ work_list.push_back(reg);
+ }
+ }
+
+ if (!changed) return;
+
+ if (failed(ctx.Finalize())) {
+ signalPassFailure();
+ }
}
} // namespace
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir
new file mode 100644
index 0000000..05ac48c
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir
@@ -0,0 +1,54 @@
+// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
+
+// -----
+
+// CHECK-LABEL: @mul_add_source_no_params
+func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ %region = "quant.region"(%arg0, %arg1, %arg2) ( {
+ ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
+ %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+ %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+ "quant.return"(%add) : (tensor<4xf32>) -> ()
+ }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
+ (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %region : tensor<4xf32>
+
+// CHECK: input_specs = [f32, f32, f32]
+// CHECK-SAME: output_specs = [f32]
+}
+
+// -----
+
+// CHECK-LABEL: @mul_add_annotated_no_narrow_range
+func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ %region = "quant.region"(%arg0, %arg1, %arg2) ( {
+ ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
+ %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+ %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+ "quant.return"(%add) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
+ logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
+ (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %region : tensor<4xf32>
+
+// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
+// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
+}
+
+// -----
+
+// CHECK-LABEL: @mul_add_annotated
+func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+ %region = "quant.region"(%arg0, %arg1, %arg2) ( {
+ ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
+ %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+ %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+ "quant.return"(%add) : (tensor<4xf32>) -> ()
+ }) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
+ logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
+ (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %region : tensor<4xf32>
+
+// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
+// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
+}