[mhlo:scf] Use tensors from the inner loops when swapping

Otherwise the second swap will not see the effects of the first swap, which is
what makes bubblesort do the bubbling.

PiperOrigin-RevId: 450909872
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc
index f1329ae..f19f95d 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc
@@ -251,12 +251,12 @@
     auto& src_block = op.comparator().front();
     auto scf_for = b.create<scf::ForOp>(
         loc, zero, ub, one, args,
-        [&](OpBuilder& b, Location loc, Value iv, ValueRange) {
+        [&](OpBuilder& b, Location loc, Value iv, ValueRange args) {
           // Extract and create tensors with relevant values to merge with the
           // expected inputs to the original compare region of the mhlo.sort op.
           SmallVector<Value> indices(ivs);
           Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
-          for (const auto& idx_and_output : llvm::enumerate(results)) {
+          for (const auto& idx_and_output : llvm::enumerate(args)) {
             indices[sort_dim] = iv;
             sort_args.push_back(b.create<tensor::FromElementsOp>(
                 loc, src_block.getArgumentTypes()[2 * idx_and_output.index()],
@@ -314,12 +314,14 @@
           Value ivPlusOne =
               b.create<arith::AddIOp>(loc, scf_for.getInductionVar(), one);
           SmallVector<Value> swapped_results;
-          for (int i = 0, e = results.size(); i < e; ++i) {
-            Value v1 = sort_args[i * 2];
-            Value v2 = sort_args[i * 2 + 1];
+          for (const auto& idx_and_output :
+               llvm::enumerate(scf_for.getRegionIterArgs())) {
+            Value v1 = sort_args[idx_and_output.index() * 2];
+            Value v2 = sort_args[idx_and_output.index() * 2 + 1];
             indices[sort_dim] = scf_for.getInductionVar();
             Value after_first_insert = b.create<tensor::InsertOp>(
-                loc, b.create<tensor::ExtractOp>(loc, v2), results[i], indices);
+                loc, b.create<tensor::ExtractOp>(loc, v2),
+                idx_and_output.value(), indices);
             indices[sort_dim] = ivPlusOne;
             swapped_results.push_back(b.create<tensor::InsertOp>(
                 loc, b.create<tensor::ExtractOp>(loc, v1), after_first_insert,
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
index 39e00a5..7a8e7c0 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/legalize-control-flow.mlir
@@ -275,13 +275,13 @@
 // CHECK:             %[[VAL_14:.*]]:2 = scf.for %[[VAL_15:.*]] = %[[C0_2]] to %[[VAL_13]] step %[[C1_1]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_9]]) -> (tensor<2xi32>, tensor<2xi32>) {
 // CHECK:               %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
 // Extract each value twice because we are comparing both directions and haven't run CSE yet
-// CHECK:               %[[VAL_19:.*]] = tensor.extract %[[VAL_8]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
+// CHECK:               %[[VAL_19:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
 // CHECK:               %[[VAL_20:.*]] = tensor.from_elements %[[VAL_19]] : tensor<i32>
-// CHECK:               %[[VAL_21:.*]] = tensor.extract %[[VAL_8]]{{\[}}%[[VAL_18]]] : tensor<2xi32>
+// CHECK:               %[[VAL_21:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_18]]] : tensor<2xi32>
 // CHECK:               %[[VAL_22:.*]] = tensor.from_elements %[[VAL_21]] : tensor<i32>
-// CHECK:               %[[VAL_23:.*]] = tensor.extract %[[VAL_9]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
+// CHECK:               %[[VAL_23:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
 // CHECK:               %[[VAL_24:.*]] = tensor.from_elements %[[VAL_23]] : tensor<i32>
-// CHECK:               %[[VAL_25:.*]] = tensor.extract %[[VAL_9]]{{\[}}%[[VAL_18]]] : tensor<2xi32>
+// CHECK:               %[[VAL_25:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_18]]] : tensor<2xi32>
 // CHECK:               %[[VAL_26:.*]] = tensor.from_elements %[[VAL_25]] : tensor<i32>
 // CHECK:               %[[VAL_27:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
 // CHECK:               %[[VAL_28:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
@@ -302,12 +302,12 @@
 // CHECK:                 %[[VAL_41:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
 // Swap first pair of values
 // CHECK:                 %[[VAL_42:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
-// CHECK:                 %[[VAL_43:.*]] = tensor.insert %[[VAL_42]] into %[[VAL_8]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
+// CHECK:                 %[[VAL_43:.*]] = tensor.insert %[[VAL_42]] into %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
 // CHECK:                 %[[VAL_44:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
 // CHECK:                 %[[VAL_45:.*]] = tensor.insert %[[VAL_44]] into %[[VAL_43]]{{\[}}%[[VAL_41]]] : tensor<2xi32>
 // Swap second pair of values
 // CHECK:                 %[[VAL_46:.*]] = tensor.extract %[[VAL_26]][] : tensor<i32>
-// CHECK:                 %[[VAL_47:.*]] = tensor.insert %[[VAL_46]] into %[[VAL_9]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
+// CHECK:                 %[[VAL_47:.*]] = tensor.insert %[[VAL_46]] into %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
 // CHECK:                 %[[VAL_48:.*]] = tensor.extract %[[VAL_24]][] : tensor<i32>
 // CHECK:                 %[[VAL_49:.*]] = tensor.insert %[[VAL_48]] into %[[VAL_47]]{{\[}}%[[VAL_41]]] : tensor<2xi32>
 // CHECK:                 scf.yield %[[VAL_45]], %[[VAL_49]] : tensor<2xi32>, tensor<2xi32>
@@ -353,13 +353,13 @@
 // CHECK:             %[[VAL_14:.*]]:2 = scf.for %[[VAL_15:.*]] = %[[C0_2]] to %[[VAL_13]] step %[[C1_1]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_9]]) -> (tensor<?xi32>, tensor<?xi32>) {
 // CHECK:               %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
 // Extract each value twice because we are comparing both directions and haven't run CSE yet
-// CHECK:               %[[VAL_19:.*]] = tensor.extract %[[VAL_8]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
+// CHECK:               %[[VAL_19:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
 // CHECK:               %[[VAL_20:.*]] = tensor.from_elements %[[VAL_19]] : tensor<i32>
-// CHECK:               %[[VAL_21:.*]] = tensor.extract %[[VAL_8]]{{\[}}%[[VAL_18]]] : tensor<?xi32>
+// CHECK:               %[[VAL_21:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_18]]] : tensor<?xi32>
 // CHECK:               %[[VAL_22:.*]] = tensor.from_elements %[[VAL_21]] : tensor<i32>
-// CHECK:               %[[VAL_23:.*]] = tensor.extract %[[VAL_9]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
+// CHECK:               %[[VAL_23:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
 // CHECK:               %[[VAL_24:.*]] = tensor.from_elements %[[VAL_23]] : tensor<i32>
-// CHECK:               %[[VAL_25:.*]] = tensor.extract %[[VAL_9]]{{\[}}%[[VAL_18]]] : tensor<?xi32>
+// CHECK:               %[[VAL_25:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_18]]] : tensor<?xi32>
 // CHECK:               %[[VAL_26:.*]] = tensor.from_elements %[[VAL_25]] : tensor<i32>
 // CHECK:               %[[VAL_27:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
 // CHECK:               %[[VAL_28:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
@@ -380,12 +380,12 @@
 // CHECK:                 %[[VAL_41:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
 // Swap first pair of values
 // CHECK:                 %[[VAL_42:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
-// CHECK:                 %[[VAL_43:.*]] = tensor.insert %[[VAL_42]] into %[[VAL_8]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
+// CHECK:                 %[[VAL_43:.*]] = tensor.insert %[[VAL_42]] into %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
 // CHECK:                 %[[VAL_44:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
 // CHECK:                 %[[VAL_45:.*]] = tensor.insert %[[VAL_44]] into %[[VAL_43]]{{\[}}%[[VAL_41]]] : tensor<?xi32>
 // Swap second pair of values
 // CHECK:                 %[[VAL_46:.*]] = tensor.extract %[[VAL_26]][] : tensor<i32>
-// CHECK:                 %[[VAL_47:.*]] = tensor.insert %[[VAL_46]] into %[[VAL_9]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
+// CHECK:                 %[[VAL_47:.*]] = tensor.insert %[[VAL_46]] into %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
 // CHECK:                 %[[VAL_48:.*]] = tensor.extract %[[VAL_24]][] : tensor<i32>
 // CHECK:                 %[[VAL_49:.*]] = tensor.insert %[[VAL_48]] into %[[VAL_47]]{{\[}}%[[VAL_41]]] : tensor<?xi32>
 // CHECK:                 scf.yield %[[VAL_45]], %[[VAL_49]] : tensor<?xi32>, tensor<?xi32>