Add support for tf.WhileRegion in LegalizeTFControlFlow pass.

This adds a legalization for tf.WhileRegion -> mhlo.while with special handling for implicitly captured/used inputs.

PiperOrigin-RevId: 333071632
Change-Id: If464cf760e063b9f00b2fee6509334aabc13152c
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir
index 5ac9786..9133a2e 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir
@@ -126,3 +126,120 @@
   %2 = mhlo.add %arg0, %0 : tensor<i32>
   return %2, %arg1, %1 : tensor<i32>, tensor<i32>, tensor<i32>
 }
+
+
+// CHECK-LABEL: func @whileRegion
+func @whileRegion() -> tensor<i32> {
+  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
+  %1 = mhlo.constant dense<-1> : tensor<i32>
+  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
+  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
+  %2:3 = "tf.WhileRegion"(%0, %1, %0) ( {
+  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+  ^cond(%carg0: tensor<i32>, %carg1: tensor<i32>, %carg2: tensor<i32>):
+    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+    // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+    // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
+    // CHECK: [[VAL10:%.+]] = mhlo.constant dense<10>
+    %3 = mhlo.constant dense<10> : tensor<i32>
+    // CHECK: [[VAL11:%.+]] = "mhlo.compare"([[VAL9]], [[VAL10]]) {comparison_direction = "LT"}
+    %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    // CHECK: "mhlo.return"([[VAL11]])
+    "tf.Yield"(%4) : (tensor<i1>) -> ()
+  }, {
+  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+  ^body(%barg0: tensor<i32>, %barg1: tensor<i32>, %barg2: tensor<i32>):
+    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
+    // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
+    // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
+    // CHECK: [[VAL10:%.+]] = mhlo.constant dense<1>
+    %5 = mhlo.constant dense<1> : tensor<i32>
+    // CHECK: [[VAL11:%.+]] = mhlo.add [[VAL9]], [[VAL10]]
+    %6 = mhlo.add %barg2, %5 : tensor<i32>
+    // CHECK: [[VAL12:%.+]] = mhlo.add [[VAL7]], [[VAL10]]
+    %7 = mhlo.add %barg0, %5 : tensor<i32>
+    // CHECK: [[VAL13:%.+]] = "mhlo.tuple"([[VAL12]], [[VAL8]], [[VAL11]])
+    // CHECK: "mhlo.return"([[VAL13]])
+    "tf.Yield"(%7, %barg1, %6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
+  }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
+  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+  // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
+  // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
+  // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
+  // CHECK: return [[VAL6]]
+  return %2#2 : tensor<i32>
+}
+
+
+// CHECK-LABEL: func @whileRegionImplicitInputs
+// CHECK-SAME:  ([[ARG0:%.+]]: tensor<i32>)
+func @whileRegionImplicitInputs(%arg0: tensor<i32>) -> tensor<i32> {
+  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
+  %1 = mhlo.constant dense<-1> : tensor<i32>
+  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[VAL0]], [[VAL1]])
+  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
+  %2 = "tf.WhileRegion"(%arg0) ( {
+  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+  ^cond(%carg0: tensor<i32>):
+    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+    // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
+    // CHECK: [[VAL8:%.+]] = "mhlo.compare"([[VAL5]], [[VAL6]]) {comparison_direction = "LT"}
+    %3 = "mhlo.compare"(%carg0, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    // CHECK: "mhlo.return"([[VAL8]])
+    "tf.Yield"(%3) : (tensor<i1>) -> ()
+  }, {
+  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+  ^body(%barg0: tensor<i32>):
+    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
+    // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
+    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
+    // CHECK: [[VAL8:%.+]] = mhlo.add [[VAL5]], [[VAL7]]
+    %3 = mhlo.add %barg0, %1 : tensor<i32>
+    // CHECK: [[VAL9:%.+]] = mhlo.add [[VAL5]], [[VAL8]]
+    %4 = mhlo.add %barg0, %3 : tensor<i32>
+    // CHECK: [[VAL10:%.+]] = "mhlo.tuple"([[VAL9]], [[VAL6]], [[VAL7]])
+    // CHECK: "mhlo.return"([[VAL10]])
+    "tf.Yield"(%4) : (tensor<i32>) -> ()
+  }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
+  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+  // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
+  // CHECK: return [[VAL4]]
+  return %2 : tensor<i32>
+}
+
+
+// CHECK-LABEL: func @whileRegionMultipleImplicitInputs
+func @whileRegionMultipleImplicitInputs() {
+  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
+  %0 = mhlo.constant dense<0> : tensor<i32>
+  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
+  %1 = mhlo.constant dense<-1> : tensor<i32>
+  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]])
+  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
+  "tf.WhileRegion"() ( {
+  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>):
+    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+    // CHECK: [[VAL6:%.+]] = "mhlo.compare"([[VAL4]], [[VAL5]]) {comparison_direction = "LT"}
+    %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    // CHECK: "mhlo.return"([[VAL6]])
+    "tf.Yield"(%2) : (tensor<i1>) -> ()
+  }, {
+  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>):
+    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+    // CHECK: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
+    %2 = mhlo.add %0, %1 : tensor<i32>
+    // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL4]], [[VAL5]])
+    // CHECK: "mhlo.return"([[VAL7]])
+    "tf.Yield"() : () -> ()
+  }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> ()
+  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>>
+  // CHECK: return
+  return
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
index 567a135..c02675f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
@@ -26,6 +26,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/iterator_range.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
@@ -35,12 +36,14 @@
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/core/util/tensor_format.h"
@@ -64,7 +67,7 @@
 
 namespace {
 
-void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
+void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
   // De-tuple the results of the xla hlo if result.
   for (auto result_it : llvm::enumerate(replace)) {
     auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
@@ -179,6 +182,122 @@
   Detuple(while_op.getResult(), op.getResults(), &builder);
   op.erase();
 }
+
+// Replaces all block arguments of a block with a single block arg of Tuple
+// type `tuple_type`. Single block arguments are removed and remapped to
+// get_tuple_element(tuple_arg, index).
+void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
+  auto tuple_arg = block->addArgument(tuple_type);
+  Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
+  for (int i = block->getNumArguments() - 2; i >= 0; --i)
+    block->eraseArgument(i);
+}
+
+// Finds and replaces implicitly captured value uses with tuple block argument.
+// get_tuple_element's are created to extract specific values. Values from
+// get_tuple_element's are returned in the order of `implicit_inputs`.
+llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
+    Block* block, int offset, ArrayRef<Value> implicit_inputs,
+    OpBuilder* builder) {
+  llvm::SmallVector<Value, 4> implicit_input_elements;
+  implicit_input_elements.reserve(implicit_inputs.size());
+
+  Region* region = block->getParent();
+  assert(block->getNumArguments() == 1);
+
+  BlockArgument tuple_arg = block->getArgument(0);
+  for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
+    Value implicit_input_value = implicit_input.value();
+    auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
+        implicit_input_value.getLoc(), tuple_arg,
+        implicit_input.index() + offset);
+    implicit_input_elements.emplace_back(get_tuple_element.getResult());
+    for (auto& use :
+         llvm::make_early_inc_range(implicit_input_value.getUses())) {
+      if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
+      use.set(get_tuple_element.getResult());
+    }
+  }
+
+  return implicit_input_elements;
+}
+
+// Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
+// can be returned if `extra_results` is not empty. If `tuple_return` is
+// set, a tuple of the return values will be set as the terminator operand.
+void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
+                       OpBuilder* builder, bool tuple_return = true) {
+  Operation* terminator = block->getTerminator();
+  assert(isa<TF::YieldOp>(terminator));
+  Location loc = terminator->getLoc();
+
+  builder->setInsertionPoint(terminator);
+  auto results = llvm::to_vector<4>(terminator->getOperands());
+  results.append(extra_results.begin(), extra_results.end());
+  if (tuple_return) {
+    auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
+    builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
+  } else {
+    builder->create<mhlo::ReturnOp>(loc, results);
+  }
+
+  terminator->erase();
+}
+
+void LowerWhileRegion(TF::WhileRegionOp op) {
+  Location loc = op.getLoc();
+  OpBuilder builder(op);
+
+  // XLA prefers tuple arguments for control flow due to XLA not supporting
+  // multiple return values.
+  SmallVector<Value, 3> inputs(op.input());
+  const int inputs_size = inputs.size();
+  llvm::SetVector<Value> implicit_inputs;
+  getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
+  inputs.append(implicit_inputs.begin(), implicit_inputs.end());
+
+  builder.setInsertionPoint(op);
+  Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
+
+  // Create the new while op with tuple inputs. Implicit inputs are also
+  // returned.
+  auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
+  while_result_types.reserve(while_result_types.size() +
+                             implicit_inputs.size());
+  for (const auto& implicit_input : implicit_inputs)
+    while_result_types.emplace_back(implicit_input.getType());
+  auto while_op = builder.create<mhlo::WhileOp>(
+      loc, builder.getTupleType(while_result_types), tuple_input);
+
+  // Rewrite cond and associated block arguments and terminator.
+  Region& cond = while_op.cond();
+  cond.takeBody(op.cond());
+  Block& cond_block = cond.front();
+  builder.setInsertionPointToStart(&cond_block);
+  ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder);
+  ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(),
+                        &builder);
+  // Cond always returns a single result of bool type.
+  ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
+                    /*tuple_return=*/false);
+
+  // Rewrite body and associated block arguments and terminator.
+  Region& body = while_op.body();
+  body.takeBody(op.body());
+  Block& body_block = body.front();
+  builder.setInsertionPointToStart(&body_block);
+  ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder);
+  // Capture implicit inputs that were added as a tuple block arguments. These
+  // are to be returned by the body in addition to explicit inputs.
+  auto implicit_input_elements = ReplaceImplicitInputs(
+      &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder);
+  ReplaceTerminator(&body_block, implicit_input_elements, &builder);
+
+  // De-tuple the results of the xla hlo while.
+  builder.setInsertionPoint(op);
+  Detuple(while_op.getResult(), op.getResults(), &builder);
+  op.erase();
+}
 }  // namespace
 
 void LegalizeTFControlFlow::runOnOperation() {
@@ -189,6 +308,10 @@
       LowerWhile(while_op, module);
       return;
     }
+    if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
+      LowerWhileRegion(while_region_op);
+      return;
+    }
     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
       LowerIf(if_op, module);
       return;