Roll-forward with fix:
[XLA/GPU] Migrate fused Slice to take LMHLO.
PiperOrigin-RevId: 355491691
Change-Id: I45147b974f4b5e1cbf886945c1078a98e088beba
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 0e3bd29..7adad73 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -369,29 +369,29 @@
return reduction_dimensions.dimensions[1] >= kWarpSize;
}
-bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
+bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
bool verify_no_strides) {
- if (!unnested_hlo.IsInputFusion()) {
+ auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
+ if (!fusion) {
return false;
}
- auto is_non_strided = [](const std::vector<int64>& strides) -> bool {
- return absl::c_all_of(strides, [](int stride) { return stride == 1; });
+ auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
+ return absl::c_all_of(
+ strides, [](const llvm::APInt& stride) { return stride == 1; });
};
- const HloInstruction* root = unnested_hlo.fused_expression_root();
- if (root->opcode() == HloOpcode::kSlice) {
- return !verify_no_strides || is_non_strided(root->slice_strides());
+ for (mlir::Value value : fusion.getFusionResults()) {
+ auto slice =
+ mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
+ if (!slice) {
+ return false;
+ }
+ if (verify_no_strides && !is_non_strided(slice.strides())) {
+ return false;
+ }
}
-
- if (root->opcode() != HloOpcode::kTuple) {
- return false;
- }
-
- return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kSlice &&
- (!verify_no_strides || is_non_strided(instr->slice_strides()));
- });
+ return true;
}
ReductionDimensions GetReductionKindAndContiguousComponents(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index bda4887..a7b4432 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -168,8 +168,8 @@
// Returns whether unnested_hlo is an input fusion whose root is either a slice
// or a tuple of slices. If verify_no_strides is true, returns false unless all
// ROOT slices have no strides.
-bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
- bool verify_no_strides = false);
+bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
+ bool verify_no_strides);
struct ReductionDimensions {
// Indicates whether the reduction is a row reduction or a column reduction.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 82c8db1..6859768 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -490,34 +490,36 @@
// slices are the same and the slices are non-strided. Otherwise, returns
// FailedPrecondition.
StatusOr<Shape> GetConsistentInputShapeForRootSlices(
- const HloInstruction& fusion) {
+ mlir::lmhlo::FusionOp fusion) {
if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
return FailedPrecondition(
"Unsupported root for slice input fusion. "
"Only non-strided slices are supported.");
}
- const HloInstruction& root = *fusion.fused_expression_root();
- if (root.opcode() == HloOpcode::kSlice) {
- return root.operands()[0]->shape();
- }
-
- CHECK_EQ(root.opcode(), HloOpcode::kTuple);
- const Shape& first_slice_operand_shape =
- root.operands()[0]->operands()[0]->shape();
- for (size_t i = 1; i < root.operands().size(); ++i) {
- const HloInstruction* slice = root.operands()[i];
- const Shape& operand_shape = slice->operands()[0]->shape();
- if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
- operand_shape)) {
- return FailedPrecondition(
- "Fused slices do not have the same input shape, fused computation = "
- "%s.",
- root.parent()->name());
+ absl::optional<Shape> first_slice_operand_shape;
+ for (mlir::Value result : fusion.getFusionResults()) {
+ auto slice =
+ mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(result.getDefiningOp());
+ if (!slice) {
+ return FailedPrecondition("Expected a slice op");
+ }
+ if (first_slice_operand_shape.has_value()) {
+ Shape operand_shape = TypeToShape(slice.operand().getType());
+ if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape,
+ operand_shape)) {
+ return FailedPrecondition(
+ "Fused slices do not have the same input shape, instruction is %s",
+ MlirToString(fusion));
+ }
+ } else {
+ first_slice_operand_shape = TypeToShape(slice.operand().getType());
}
}
-
- return first_slice_operand_shape;
+ if (!first_slice_operand_shape.has_value()) {
+ return InvalidArgument("Fusion has no roots");
+ }
+ return *first_slice_operand_shape;
}
} // namespace
@@ -1842,8 +1844,8 @@
// In the case of root tuple, it can be either reduce or slice input
// fusion.
case HloOpcode::kTuple: {
- if (IsInputFusibleSlices(*fusion)) {
- return EmitInputFusibleNonStridedSlices(fusion);
+ if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) {
+ return EmitInputFusibleNonStridedSlices(mlir_input);
}
CHECK_GE(mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op)
@@ -1864,7 +1866,7 @@
return EmitReductionFromOrToContiguousDimensions(mlir_input);
}
case HloOpcode::kSlice: {
- return EmitInputFusibleNonStridedSlices(fusion);
+ return EmitInputFusibleNonStridedSlices(mlir_input);
}
default:
LOG(FATAL) << "Bad opcode for input fusion: "
@@ -5636,11 +5638,16 @@
// Write to output of slice1
// }
//
-void IrEmitterUnnested::EmitElementForInputFusibleSlices(
- HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) {
- VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString();
+Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
+ mlir::lmhlo::FusionOp fusion, absl::Span<const llvm_ir::IrArray> ir_arrays,
+ const llvm_ir::IrArray::Index& index) {
+ VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion);
- HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root();
+ TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
+ GetOrCreateSubComputationFromRegion(&fusion.region(),
+ /*is_fusion=*/true));
+
+ HloInstruction* slice_or_tuple = fused_computation->root_instruction();
auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
@@ -5654,7 +5661,13 @@
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(&elem_emitter);
- BindFusionArguments(unnested_hlo, &fused_emitter);
+ for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ fused_emitter.BindGenerator(
+ fused_computation->parameter_instruction(i),
+ [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
+ return ir_arrays[i].EmitReadArrayElement(index, &b_);
+ });
+ }
for (const HloInstruction* slice : slice_instructions) {
auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
input_ir_values.push_back(input_generator(index).ValueOrDie());
@@ -5689,11 +5702,8 @@
Sub(src_multidim[dim],
index.GetConstantWithIndexType(slice->slice_starts(dim)));
}
- ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice)
- ? ShapeIndex()
- : ShapeIndex({i});
llvm_ir::IrArray src_ir_array =
- GetIrArray(*unnested_hlo, *unnested_hlo, shape_index);
+ ir_arrays[fused_computation->num_parameters() + i];
IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
index.GetType());
src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
@@ -5702,16 +5712,23 @@
ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
}
+ return Status::OK();
}
Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
- HloInstruction* unnested_hlo) {
+ MlirEmitterInput mlir_input) {
+ auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
+
constexpr int unroll_factor = 1;
- std::unique_ptr<KernelThunk> kernel_thunk =
- BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/true);
+
+ std::vector<llvm_ir::IrArray> ir_arrays;
+ TF_ASSIGN_OR_RETURN(
+ auto kernel_thunk,
+ BuildKernelThunkForMlir(fusion, mlir_input.thunk_info,
+ mlir_input.extra_slice, &ir_arrays));
TF_ASSIGN_OR_RETURN(Shape element_shape,
- GetConsistentInputShapeForRootSlices(*unnested_hlo));
+ GetConsistentInputShapeForRootSlices(fusion));
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
@@ -5720,13 +5737,12 @@
Status emit_status =
ParallelLoopEmitter(
[&](const llvm_ir::IrArray::Index index) -> Status {
- EmitElementForInputFusibleSlices(unnested_hlo, index);
- return Status::OK();
+ return EmitElementForInputFusibleSlices(fusion, ir_arrays, index);
},
element_shape, launch_dimensions, &b_)
- .EmitLoop(IrName(unnested_hlo),
- GetIndexTypeForKernel(
- unnested_hlo, launch_dimensions.launch_bound(), &b_));
+ .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
+ GetIndexTypeForKernelFromMlir(
+ fusion, launch_dimensions.launch_bound(), &b_));
thunk_sequence_.emplace_back(std::move(kernel_thunk));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 0480736..6cfed8d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -446,11 +446,12 @@
// different. On the other hand, the input ranges of slices can be
// overlapping. Further generalization/specialization when the needs are seen
// in the future.
- Status EmitInputFusibleNonStridedSlices(HloInstruction* unnested_hlo);
+ Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input);
- void EmitElementForInputFusibleSlices(
- HloInstruction* unnested_hlo,
- const llvm_ir::IrArray::Index& slice_input_index);
+ Status EmitElementForInputFusibleSlices(
+ mlir::lmhlo::FusionOp fusion,
+ absl::Span<const llvm_ir::IrArray> ir_arrays,
+ const llvm_ir::IrArray::Index& index);
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
// the process. Scatter indices are taken from `scatter_indices_gen`, updates
diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo
index aeb9176..5964764 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo
@@ -2,21 +2,21 @@
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
-// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]*
+// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
-// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]*
+// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
-// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
+// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1023 x half]*
// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
-// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]*
+// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [0 x half]*
// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
-// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]*
+// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [3 x i8*]*
// CHECK: %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024
@@ -34,18 +34,18 @@
// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_31]]
// CHECK: %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ]
// CHECK: %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]]
-// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]]
+// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_2]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4
-// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]]
+// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4
// CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]]
// CHECK: br label %[[VAL_45:.*]]
// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_37]]
// CHECK: %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ]
// CHECK: %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]]
-// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]]
+// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4
-// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]]
+// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4
// CHECK: %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]]
// CHECK: br label %[[VAL_45]]
@@ -54,7 +54,7 @@
// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]]
// CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_37]]
// CHECK: unreachable
-// CHECK: concat.1.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
+// CHECK: concatenate.7.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
// CHECK: %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ]
// CHECK: %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0
// CHECK: %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024
@@ -74,17 +74,17 @@
// CHECK: br label %[[VAL_32]]
// CHECK: slice0-true: ; preds = %[[VAL_45]]
// CHECK: %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0
-// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]]
+// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_71]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_72]], align 2
// CHECK: br label %[[VAL_61]]
// CHECK: slice1-true: ; preds = %[[VAL_61]]
// CHECK: %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024
-// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]]
+// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_73]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_74]], align 2
// CHECK: br label %[[VAL_66]]
// CHECK: slice2-true: ; preds = %[[VAL_66]]
// CHECK: %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047
-// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]]
+// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_75]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_76]], align 2
// CHECK: br label %[[VAL_33]]