Roll-forward with fix:
[XLA/GPU] Migrate nested reduce emitter to take LMHLO.
PiperOrigin-RevId: 343582798
Change-Id: Ia468154b1442818110e8bf1188678c8eaf851f69
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index 0ee44f3..ff052b5 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -197,10 +197,11 @@
//===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way.
-def LHLO_ReduceOp: LHLO_Op<"reduce", [
- SameVariadicOperandSize,
- SingleBlockImplicitTerminator<"TerminatorOp">
- ]>, BASE_HLO_ReduceOp {
+//
+// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
+// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
+// TODO(timshen): cleanup lmhlo.TerminatorOp.
+def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
index e7312e2..cd72707 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
@@ -374,3 +374,23 @@
return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK: "lmhlo.reduce"({{.*}}) ( {
+// CHECK: ^bb0(%[[VAL1:.*]]: tensor<f32>, %[[VAL2:.*]]: tensor<i32>, %[[VAL3:.*]]: tensor<f32>, %[[VAL4:.*]]: tensor<i32>): // no predecessors
+// CHECK: %[[VAL5:.*]] = mhlo.maximum %[[VAL1]], %[[VAL3]] : tensor<f32>
+// CHECK: %[[VAL6:.*]] = mhlo.maximum %[[VAL2]], %[[VAL4:.*]] : tensor<i32>
+// CHECK: %[[VAL7:.*]] = "mhlo.tuple"(%[[VAL5]], %[[VAL6:.*]]) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
+// CHECK: "mhlo.return"(%[[VAL7:.*]]) : (tuple<tensor<f32>, tensor<i32>>) -> ()
+// CHECK: })
+func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
+ %result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
+ ^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
+ %fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
+ }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
+ return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 0d2a113..b46d413 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -277,6 +277,8 @@
return EmitCustomCallOp(instr);
case HloOpcode::kConstant:
return EmitConstant(instr);
+ case HloOpcode::kReduce:
+ return EmitReduceOp(instr);
default:
llvm::errs() << instr->ToString();
return tensorflow::errors::Internal(
@@ -561,6 +563,19 @@
return get_global_memref;
}
+StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
+ HloInstruction* instr) {
+ TF_ASSIGN_OR_RETURN(auto reduce_op,
+ CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
+ auto* reduce = ::xla::Cast<::xla::HloReduceInstruction>(instr);
+ std::vector<int64_t> dimensions(reduce->dimensions().begin(),
+ reduce->dimensions().end());
+ reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
+ TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
+ *instr->called_computations()[0], &reduce_op.body(), &builder_));
+ return reduce_op;
+}
+
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 2d4f243..054aa63 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -58,6 +58,7 @@
::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::CustomCallOp> EmitCustomCallOp(
::xla::HloInstruction* instr);
+ ::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
::xla::StatusOr<mlir::GetGlobalMemrefOp> EmitConstant(
::xla::HloInstruction* instr) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index e156c57..6b7a0f0 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -572,13 +572,24 @@
// because we don't have a fully functioning LMHLO graph yet.
mlir::Location loc = input.op->getLoc();
- mlir::lmhlo::FusionOp fusion = nullptr;
+ mlir::lmhlo::FusionOp fusion =
+ mlir::OpBuilder(input.op).create<mlir::lmhlo::FusionOp>(
+ loc, llvm::ArrayRef<mlir::NamedAttribute>());
Shape output_shape;
+ mlir::OpBuilder b(&fusion.region());
+
+ const auto load_memrefs = [loc, &b](mlir::ValueRange range) {
+ std::vector<mlir::Value> operands;
+ for (mlir::Value memref : range) {
+ auto load = b.create<mlir::TensorLoadOp>(loc, memref);
+ HloFunctionImporter::SetLayoutForMlir(load,
+ TypeToShape(memref.getType()));
+ operands.push_back(load);
+ }
+ return operands;
+ };
+
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
- fusion = mlir::OpBuilder(copy).create<mlir::lmhlo::FusionOp>(
- loc, llvm::ArrayRef<mlir::NamedAttribute>());
- copy.getOperation()->moveBefore(&fusion.region().front().back());
- mlir::OpBuilder b(copy);
auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
HloFunctionImporter::SetLayoutForMlir(
operand, TypeToShape(copy.operand().getType()));
@@ -586,15 +597,41 @@
output_shape = TypeToShape(copy.output().getType());
HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
- copy.getOperation()->erase();
+ } else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(input.op)) {
+ std::vector<mlir::Value> operands = load_memrefs(reduce.operands());
+ std::vector<mlir::Value> init_values = load_memrefs(reduce.init_values());
+ auto fused_reduce = b.create<mlir::mhlo::ReduceOp>(
+ loc, operands, init_values, reduce.dimensions());
+ fused_reduce.body().takeBody(reduce.body());
+ CHECK_EQ(fused_reduce.getNumResults(), reduce.out().size());
+ std::vector<Shape> output_shapes;
+ for (int i = 0; i < reduce.out().size(); i++) {
+ b.create<mlir::TensorStoreOp>(loc, fused_reduce.getResult(i),
+ reduce.out()[i]);
+ auto shape = TypeToShape(reduce.out()[i].getType());
+ if (i == 0) {
+ HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape);
+ }
+ output_shapes.push_back(shape);
+ }
+ if (output_shapes.size() == 1) {
+ output_shape = output_shapes[0];
+ } else {
+ output_shape = ShapeUtil::MakeTupleShape(output_shapes);
+ }
} else {
input.op->dump();
LOG(FATAL) << "Unimplemented default action for mlir op";
}
+ input.op->erase();
input.op = fusion;
- auto ret = EmitLoopFusionFromMlir(
- input, output_shape,
- ComputeMaxUnrollFactor(output_shape, hlo_module_config_));
+ int unroll_factor = 1;
+ // TODO(timshen): Port MayPreventVectorization as we add more ops into this
+ // function.
+ if (output_shape.IsArray()) {
+ unroll_factor = ComputeMaxUnrollFactor(output_shape, hlo_module_config_);
+ }
+ auto ret = EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
return ret;
}
@@ -911,7 +948,8 @@
// This function won't be needed once ElementalIrEmitter migrates to take MHLO
// instead.
static Status ProcessFusionForConversion(mlir::Region* region,
- std::vector<Shape>* operand_shapes) {
+ std::vector<Shape>* operand_shapes,
+ std::vector<Shape>* output_shapes) {
std::vector<mlir::TensorLoadOp> loads;
std::vector<mlir::TensorStoreOp> stores;
@@ -931,8 +969,7 @@
auto arg = region->addArgument(load.getType());
load.replaceAllUsesWith(arg);
Shape shape = TypeToShape(load.getType());
- auto attr = mlir::GetLayoutFromMlirHlo(load);
- if (attr) {
+ if (auto attr = mlir::GetLayoutFromMlirHlo(load)) {
std::vector<int64> minor_to_major;
absl::c_transform(
attr, std::back_inserter(minor_to_major),
@@ -948,6 +985,16 @@
std::vector<mlir::Value> returned_values;
for (auto store : stores) {
+ Shape shape = TypeToShape(store.memref().getType());
+ if (auto attr = mlir::GetLayoutFromMlirHlo(store)) {
+ std::vector<int64> minor_to_major;
+ absl::c_transform(
+ attr, std::back_inserter(minor_to_major),
+ std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
+ *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+ }
+ output_shapes->push_back(shape);
+
returned_values.push_back(store.tensor());
store.erase();
}
@@ -1272,12 +1319,14 @@
}
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
+ TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(reduce));
+
if (IsReductionFromOrToContiguousDimensions(*reduce) &&
reduce->shape().IsArray()) {
return EmitReductionFromOrToContiguousDimensions(reduce, {reduce});
}
- return IrEmitter::HandleReduce(reduce);
+ return DefaultActionForMlir(input);
}
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
@@ -1880,9 +1929,10 @@
bool is_fusion) {
std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
if (module == nullptr) {
- std::vector<Shape> operand_shapes;
+ std::vector<Shape> operand_shapes, output_shapes;
if (is_fusion) {
- TF_RETURN_IF_ERROR(ProcessFusionForConversion(region, &operand_shapes));
+ TF_RETURN_IF_ERROR(
+ ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
}
xla::XlaComputation xla_computation;
@@ -1896,6 +1946,62 @@
module, HloModule::CreateFromProto(xla_computation.proto(),
HloModuleConfig(program_shape)));
+ if (is_fusion) {
+ HloComputation* fused_computation = module->entry_computation();
+ CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
+ for (int i = 0; i < fused_computation->num_parameters(); i++) {
+ *fused_computation->parameter_instruction(i)
+ ->mutable_shape()
+ ->mutable_layout() = operand_shapes[i].layout();
+ }
+ HloInstruction* root = fused_computation->root_instruction();
+ // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
+ // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
+ // as much as possible.
+ if (root->opcode() == HloOpcode::kTuple) {
+ [&] {
+ HloInstruction* real_root = nullptr;
+ int expected_tuple_index = 0;
+ for (HloInstruction* operand : root->operands()) {
+ if (operand->opcode() != HloOpcode::kGetTupleElement) {
+ return;
+ }
+ if (real_root == nullptr) {
+ real_root = operand->mutable_operand(0);
+ } else if (real_root != operand->operand(0)) {
+ return;
+ }
+ if (expected_tuple_index != operand->tuple_index()) {
+ return;
+ }
+ expected_tuple_index++;
+ }
+ fused_computation->set_root_instruction(real_root);
+ std::vector<HloInstruction*> to_be_removed;
+ to_be_removed.push_back(root);
+ for (HloInstruction* operand : root->operands()) {
+ to_be_removed.push_back(operand);
+ }
+ for (auto instr : to_be_removed) {
+ TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
+ }
+
+ root = real_root;
+ }();
+ }
+
+ if (output_shapes.size() > 1) {
+ CHECK(root->shape().IsTuple());
+ CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
+
+ for (int i = 0; i < output_shapes.size(); i++) {
+ *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
+ }
+ } else {
+ CHECK_EQ(1, output_shapes.size());
+ *root->mutable_shape() = output_shapes[0];
+ }
+ }
// Post-process the generated computation:
// * Sanitize constant names, so that they can be used as LLVM global
// symbols.
@@ -1905,22 +2011,13 @@
if (instr->opcode() == HloOpcode::kConstant) {
instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(*instr));
}
- if (instr->shape().IsTuple()) {
- TF_ASSIGN_OR_RETURN(*instr->mutable_shape(),
- ShapeInference::InferVariadicOpShape(
- instr->opcode(), instr->operands()));
+ if (instr->shape().IsTuple() &&
+ computation == module->entry_computation() &&
+ instr != computation->root_instruction()) {
+ return InternalError("Non-root tuple types are not handled.");
}
}
}
- if (is_fusion) {
- HloComputation* fused_computation = module->entry_computation();
- CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
- for (int i = 0; i < fused_computation->num_parameters(); i++) {
- *fused_computation->parameter_instruction(i)
- ->mutable_shape()
- ->mutable_layout() = operand_shapes[i].layout();
- }
- }
}
return module->entry_computation();
}
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo b/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo
index baf3ccf..8e9ee58 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduce_nested.hlo
@@ -13,15 +13,15 @@
// CHECK: %[[VAL_9:.*]] = alloca float, align 4
// CHECK: %[[VAL_10:.*]] = alloca float, align 4
// CHECK: %[[VAL_11:.*]] = getelementptr inbounds i8, i8* %[[VAL_12:.*]], i64 0
-// CHECK: %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [2 x i8*]*
+// CHECK: %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_14:.*]] = getelementptr inbounds i8, i8* %[[VAL_15:.*]], i64 0
-// CHECK: %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [200 x float]*
+// CHECK: %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_17:.*]] = getelementptr inbounds i8, i8* %[[VAL_18:.*]], i64 0
// CHECK: %[[VAL_19:.*]] = bitcast i8* %[[VAL_17]] to [200 x float]*
// CHECK: %[[VAL_20:.*]] = getelementptr inbounds i8, i8* %[[VAL_21:.*]], i64 0
-// CHECK: %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [100 x [200 x [300 x float]]]*
+// CHECK: %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [200 x float]*
// CHECK: %[[VAL_23:.*]] = getelementptr inbounds i8, i8* %[[VAL_24:.*]], i64 0
-// CHECK: %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [100 x [200 x [300 x float]]]*
+// CHECK: %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [2 x i8*]*
// CHECK: %[[VAL_26:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
// CHECK: %[[VAL_27:.*]] = icmp eq i32 0, %[[VAL_26]]
// CHECK: %[[VAL_28:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
@@ -41,11 +41,11 @@
// CHECK: d.in_bounds-after: ; preds = %[[VAL_43:.*]], %[[VAL_32]]
// CHECK: ret void
// CHECK: emit_mof_tuple-true: ; preds = %[[VAL_33]]
-// CHECK: %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_16]] to i8*
-// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 0
+// CHECK: %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8*
+// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_25]], i64 0, i64 0
// CHECK: store i8* %[[VAL_44]], i8** %[[VAL_45]], align 8
-// CHECK: %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8*
-// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 1
+// CHECK: %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_22]] to i8*
+// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_25]], i64 0, i64 1
// CHECK: store i8* %[[VAL_46]], i8** %[[VAL_47]], align 8
// CHECK: br label %[[VAL_32]]
// CHECK: d.in_bounds-true: ; preds = %[[VAL_32]]
@@ -55,23 +55,23 @@
// CHECK: store float %[[VAL_49]], float* %[[VAL_9]], align 4
// CHECK: store i32 0, i32* %[[VAL_8]], align 4
// CHECK: br label %[[VAL_50:.*]]
-// CHECK: d.inner.loop_header.reduction_dim.0: ; preds = %[[VAL_51:.*]], %[[VAL_41]]
+// CHECK: reduce.13.inner.loop_header.reduction_dim.0: ; preds = %[[VAL_51:.*]], %[[VAL_41]]
// CHECK: %[[VAL_52:.*]] = load i32, i32* %[[VAL_8]], align 4
// CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], 100
// CHECK: br i1 %[[VAL_53]], label %[[VAL_43]], label %[[VAL_54:.*]]
-// CHECK: d.inner.loop_body.reduction_dim.0: ; preds = %[[VAL_50]]
+// CHECK: reduce.13.inner.loop_body.reduction_dim.0: ; preds = %[[VAL_50]]
// CHECK: store i32 0, i32* %[[VAL_7]], align 4
// CHECK: br label %[[VAL_55:.*]]
-// CHECK: d.inner.loop_header.reduction_dim.2: ; preds = %[[VAL_56:.*]], %[[VAL_54]]
+// CHECK: reduce.13.inner.loop_header.reduction_dim.2: ; preds = %[[VAL_56:.*]], %[[VAL_54]]
// CHECK: %[[VAL_57:.*]] = load i32, i32* %[[VAL_7]], align 4
// CHECK: %[[VAL_58:.*]] = icmp uge i32 %[[VAL_57]], 300
// CHECK: br i1 %[[VAL_58]], label %[[VAL_51]], label %[[VAL_56]]
-// CHECK: d.inner.loop_body.reduction_dim.2: ; preds = %[[VAL_55]]
+// CHECK: reduce.13.inner.loop_body.reduction_dim.2: ; preds = %[[VAL_55]]
// CHECK: %[[VAL_59:.*]] = load float, float* %[[VAL_10]], align 4
// CHECK: %[[VAL_60:.*]] = load float, float* %[[VAL_9]], align 4
-// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_22]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
+// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_13]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
// CHECK: %[[VAL_62:.*]] = load float, float* %[[VAL_61]], align 4, !invariant.load !4
-// CHECK: %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_25]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
+// CHECK: %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_16]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
// CHECK: %[[VAL_64:.*]] = load float, float* %[[VAL_63]], align 4, !invariant.load !4
// CHECK: store float %[[VAL_59]], float* %[[VAL_5]], align 4
// CHECK: store float %[[VAL_60]], float* %[[VAL_4]], align 4
@@ -83,7 +83,7 @@
// CHECK: %[[VAL_67:.*]] = bitcast float* %[[VAL_1]] to i8*
// CHECK: %[[VAL_68:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_6]], i64 0, i64 1
// CHECK: store i8* %[[VAL_67]], i8** %[[VAL_68]], align 8
-// CHECK: call void @Add(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]])
+// CHECK: call void @region_1_5(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]])
// CHECK: %[[VAL_69:.*]] = load float, float* %[[VAL_0]], align 4
// CHECK: %[[VAL_70:.*]] = load float, float* %[[VAL_1]], align 4
// CHECK: store float %[[VAL_69]], float* %[[VAL_10]], align 4
@@ -91,21 +91,21 @@
// CHECK: %[[VAL_71:.*]] = add nuw nsw i32 %[[VAL_57]], 1
// CHECK: store i32 %[[VAL_71]], i32* %[[VAL_7]], align 4
// CHECK: br label %[[VAL_55]]
-// CHECK: d.inner.loop_exit.reduction_dim.2: ; preds = %[[VAL_55]]
+// CHECK: reduce.13.inner.loop_exit.reduction_dim.2: ; preds = %[[VAL_55]]
// CHECK: %[[VAL_72:.*]] = add nuw nsw i32 %[[VAL_52]], 1
// CHECK: store i32 %[[VAL_72]], i32* %[[VAL_8]], align 4
// CHECK: br label %[[VAL_50]]
-// CHECK: d.inner.loop_exit.reduction_dim.0: ; preds = %[[VAL_50]]
+// CHECK: reduce.13.inner.loop_exit.reduction_dim.0: ; preds = %[[VAL_50]]
// CHECK: %[[VAL_73:.*]] = load float, float* %[[VAL_10]], align 4
// CHECK: %[[VAL_74:.*]] = insertvalue { float, float } undef, float %[[VAL_73]], 0
// CHECK: %[[VAL_75:.*]] = load float, float* %[[VAL_9]], align 4
// CHECK: %[[VAL_76:.*]] = insertvalue { float, float } %[[VAL_74]], float %[[VAL_75]], 1
// CHECK: %[[VAL_77:.*]] = extractvalue { float, float } %[[VAL_76]], 0
-// CHECK: %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_16]] to float*
+// CHECK: %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_19]] to float*
// CHECK: %[[VAL_79:.*]] = getelementptr inbounds float, float* %[[VAL_78]], i32 %[[VAL_37]]
// CHECK: store float %[[VAL_77]], float* %[[VAL_79]], align 4
// CHECK: %[[VAL_80:.*]] = extractvalue { float, float } %[[VAL_76]], 1
-// CHECK: %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_19]] to float*
+// CHECK: %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_22]] to float*
// CHECK: %[[VAL_82:.*]] = getelementptr inbounds float, float* %[[VAL_81]], i32 %[[VAL_37]]
// CHECK: store float %[[VAL_80]], float* %[[VAL_82]], align 4
// CHECK: br label %[[VAL_42]]
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
index 1457fa5..e92cd54 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
@@ -345,7 +345,9 @@
CreateDenseIntElementsAttrFromVector(instr->dimensions(), builder_);
auto reduce_op = builder.create<lhlo::ReduceOp>(loc, inputs, init_values,
results, dimensions_attr);
- reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(instr));
+ builder.createBlock(&reduce_op.body());
+ OpBuilder::atBlockEnd(&reduce_op.body().front())
+ .create<lhlo::TerminatorOp>(getLocation(instr));
return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc,
*instr->to_apply(), emission_context_);
}