Add support for tf.WhileRegion in LegalizeTFControlFlow pass.
This adds a legalization for tf.WhileRegion -> mhlo.while with special handling for implicitly captured/used inputs.
PiperOrigin-RevId: 333071632
Change-Id: If464cf760e063b9f00b2fee6509334aabc13152c
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir
index 5ac9786..9133a2e 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir
@@ -126,3 +126,120 @@
%2 = mhlo.add %arg0, %0 : tensor<i32>
return %2, %arg1, %1 : tensor<i32>, tensor<i32>, tensor<i32>
}
+
+
+// CHECK-LABEL: func @whileRegion
+func @whileRegion() -> tensor<i32> {
+ // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
+ %1 = mhlo.constant dense<-1> : tensor<i32>
+ // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
+ // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
+ %2:3 = "tf.WhileRegion"(%0, %1, %0) ( {
+ // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+ ^cond(%carg0: tensor<i32>, %carg1: tensor<i32>, %carg2: tensor<i32>):
+ // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+ // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+ // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
+ // CHECK: [[VAL10:%.+]] = mhlo.constant dense<10>
+ %3 = mhlo.constant dense<10> : tensor<i32>
+ // CHECK: [[VAL11:%.+]] = "mhlo.compare"([[VAL9]], [[VAL10]]) {comparison_direction = "LT"}
+ %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ // CHECK: "mhlo.return"([[VAL11]])
+ "tf.Yield"(%4) : (tensor<i1>) -> ()
+ }, {
+ // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+ ^body(%barg0: tensor<i32>, %barg1: tensor<i32>, %barg2: tensor<i32>):
+ // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
+ // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
+ // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
+ // CHECK: [[VAL10:%.+]] = mhlo.constant dense<1>
+ %5 = mhlo.constant dense<1> : tensor<i32>
+ // CHECK: [[VAL11:%.+]] = mhlo.add [[VAL9]], [[VAL10]]
+ %6 = mhlo.add %barg2, %5 : tensor<i32>
+ // CHECK: [[VAL12:%.+]] = mhlo.add [[VAL7]], [[VAL10]]
+ %7 = mhlo.add %barg0, %5 : tensor<i32>
+ // CHECK: [[VAL13:%.+]] = "mhlo.tuple"([[VAL12]], [[VAL8]], [[VAL11]])
+ // CHECK: "mhlo.return"([[VAL13]])
+ "tf.Yield"(%7, %barg1, %6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
+ }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
+ // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+ // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
+ // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
+ // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
+ // CHECK: return [[VAL6]]
+ return %2#2 : tensor<i32>
+}
+
+
+// CHECK-LABEL: func @whileRegionImplicitInputs
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>)
+func @whileRegionImplicitInputs(%arg0: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
+ %1 = mhlo.constant dense<-1> : tensor<i32>
+ // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[VAL0]], [[VAL1]])
+ // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
+ %2 = "tf.WhileRegion"(%arg0) ( {
+ // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+ ^cond(%carg0: tensor<i32>):
+ // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+ // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+ // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
+ // CHECK: [[VAL8:%.+]] = "mhlo.compare"([[VAL5]], [[VAL6]]) {comparison_direction = "LT"}
+ %3 = "mhlo.compare"(%carg0, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ // CHECK: "mhlo.return"([[VAL8]])
+ "tf.Yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
+ ^body(%barg0: tensor<i32>):
+ // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
+ // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
+ // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
+ // CHECK: [[VAL8:%.+]] = mhlo.add [[VAL5]], [[VAL7]]
+ %3 = mhlo.add %barg0, %1 : tensor<i32>
+ // CHECK: [[VAL9:%.+]] = mhlo.add [[VAL5]], [[VAL8]]
+ %4 = mhlo.add %barg0, %3 : tensor<i32>
+ // CHECK: [[VAL10:%.+]] = "mhlo.tuple"([[VAL9]], [[VAL6]], [[VAL7]])
+ // CHECK: "mhlo.return"([[VAL10]])
+ "tf.Yield"(%4) : (tensor<i32>) -> ()
+ }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
+ // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+ // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
+ // CHECK: return [[VAL4]]
+ return %2 : tensor<i32>
+}
+
+
+// CHECK-LABEL: func @whileRegionMultipleImplicitInputs
+func @whileRegionMultipleImplicitInputs() {
+ // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
+ %0 = mhlo.constant dense<0> : tensor<i32>
+ // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
+ %1 = mhlo.constant dense<-1> : tensor<i32>
+ // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]])
+ // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
+ "tf.WhileRegion"() ( {
+ // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>):
+ // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+ // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+ // CHECK: [[VAL6:%.+]] = "mhlo.compare"([[VAL4]], [[VAL5]]) {comparison_direction = "LT"}
+ %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ // CHECK: "mhlo.return"([[VAL6]])
+ "tf.Yield"(%2) : (tensor<i1>) -> ()
+ }, {
+ // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>):
+ // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
+ // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
+ // CHECK: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
+ %2 = mhlo.add %0, %1 : tensor<i32>
+ // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL4]], [[VAL5]])
+ // CHECK: "mhlo.return"([[VAL7]])
+ "tf.Yield"() : () -> ()
+ }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> ()
+ // CHECK: }) : (tuple<tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>>
+ // CHECK: return
+ return
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
index 567a135..c02675f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
@@ -26,6 +26,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/iterator_range.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@@ -35,12 +36,14 @@
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -64,7 +67,7 @@
namespace {
-void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
+void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
// De-tuple the results of the xla hlo if result.
for (auto result_it : llvm::enumerate(replace)) {
auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
@@ -179,6 +182,122 @@
Detuple(while_op.getResult(), op.getResults(), &builder);
op.erase();
}
+
+// Replaces all block arguments of a block with a single block arg of Tuple
+// type `tuple_type`. Single block arguments are removed and remapped to
+// get_tuple_element(tuple_arg, index).
+void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
+ auto tuple_arg = block->addArgument(tuple_type);
+ Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
+ for (int i = block->getNumArguments() - 2; i >= 0; --i)
+ block->eraseArgument(i);
+}
+
+// Finds and replaces implicitly captured value uses with tuple block argument.
+// get_tuple_element's are created to extract specific values. Values from
+// get_tuple_element's are returned in the order of `implicit_inputs`.
+llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
+ Block* block, int offset, ArrayRef<Value> implicit_inputs,
+ OpBuilder* builder) {
+ llvm::SmallVector<Value, 4> implicit_input_elements;
+ implicit_input_elements.reserve(implicit_inputs.size());
+
+ Region* region = block->getParent();
+ assert(block->getNumArguments() == 1);
+
+ BlockArgument tuple_arg = block->getArgument(0);
+ for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
+ Value implicit_input_value = implicit_input.value();
+ auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
+ implicit_input_value.getLoc(), tuple_arg,
+ implicit_input.index() + offset);
+ implicit_input_elements.emplace_back(get_tuple_element.getResult());
+ for (auto& use :
+ llvm::make_early_inc_range(implicit_input_value.getUses())) {
+ if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
+ use.set(get_tuple_element.getResult());
+ }
+ }
+
+ return implicit_input_elements;
+}
+
+// Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
+// can be returned if `extra_results` is not empty. If `tuple_return` is
+// set, a tuple of the return values will be set as the terminator operand.
+void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
+ OpBuilder* builder, bool tuple_return = true) {
+ Operation* terminator = block->getTerminator();
+ assert(isa<TF::YieldOp>(terminator));
+ Location loc = terminator->getLoc();
+
+ builder->setInsertionPoint(terminator);
+ auto results = llvm::to_vector<4>(terminator->getOperands());
+ results.append(extra_results.begin(), extra_results.end());
+ if (tuple_return) {
+ auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
+ builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
+ } else {
+ builder->create<mhlo::ReturnOp>(loc, results);
+ }
+
+ terminator->erase();
+}
+
+void LowerWhileRegion(TF::WhileRegionOp op) {
+ Location loc = op.getLoc();
+ OpBuilder builder(op);
+
+ // XLA prefers tuple arguments for control flow due to XLA not supporting
+ // multiple return values.
+ SmallVector<Value, 3> inputs(op.input());
+ const int inputs_size = inputs.size();
+ llvm::SetVector<Value> implicit_inputs;
+ getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
+ inputs.append(implicit_inputs.begin(), implicit_inputs.end());
+
+ builder.setInsertionPoint(op);
+ Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
+
+ // Create the new while op with tuple inputs. Implicit inputs are also
+ // returned.
+ auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
+ while_result_types.reserve(while_result_types.size() +
+ implicit_inputs.size());
+ for (const auto& implicit_input : implicit_inputs)
+ while_result_types.emplace_back(implicit_input.getType());
+ auto while_op = builder.create<mhlo::WhileOp>(
+ loc, builder.getTupleType(while_result_types), tuple_input);
+
+ // Rewrite cond and associated block arguments and terminator.
+ Region& cond = while_op.cond();
+ cond.takeBody(op.cond());
+ Block& cond_block = cond.front();
+ builder.setInsertionPointToStart(&cond_block);
+ ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder);
+ ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(),
+ &builder);
+ // Cond always returns a single result of bool type.
+ ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
+ /*tuple_return=*/false);
+
+ // Rewrite body and associated block arguments and terminator.
+ Region& body = while_op.body();
+ body.takeBody(op.body());
+ Block& body_block = body.front();
+ builder.setInsertionPointToStart(&body_block);
+ ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder);
+ // Capture implicit inputs that were added as a tuple block arguments. These
+ // are to be returned by the body in addition to explicit inputs.
+ auto implicit_input_elements = ReplaceImplicitInputs(
+ &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder);
+ ReplaceTerminator(&body_block, implicit_input_elements, &builder);
+
+ // De-tuple the results of the xla hlo while.
+ builder.setInsertionPoint(op);
+ Detuple(while_op.getResult(), op.getResults(), &builder);
+ op.erase();
+}
} // namespace
void LegalizeTFControlFlow::runOnOperation() {
@@ -189,6 +308,10 @@
LowerWhile(while_op, module);
return;
}
+ if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
+ LowerWhileRegion(while_region_op);
+ return;
+ }
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
LowerIf(if_op, module);
return;