Add BroadcastOp to LHLO/HLO emitters.

PiperOrigin-RevId: 277052910
Change-Id: If85b4cbbdbbad29685c4273dce891b1ffb1ee052
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
index 87377b7..a10b16d 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
@@ -31,13 +31,11 @@
 
 using ::mlir::ArrayRef;
 using ::mlir::Attribute;
-using ::mlir::Builder;
 using ::mlir::Identifier;
 using ::mlir::Location;
 using ::mlir::NamedAttribute;
 using ::mlir::OpBuilder;
 using ::mlir::RankedTensorType;
-using ::mlir::ShapedType;
 using ::mlir::Type;
 using ::mlir::Value;
 
@@ -50,22 +48,22 @@
   switch (opcode) {
     case HloOpcode::kAdd:
       return {func_builder.create<hlo::AddOp>(loc, rets, args, attrs)};
-    case HloOpcode::kMultiply:
-      return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
-    case HloOpcode::kSubtract:
-      return {func_builder.create<hlo::SubOp>(loc, rets, args, attrs)};
-    case HloOpcode::kDivide:
-      return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
     case HloOpcode::kAnd:
       return {func_builder.create<hlo::AndOp>(loc, rets, args, attrs)};
-    case HloOpcode::kMinimum:
-      return {func_builder.create<hlo::MinOp>(loc, rets, args, attrs)};
-    case HloOpcode::kMaximum:
-      return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
+    case HloOpcode::kDivide:
+      return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
     case HloOpcode::kExp:
       return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
+    case HloOpcode::kMaximum:
+      return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
+    case HloOpcode::kMinimum:
+      return {func_builder.create<hlo::MinOp>(loc, rets, args, attrs)};
+    case HloOpcode::kMultiply:
+      return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
     case HloOpcode::kSelect:
       return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
+    case HloOpcode::kSubtract:
+      return {func_builder.create<hlo::SubOp>(loc, rets, args, attrs)};
     default:
       return tensorflow::errors::Internal(absl::StrCat(
           "HLO Opcode ", HloOpcodeString(opcode), " is not supported."));
@@ -103,6 +101,21 @@
   return Status::OK();
 }
 
+Status HloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) {
+  mlir::DenseIntElementsAttr broadcast_dim =
+      CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_);
+  TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
+                                         broadcast->shape(), builder_));
+
+  auto broadcast_op = builder_.create<hlo::BroadcastInDimOp>(
+      getLocation(broadcast), llvm::makeArrayRef(res_type),
+      instruction_to_values_[broadcast->operand(0)], broadcast_dim);
+  broadcast_op.setAttr("name", builder_.getStringAttr(broadcast->name()));
+
+  instruction_to_values_[broadcast] = broadcast_op;
+  return Status::OK();
+}
+
 Status HloDialectEmitter::HandleParameter(HloInstruction* param) {
   auto argValue = arguments_[param->parameter_number()];
   instruction_to_values_[param] = argValue;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
index 97f7d53..d2bcf84 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
@@ -52,10 +52,11 @@
   StatusOr<mlir::Value*> EmitComputation(const HloComputation& computation);
 
   Status DefaultAction(HloInstruction* instr) override;
+  Status HandleBroadcast(HloInstruction* broadcast) override;
+  Status HandleCompare(HloInstruction* compare) override;
   Status HandleConstant(HloInstruction* constant) override;
   Status HandleParameter(HloInstruction* param) override;
   Status HandleReduce(HloInstruction* reduce) override;
-  Status HandleCompare(HloInstruction* compare) override;
 
  private:
   mlir::Location getLocation(const HloInstruction* instr) const;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
index 785522e..d2d2c98 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
@@ -66,29 +66,29 @@
     case HloOpcode::kAdd:
       func_builder.create<lhlo::AddOp>(loc, rets, args, attrs);
       break;
-    case HloOpcode::kMultiply:
-      func_builder.create<lhlo::MulOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kSubtract:
-      func_builder.create<lhlo::SubOp>(loc, rets, args, attrs);
+    case HloOpcode::kAnd:
+      func_builder.create<lhlo::AndOp>(loc, rets, args, attrs);
       break;
     case HloOpcode::kDivide:
       func_builder.create<lhlo::DivOp>(loc, rets, args, attrs);
       break;
-    case HloOpcode::kAnd:
-      func_builder.create<lhlo::AndOp>(loc, rets, args, attrs);
-      break;
-    case HloOpcode::kMinimum:
-      func_builder.create<lhlo::MinOp>(loc, rets, args, attrs);
+    case HloOpcode::kExp:
+      func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
       break;
     case HloOpcode::kMaximum:
       func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
       break;
-    case HloOpcode::kExp:
-      func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
+    case HloOpcode::kMinimum:
+      func_builder.create<lhlo::MinOp>(loc, rets, args, attrs);
+      break;
+    case HloOpcode::kMultiply:
+      func_builder.create<lhlo::MulOp>(loc, rets, args, attrs);
       break;
     case HloOpcode::kSelect:
-      func_builder.create<::mlir::xla_lhlo::SelectOp>(loc, rets, args, attrs);
+      func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
+      break;
+    case HloOpcode::kSubtract:
+      func_builder.create<lhlo::SubOp>(loc, rets, args, attrs);
       break;
     default:
       return tensorflow::errors::Internal(absl::StrCat(
@@ -179,6 +179,19 @@
   return Status::OK();
 }
 
+Status LhloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) {
+  mlir::DenseIntElementsAttr broadcast_dim =
+      CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_);
+
+  TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*broadcast));
+  OpBuilder func_builder(function.getBody());
+  auto broadcast_op = func_builder.create<lhlo::BroadcastInDimOp>(
+      getLocation(broadcast), function.getArgument(0), function.getArgument(1),
+      broadcast_dim);
+  broadcast_op.setAttr("name", builder_.getStringAttr(broadcast->name()));
+  return Status::OK();
+}
+
 Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) {
   TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*fusion));
   OpBuilder func_builder(function.getBody());
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
index e79fea9..148cf91 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
@@ -53,11 +53,11 @@
   // Default action which emits code for most operations. Operations which are
   // special in some way are handled explicitly in HandleFoo methods.
   Status DefaultAction(HloInstruction* instr) override;
-
-  Status HandleFusion(HloInstruction* fusion) override;
-  Status HandleCustomCall(HloInstruction* custom_call) override;
-  Status HandleParameter(HloInstruction* parameter) override;
+  Status HandleBroadcast(HloInstruction* broadcast) override;
   Status HandleCompare(HloInstruction* compare) override;
+  Status HandleCustomCall(HloInstruction* custom_call) override;
+  Status HandleFusion(HloInstruction* fusion) override;
+  Status HandleParameter(HloInstruction* parameter) override;
 
   Status FinishVisit(HloInstruction* root) override;
 
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
index bdd691d..bf1c241 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
@@ -290,5 +290,22 @@
       )");
 }
 
+TEST_F(LhloGenTest, Broadcast) {
+  CompileAndVerifyIr(R"(
+HloModule Broadcast
+
+ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] {
+  %x = f32[10]{0} parameter(0)
+  ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0}
+})",
+                     R"(
+;CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]],  %[[OUT:.*]]: [[OUT_T:.*]]) {
+;CHECK:   "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]])
+;CHECK:   {broadcast_dimensions = dense<0> : tensor<1xi64>, name = "broadcast"}
+;CHECK:   : ([[IN_T]], [[OUT_T]]) -> ()
+;CHECK: }
+)");
+}
+
 }  // namespace mlir_gpu
 }  // namespace xla