Roll-forward with fix:

[XLA/GPU] Migrate fused Slice to take LMHLO.

PiperOrigin-RevId: 355491691
Change-Id: I45147b974f4b5e1cbf886945c1078a98e088beba
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 0e3bd29..7adad73 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -369,29 +369,29 @@
   return reduction_dimensions.dimensions[1] >= kWarpSize;
 }
 
-bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
+bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
                           bool verify_no_strides) {
-  if (!unnested_hlo.IsInputFusion()) {
+  auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
+  if (!fusion) {
     return false;
   }
 
-  auto is_non_strided = [](const std::vector<int64>& strides) -> bool {
-    return absl::c_all_of(strides, [](int stride) { return stride == 1; });
+  auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
+    return absl::c_all_of(
+        strides, [](const llvm::APInt& stride) { return stride == 1; });
   };
 
-  const HloInstruction* root = unnested_hlo.fused_expression_root();
-  if (root->opcode() == HloOpcode::kSlice) {
-    return !verify_no_strides || is_non_strided(root->slice_strides());
+  for (mlir::Value value : fusion.getFusionResults()) {
+    auto slice =
+        mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
+    if (!slice) {
+      return false;
+    }
+    if (verify_no_strides && !is_non_strided(slice.strides())) {
+      return false;
+    }
   }
-
-  if (root->opcode() != HloOpcode::kTuple) {
-    return false;
-  }
-
-  return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) {
-    return instr->opcode() == HloOpcode::kSlice &&
-           (!verify_no_strides || is_non_strided(instr->slice_strides()));
-  });
+  return true;
 }
 
 ReductionDimensions GetReductionKindAndContiguousComponents(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index bda4887..a7b4432 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -168,8 +168,8 @@
 // Returns whether unnested_hlo is an input fusion whose root is either a slice
 // or a tuple of slices. If verify_no_strides is true, returns false unless all
 // ROOT slices have no strides.
-bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
-                          bool verify_no_strides = false);
+bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
+                          bool verify_no_strides);
 
 struct ReductionDimensions {
   // Indicates whether the reduction is a row reduction or a column reduction.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 82c8db1..6859768 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -490,34 +490,36 @@
 // slices are the same and the slices are non-strided. Otherwise, returns
 // FailedPrecondition.
 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
-    const HloInstruction& fusion) {
+    mlir::lmhlo::FusionOp fusion) {
   if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
     return FailedPrecondition(
         "Unsupported root for slice input fusion. "
         "Only non-strided slices are supported.");
   }
 
-  const HloInstruction& root = *fusion.fused_expression_root();
-  if (root.opcode() == HloOpcode::kSlice) {
-    return root.operands()[0]->shape();
-  }
-
-  CHECK_EQ(root.opcode(), HloOpcode::kTuple);
-  const Shape& first_slice_operand_shape =
-      root.operands()[0]->operands()[0]->shape();
-  for (size_t i = 1; i < root.operands().size(); ++i) {
-    const HloInstruction* slice = root.operands()[i];
-    const Shape& operand_shape = slice->operands()[0]->shape();
-    if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
-                                             operand_shape)) {
-      return FailedPrecondition(
-          "Fused slices do not have the same input shape, fused computation = "
-          "%s.",
-          root.parent()->name());
+  absl::optional<Shape> first_slice_operand_shape;
+  for (mlir::Value result : fusion.getFusionResults()) {
+    auto slice =
+        mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(result.getDefiningOp());
+    if (!slice) {
+      return FailedPrecondition("Expected a slice op");
+    }
+    if (first_slice_operand_shape.has_value()) {
+      Shape operand_shape = TypeToShape(slice.operand().getType());
+      if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape,
+                                               operand_shape)) {
+        return FailedPrecondition(
+            "Fused slices do not have the same input shape, instruction is %s",
+            MlirToString(fusion));
+      }
+    } else {
+      first_slice_operand_shape = TypeToShape(slice.operand().getType());
     }
   }
-
-  return first_slice_operand_shape;
+  if (!first_slice_operand_shape.has_value()) {
+    return InvalidArgument("Fusion has no roots");
+  }
+  return *first_slice_operand_shape;
 }
 
 }  // namespace
@@ -1842,8 +1844,8 @@
       // In the case of root tuple, it can be either reduce or slice input
       // fusion.
       case HloOpcode::kTuple: {
-        if (IsInputFusibleSlices(*fusion)) {
-          return EmitInputFusibleNonStridedSlices(fusion);
+        if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) {
+          return EmitInputFusibleNonStridedSlices(mlir_input);
         }
 
         CHECK_GE(mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op)
@@ -1864,7 +1866,7 @@
         return EmitReductionFromOrToContiguousDimensions(mlir_input);
       }
       case HloOpcode::kSlice: {
-        return EmitInputFusibleNonStridedSlices(fusion);
+        return EmitInputFusibleNonStridedSlices(mlir_input);
       }
       default:
         LOG(FATAL) << "Bad opcode for input fusion: "
@@ -5636,11 +5638,16 @@
 //   Write to output of slice1
 // }
 //
-void IrEmitterUnnested::EmitElementForInputFusibleSlices(
-    HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) {
-  VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString();
+Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
+    mlir::lmhlo::FusionOp fusion, absl::Span<const llvm_ir::IrArray> ir_arrays,
+    const llvm_ir::IrArray::Index& index) {
+  VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion);
 
-  HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root();
+  TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
+                      GetOrCreateSubComputationFromRegion(&fusion.region(),
+                                                          /*is_fusion=*/true));
+
+  HloInstruction* slice_or_tuple = fused_computation->root_instruction();
   auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
     if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
       return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
@@ -5654,7 +5661,13 @@
   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
                                      GetNestedComputer());
   FusedIrEmitter fused_emitter(&elem_emitter);
-  BindFusionArguments(unnested_hlo, &fused_emitter);
+  for (int i = 0; i < fused_computation->num_parameters(); i++) {
+    fused_emitter.BindGenerator(
+        fused_computation->parameter_instruction(i),
+        [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
+          return ir_arrays[i].EmitReadArrayElement(index, &b_);
+        });
+  }
   for (const HloInstruction* slice : slice_instructions) {
     auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
     input_ir_values.push_back(input_generator(index).ValueOrDie());
@@ -5689,11 +5702,8 @@
             Sub(src_multidim[dim],
                 index.GetConstantWithIndexType(slice->slice_starts(dim)));
       }
-      ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice)
-                                   ? ShapeIndex()
-                                   : ShapeIndex({i});
       llvm_ir::IrArray src_ir_array =
-          GetIrArray(*unnested_hlo, *unnested_hlo, shape_index);
+          ir_arrays[fused_computation->num_parameters() + i];
       IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
                                      index.GetType());
       src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
@@ -5702,16 +5712,23 @@
 
     ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
   }
+  return Status::OK();
 }
 
 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
-    HloInstruction* unnested_hlo) {
+    MlirEmitterInput mlir_input) {
+  auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
+
   constexpr int unroll_factor = 1;
-  std::unique_ptr<KernelThunk> kernel_thunk =
-      BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/true);
+
+  std::vector<llvm_ir::IrArray> ir_arrays;
+  TF_ASSIGN_OR_RETURN(
+      auto kernel_thunk,
+      BuildKernelThunkForMlir(fusion, mlir_input.thunk_info,
+                              mlir_input.extra_slice, &ir_arrays));
 
   TF_ASSIGN_OR_RETURN(Shape element_shape,
-                      GetConsistentInputShapeForRootSlices(*unnested_hlo));
+                      GetConsistentInputShapeForRootSlices(fusion));
   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
       element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
@@ -5720,13 +5737,12 @@
   Status emit_status =
       ParallelLoopEmitter(
           [&](const llvm_ir::IrArray::Index index) -> Status {
-            EmitElementForInputFusibleSlices(unnested_hlo, index);
-            return Status::OK();
+            return EmitElementForInputFusibleSlices(fusion, ir_arrays, index);
           },
           element_shape, launch_dimensions, &b_)
-          .EmitLoop(IrName(unnested_hlo),
-                    GetIndexTypeForKernel(
-                        unnested_hlo, launch_dimensions.launch_bound(), &b_));
+          .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
+                    GetIndexTypeForKernelFromMlir(
+                        fusion, launch_dimensions.launch_bound(), &b_));
 
   thunk_sequence_.emplace_back(std::move(kernel_thunk));
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 0480736..6cfed8d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -446,11 +446,12 @@
   // different. On the other hand, the input ranges of slices can be
   // overlapping. Further generalization/specialization when the needs are seen
   // in the future.
-  Status EmitInputFusibleNonStridedSlices(HloInstruction* unnested_hlo);
+  Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input);
 
-  void EmitElementForInputFusibleSlices(
-      HloInstruction* unnested_hlo,
-      const llvm_ir::IrArray::Index& slice_input_index);
+  Status EmitElementForInputFusibleSlices(
+      mlir::lmhlo::FusionOp fusion,
+      absl::Span<const llvm_ir::IrArray> ir_arrays,
+      const llvm_ir::IrArray::Index& index);
 
   // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
   // the process. Scatter indices are taken from `scatter_indices_gen`, updates
diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo
index aeb9176..5964764 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo
@@ -2,21 +2,21 @@
 
 // CHECK-LABEL: entry:
 // CHECK:         %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
-// CHECK:         %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]*
+// CHECK:         %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
 // CHECK:         %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
 // CHECK:         %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
 // CHECK:         %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
 // CHECK:         %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
 // CHECK:         %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
-// CHECK:         %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]*
+// CHECK:         %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
 // CHECK:         %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
 // CHECK:         %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
 // CHECK:         %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
-// CHECK:         %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
+// CHECK:         %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1023 x half]*
 // CHECK:         %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
-// CHECK:         %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]*
+// CHECK:         %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [0 x half]*
 // CHECK:         %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
-// CHECK:         %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]*
+// CHECK:         %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [3 x i8*]*
 // CHECK:         %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
 // CHECK:         %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
 // CHECK:         %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024
@@ -34,18 +34,18 @@
 // CHECK:       concat_index_from_operand_id0:                    ; preds = %[[VAL_31]]
 // CHECK:         %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ]
 // CHECK:         %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]]
-// CHECK:         %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]]
+// CHECK:         %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_2]], i32 0, i32 %[[VAL_39]]
 // CHECK:         %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4
-// CHECK:         %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]]
+// CHECK:         %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_39]]
 // CHECK:         %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]]
 // CHECK:         br label %[[VAL_45:.*]]
 // CHECK:       concat_index_from_operand_id1:                    ; preds = %[[VAL_37]]
 // CHECK:         %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ]
 // CHECK:         %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]]
-// CHECK:         %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]]
+// CHECK:         %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_47]]
 // CHECK:         %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4
-// CHECK:         %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]]
+// CHECK:         %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_47]]
 // CHECK:         %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]]
 // CHECK:         br label %[[VAL_45]]
@@ -54,7 +54,7 @@
 // CHECK:         br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]]
 // CHECK:       concat_index_not_from_operand1:                   ; preds = %[[VAL_37]]
 // CHECK:         unreachable
-// CHECK:       concat.1.merge:                                   ; preds = %[[VAL_54]], %[[VAL_36]]
+// CHECK:       concatenate.7.merge:                              ; preds = %[[VAL_54]], %[[VAL_36]]
 // CHECK:         %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ]
 // CHECK:         %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0
 // CHECK:         %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024
@@ -74,17 +74,17 @@
 // CHECK:         br label %[[VAL_32]]
 // CHECK:       slice0-true:                                      ; preds = %[[VAL_45]]
 // CHECK:         %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0
-// CHECK:         %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]]
+// CHECK:         %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_71]]
 // CHECK:         store half %[[VAL_56]], half* %[[VAL_72]], align 2
 // CHECK:         br label %[[VAL_61]]
 // CHECK:       slice1-true:                                      ; preds = %[[VAL_61]]
 // CHECK:         %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024
-// CHECK:         %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]]
+// CHECK:         %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_73]]
 // CHECK:         store half %[[VAL_56]], half* %[[VAL_74]], align 2
 // CHECK:         br label %[[VAL_66]]
 // CHECK:       slice2-true:                                      ; preds = %[[VAL_66]]
 // CHECK:         %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047
-// CHECK:         %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]]
+// CHECK:         %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_75]]
 // CHECK:         store half %[[VAL_56]], half* %[[VAL_76]], align 2
 // CHECK:         br label %[[VAL_33]]