Add the propagation algorithm based on the quant region and target spec

One simple test is added to demonstrate how the custom scale function is used.

PiperOrigin-RevId: 303865677
Change-Id: Ib9eeff4da4dba090fe7ceaa8c1bac97f1c92894f
diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD
index 69994ae..a75135c 100644
--- a/tensorflow/compiler/mlir/lite/quantization/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/BUILD
@@ -123,3 +123,19 @@
         "@llvm-project//mlir:Support",
     ],
 )
+
+cc_library(
+    name = "quantization_context",
+    srcs = ["quantization_context.cc"],
+    hdrs = ["quantization_context.h"],
+    deps = [
+        ":device_target",
+        ":quantization_lib",
+        "@com_google_absl//absl/memory",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:QuantOps",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
new file mode 100644
index 0000000..85a988a
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc
@@ -0,0 +1,239 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Matchers.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+
+#define DEBUG_TYPE "quantization-context"
+
+namespace mlir {
+namespace quant {
+
+QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
+    : func_(func), target_spec_(spec) {
+  llvm::DenseMap<Value, int> value_to_state;
+  func.walk([&](quant::QuantizeRegionOp op) {
+    for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+      states_manager_.InitializeOperandState(op, i, &value_to_state);
+    }
+
+    for (int res = 0, e = op.getNumResults(); res != e; ++res) {
+      states_manager_.InitializeResultState(op, res, &value_to_state);
+    }
+  });
+}
+
+llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
+  llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
+  func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
+  return all_ops;
+}
+
+LogicalResult QuantizeContext::Handle(
+    quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
+    bool *changed) {
+  auto spec = target_spec_.Get(op);
+  if (!spec.hasValue()) {
+    op.emitWarning(
+        "Couldn't find kernel from the registeration for quantization.");
+    return success();
+  }
+  switch (spec->type) {
+    case ScaleConstraintType::OutputInputFreeScale: {
+      // no propagation.
+      *changed = false;
+      break;
+    }
+    case ScaleConstraintType::CustomScale: {
+      if (failed(spec->scale_fn(this, op, new_items, changed))) {
+        return failure();
+      }
+      break;
+    }
+    default: {
+      llvm_unreachable("no implementation.");
+      return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult QuantizeContext::Finalize() {
+  MLIRContext *context = func_.getContext();
+  func_.walk([&](quant::QuantizeRegionOp op) {
+    llvm::SmallVector<Attribute, 4> input_specs;
+    auto original_input_specs = op.input_specs().getValue();
+    for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+      auto &state = states_manager_.GetOperandQuantState(op, i);
+      auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
+      if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
+        input_specs.push_back(original_input_specs[i]);
+      } else if (requantize.pos == RequantizeState::ON_OUTPUT) {
+        input_specs.push_back(TypeAttr::get(requantize.params));
+      } else {
+        input_specs.push_back(TypeAttr::get(state.params));
+      }
+    }
+    op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
+
+    llvm::SmallVector<Attribute, 4> output_specs;
+    auto original_output_specs = op.output_specs().getValue();
+    for (int res = 0, e = op.getNumResults(); res != e; ++res) {
+      auto &state = states_manager_.GetResultQuantState(op, res);
+      auto &requantize = states_manager_.GetResultRequantizeState(op, res);
+      if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
+        output_specs.push_back(original_output_specs[res]);
+      } else if (requantize.pos == RequantizeState::ON_INPUT) {
+        output_specs.push_back(TypeAttr::get(requantize.params));
+      } else {
+        output_specs.push_back(TypeAttr::get(state.params));
+      }
+    }
+    op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
+  });
+  return success();
+}
+
+void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
+  if (current_op) {
+    llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
+  }
+  func_.walk([&](QuantizeRegionOp op) {
+    if (current_op == op) llvm::errs() << "===>>>";
+    llvm::errs() << op.logical_kernel() << " : (";
+    for (auto i = 0; i < op.getNumOperands(); ++i) {
+      if (auto params = GetOperandParams(op, i))
+        params.print(llvm::errs());
+      else
+        llvm::errs() << "_";
+      llvm::errs() << ",";
+    }
+    llvm::errs() << ") -> (";
+    for (auto i = 0; i < op.getNumResults(); ++i) {
+      if (auto params = GetResultParams(op, i))
+        params.print(llvm::errs());
+      else
+        llvm::errs() << "_";
+      llvm::errs() << ",";
+    }
+    llvm::errs() << ")\n";
+  });
+}
+
+int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
+                                                    int index, bool as_result) {
+  Attribute params_attr;
+  if (as_result) {
+    params_attr = op.output_specs()[index];
+  } else {
+    params_attr = op.input_specs()[index];
+  }
+  QuantParams params =
+      params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
+  bool immutable = !EmptyParams(params);
+  int next_state_index = states_.size();
+  states_.push_back({params, immutable});
+  if (as_result) {
+    result_states_.insert({{op, index}, next_state_index});
+  } else {
+    operand_states_.insert({{op, index}, next_state_index});
+  }
+  return next_state_index;
+}
+
+void QuantizeContext::StatesManager::InitializeOperandState(
+    quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
+  Value in = op.getOperand(index);
+  auto cached = cache->insert({in, 0});
+  if (!cached.second) {
+    operand_states_.insert({{op, index}, cached.first->second});
+    return;
+  }
+  cached.first->second = InitializeState(op, index, /*as_result=*/false);
+}
+
+void QuantizeContext::StatesManager::InitializeResultState(
+    quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
+  auto res = op.getResult(index);
+  auto cached = cache->insert({res, 0});
+  if (!cached.second) {
+    result_states_.insert({{op, index}, cached.first->second});
+    return;
+  }
+  cached.first->second = InitializeState(op, index, /*as_result=*/true);
+}
+
+bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
+  llvm_unreachable("no implementation.");
+  return false;
+}
+
+bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
+                                                     int res_index,
+                                                     QuantParams params) {
+  auto &state = GetResultQuantState(op, res_index);
+  if (state.params == params) {
+    return false;
+  }
+  if (!state.IsEmpty()) {
+    auto &rescale = GetResultRequantizeState(op, res_index);
+    rescale.params = params;
+    rescale.pos = RequantizeState::ON_INPUT;
+    return false;
+  }
+  state.params = params;
+  return true;
+}
+
+bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
+                                                      QuantParams params) {
+  auto &state = GetOperandQuantState(op, index);
+  if (state.params == params) {
+    return false;
+  }
+
+  if (!state.IsEmpty()) {
+    auto &rescale = GetOperandRequantizeState(op, index);
+    rescale.params = params;
+    rescale.pos = RequantizeState::ON_OUTPUT;
+    return false;
+  }
+  state.params = params;
+  return true;
+}
+}  //  namespace quant
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h
new file mode 100644
index 0000000..35ed1fe
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h
@@ -0,0 +1,217 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+
+namespace mlir {
+namespace quant {
+
+static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
+
+// The state for each op result during the quantization parameters propagation.
+struct QuantState {
+  // Quantization parameters propagated to an op result.
+  QuantParams params;
+  // A flag indicates this state (the params) shouldn't be changed after it is
+  // initialized. This flag will be set to true if the quantization parameters
+  // are from the quantization-aware training.
+  const bool immutable;
+
+  bool IsEmpty() { return EmptyParams(params); }
+};
+
+// The state for rescaling the propagated quantization parameters. This can be
+// on the input side to satisfy the constraint of previous operation, or on the
+// output side to satisfy the constraint of the next operation.
+struct RequantizeState {
+  // Sometimes, we have to "requantize" the quantization result to satisfy all
+  // the constraints. The "requantize" can happen either on the input or output
+  // of the quantization result.
+  enum RequantizePosition {
+    NO_REQUANTIZE,
+    ON_INPUT,
+    ON_OUTPUT
+  } pos = NO_REQUANTIZE;
+
+  // Quantization parameters will be used to add the requantize ops.
+  QuantParams params;
+};
+
+// This class manages all the intermedaite quantization states.
+class QuantizeContext {
+ public:
+  QuantizeContext(FuncOp func, const DeviceTarget &spec);
+
+  // Returns all the quant region ops.
+  ArrayRef<quant::QuantizeRegionOp> GetAllOps();
+
+  // For each quant region op, propagates its quantization parameters according
+  // to the kernel specification and also returns the adjcent quant region ops
+  // which get the new quantization parameters propagated.
+  LogicalResult Handle(quant::QuantizeRegionOp op,
+                       llvm::SmallVectorImpl<Operation *> *new_items,
+                       bool *changed);
+
+  // Updates the port quantization specifications of all the quant region ops
+  // with the propagation results.
+  LogicalResult Finalize();
+
+  // Dumps the states stores in the state manager.
+  void DumpStates(QuantizeRegionOp current_op = {});
+
+  // Update the quantization parameter for certain result of the op. By this
+  // method, the quantization parameter is propagated to all the users of the
+  // result as well.
+  bool SetResultParams(Operation *op, int index, QuantParams params) {
+    return states_manager_.SetResultParams(op, index, params);
+  }
+
+  // Update the quantization parameter for certain operand of the op. By this
+  // method, the quantization parameter is propagated to the defining op of
+  // operand as well.
+  bool SetOperandParams(Operation *op, int index, QuantParams params) {
+    return states_manager_.SetOperandParams(op, index, params);
+  }
+
+  // Return the quantization parameter of certain result of the op.
+  QuantParams GetResultParams(Operation *op, int index) {
+    return states_manager_.GetResultParams(op, index);
+  }
+
+  // Return the quantization parameter of certain operand of the op.
+  QuantParams GetOperandParams(Operation *op, int index) {
+    return states_manager_.GetOperandParams(op, index);
+  }
+
+ private:
+  class StatesManager {
+   public:
+    // Sets the quantization parameters of the constant result according to its
+    // content.
+    //
+    // Always returns true.
+    bool SetConstantResultParams(Operation *op);
+
+    // Sets the quantization parameters of the result to a fixed value. If any
+    // quantization parameters have been propagated, a `requantize` will happen
+    // on the input of propagated quantization.
+    //
+    // Returns true, if the users of the result needs to be added to the
+    // worklist.
+    bool SetResultParams(Operation *op, int index, QuantParams params);
+
+    // Sets the quantization parameters of the operand to a fixed value. If any
+    // quantization parameters have been propagated, a `requantize` will happen
+    // on the output of propagated quantization.
+    //
+    // Returns true, if the defining op of the operand needs to be added to the
+    // worklist.
+    bool SetOperandParams(Operation *op, int index, QuantParams params);
+
+    // Returns the quantization parameters of the index-th result of the op.
+    QuantParams GetResultParams(Operation *op, int index) {
+      return states_[result_states_[{op, index}]].params;
+    }
+
+    // Returns the quantization parameters of the index-th operand of the op.
+    QuantParams GetOperandParams(Operation *op, int index) {
+      return states_[operand_states_[{op, index}]].params;
+    }
+
+   private:
+    friend class QuantizeContext;
+
+    // Uses the type of `val` to set the initial state of the index-th result if
+    // `as_result` is true or index-th operand if `as_result` is false. The
+    // state is immutable if the type is a quantized type. Returns the index of
+    // this new state in the state vector.
+    int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
+
+    // Sets the state of the index-th operand of the op. If this operand is
+    // cached, uses the cached result without creating new entry in the state
+    // vector. Otherwise, allocate a new entry in the state vector.
+    void InitializeOperandState(quant::QuantizeRegionOp op, int index,
+                                llvm::DenseMap<Value, int> *cache);
+
+    // Sets the state of the index-th result of the op. If this result is
+    // cached, uses the cached result without creating new entry in the state
+    // vector. Otherwise, allocate a new entry in the state vector.
+    void InitializeResultState(quant::QuantizeRegionOp op, int index,
+                               llvm::DenseMap<Value, int> *cache);
+
+    // Returns the state of the index-th operand of the op.
+    QuantState &GetOperandQuantState(Operation *op, int index) {
+      return states_[operand_states_[{op, index}]];
+    }
+
+    // Returns the state of the index-th result of the op.
+    QuantState &GetResultQuantState(Operation *op, int index) {
+      return states_[result_states_[{op, index}]];
+    }
+
+    // Returns the state of the index-th operand of the op.
+    RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
+      return rescale_states_[operand_states_[{op, index}]];
+    }
+
+    // Returns the state of the index-th result of the op.
+    RequantizeState &GetResultRequantizeState(Operation *op, int index) {
+      return rescale_states_[result_states_[{op, index}]];
+    }
+
+   private:
+    // This is used to identify an operand or result of an op. The second
+    // element of this pair is the index of the operand or result.
+    using OpValue = std::pair<mlir::Operation *, int>;
+
+    // The vector contains all the quantization parameters propagated from the
+    // defining operations of the value, or from the quantization aware
+    // training.
+    std::vector<QuantState> states_;
+
+    // The map contains all the quantization parameters which are required to
+    // satisfy the same operands and results constraint. The keys of this map
+    // are the values from `operand_states_` and `result_state_`.
+    std::unordered_map<int, RequantizeState> rescale_states_;
+
+    // Maps of indexes to the propagation state vector from the ops operands,
+    // results and arguments.
+    llvm::DenseMap<OpValue, int> operand_states_;
+    llvm::DenseMap<OpValue, int> result_states_;
+  };
+
+  FuncOp func_;
+
+  DeviceTarget target_spec_;
+
+  StatesManager states_manager_;
+};
+
+}  // namespace quant
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
index 098739b..2bc1568 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
@@ -35,6 +35,7 @@
     deps = [
         ":cpu_device_target",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+        "//tensorflow/compiler/mlir/lite/quantization:quantization_context",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
         "//tensorflow/compiler/mlir/xla:hlo",
         "//tensorflow/compiler/xla/client/lib:quantize",
@@ -60,6 +61,7 @@
     ],
     deps = [
         "//tensorflow/compiler/mlir/lite/quantization:device_target",
+        "//tensorflow/compiler/mlir/lite/quantization:quantization_context",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:QuantOps",
         "@llvm-project//mlir:Support",
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc
index 94ae788..e4bdafa 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.cc
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
 
 namespace mlir {
 namespace xla_hlo {
@@ -36,5 +37,23 @@
                  std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
                            this, ph::_1, ph::_2, ph::_3, ph::_4));
 }
+
+LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
+    quant::QuantizeContext* ctx, Operation* op,
+    quant::AdjacentOperations* new_items, bool* changed) {
+  auto bias_params = ctx->GetOperandParams(op, 2);
+  if (!EmptyParams(bias_params)) {
+    return success();
+  }
+  std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
+                                           ctx->GetOperandParams(op, 1)};
+  auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
+  if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
+    *changed = true;
+    new_items->push_back(op->getOperand(2).getDefiningOp());
+  }
+  return success();
+}
+
 }  // namespace xla_hlo
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
index 4087eeb..c4c5904 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
@@ -26,7 +26,9 @@
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
 
 // NOLINTNEXTLINE
 static llvm::cl::opt<bool> disable_per_channel(
@@ -59,9 +61,36 @@
 
 void PropagateQuantPass::runOnFunction() {
   FuncOp func = getFunction();
+  // TODO(fengliuai): deprecate this old code generation path.
   // XLA only support uint8/uint16 quantization for now.
   ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
                                      disable_per_channel, GetOpQuantSpec);
+
+  CpuDeviceTarget spec(&getContext());
+  quant::QuantizeContext ctx(func, spec);
+
+  std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps());
+  bool changed = false;
+  while (!work_list.empty()) {
+    quant::QuantizeRegionOp op = work_list.back();
+    work_list.pop_back();
+
+    llvm::SmallVector<Operation *, 4> new_items;
+    if (failed(ctx.Handle(op, &new_items, &changed))) {
+      // The IR is still valid, thus we shouldn't fail.
+      signalPassFailure();
+    }
+    for (auto item : new_items) {
+      if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
+        work_list.push_back(reg);
+    }
+  }
+
+  if (!changed) return;
+
+  if (failed(ctx.Finalize())) {
+    signalPassFailure();
+  }
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir
new file mode 100644
index 0000000..05ac48c
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/region_propagation.mlir
@@ -0,0 +1,54 @@
+// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
+
+// -----
+
+// CHECK-LABEL: @mul_add_source_no_params
+func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+  %region = "quant.region"(%arg0, %arg1, %arg2) ( {
+    ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>):	// no predecessors
+    %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+    %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+    "quant.return"(%add) : (tensor<4xf32>) -> ()
+  }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
+  (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %region : tensor<4xf32>
+
+// CHECK: input_specs = [f32, f32, f32]
+// CHECK-SAME: output_specs = [f32]
+}
+
+// -----
+
+// CHECK-LABEL: @mul_add_annotated_no_narrow_range
+func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+  %region = "quant.region"(%arg0, %arg1, %arg2) ( {
+    ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>):	// no predecessors
+    %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+    %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+    "quant.return"(%add) : (tensor<4xf32>) -> ()
+  }) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
+    logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
+  (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %region : tensor<4xf32>
+
+// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
+// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
+}
+
+// -----
+
+// CHECK-LABEL: @mul_add_annotated
+func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
+  %region = "quant.region"(%arg0, %arg1, %arg2) ( {
+    ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>):	// no predecessors
+    %mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+    %add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
+    "quant.return"(%add) : (tensor<4xf32>) -> ()
+  }) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
+    logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
+  (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %region : tensor<4xf32>
+
+// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
+// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
+}