Add BroadcastOp to LHLO/HLO emitters.
PiperOrigin-RevId: 277052910
Change-Id: If85b4cbbdbbad29685c4273dce891b1ffb1ee052
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
index 87377b7..a10b16d 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
@@ -31,13 +31,11 @@
using ::mlir::ArrayRef;
using ::mlir::Attribute;
-using ::mlir::Builder;
using ::mlir::Identifier;
using ::mlir::Location;
using ::mlir::NamedAttribute;
using ::mlir::OpBuilder;
using ::mlir::RankedTensorType;
-using ::mlir::ShapedType;
using ::mlir::Type;
using ::mlir::Value;
@@ -50,22 +48,22 @@
switch (opcode) {
case HloOpcode::kAdd:
return {func_builder.create<hlo::AddOp>(loc, rets, args, attrs)};
- case HloOpcode::kMultiply:
- return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
- case HloOpcode::kSubtract:
- return {func_builder.create<hlo::SubOp>(loc, rets, args, attrs)};
- case HloOpcode::kDivide:
- return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
case HloOpcode::kAnd:
return {func_builder.create<hlo::AndOp>(loc, rets, args, attrs)};
- case HloOpcode::kMinimum:
- return {func_builder.create<hlo::MinOp>(loc, rets, args, attrs)};
- case HloOpcode::kMaximum:
- return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
+ case HloOpcode::kDivide:
+ return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
case HloOpcode::kExp:
return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
+ case HloOpcode::kMaximum:
+ return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
+ case HloOpcode::kMinimum:
+ return {func_builder.create<hlo::MinOp>(loc, rets, args, attrs)};
+ case HloOpcode::kMultiply:
+ return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
case HloOpcode::kSelect:
return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
+ case HloOpcode::kSubtract:
+ return {func_builder.create<hlo::SubOp>(loc, rets, args, attrs)};
default:
return tensorflow::errors::Internal(absl::StrCat(
"HLO Opcode ", HloOpcodeString(opcode), " is not supported."));
@@ -103,6 +101,21 @@
return Status::OK();
}
+Status HloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) {
+ mlir::DenseIntElementsAttr broadcast_dim =
+ CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_);
+ TF_ASSIGN_OR_RETURN(Type res_type, ConvertTensorShapeToType<RankedTensorType>(
+ broadcast->shape(), builder_));
+
+ auto broadcast_op = builder_.create<hlo::BroadcastInDimOp>(
+ getLocation(broadcast), llvm::makeArrayRef(res_type),
+ instruction_to_values_[broadcast->operand(0)], broadcast_dim);
+ broadcast_op.setAttr("name", builder_.getStringAttr(broadcast->name()));
+
+ instruction_to_values_[broadcast] = broadcast_op;
+ return Status::OK();
+}
+
Status HloDialectEmitter::HandleParameter(HloInstruction* param) {
auto argValue = arguments_[param->parameter_number()];
instruction_to_values_[param] = argValue;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
index 97f7d53..d2bcf84 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
@@ -52,10 +52,11 @@
StatusOr<mlir::Value*> EmitComputation(const HloComputation& computation);
Status DefaultAction(HloInstruction* instr) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleCompare(HloInstruction* compare) override;
Status HandleConstant(HloInstruction* constant) override;
Status HandleParameter(HloInstruction* param) override;
Status HandleReduce(HloInstruction* reduce) override;
- Status HandleCompare(HloInstruction* compare) override;
private:
mlir::Location getLocation(const HloInstruction* instr) const;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
index 785522e..d2d2c98 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
@@ -66,29 +66,29 @@
case HloOpcode::kAdd:
func_builder.create<lhlo::AddOp>(loc, rets, args, attrs);
break;
- case HloOpcode::kMultiply:
- func_builder.create<lhlo::MulOp>(loc, rets, args, attrs);
- break;
- case HloOpcode::kSubtract:
- func_builder.create<lhlo::SubOp>(loc, rets, args, attrs);
+ case HloOpcode::kAnd:
+ func_builder.create<lhlo::AndOp>(loc, rets, args, attrs);
break;
case HloOpcode::kDivide:
func_builder.create<lhlo::DivOp>(loc, rets, args, attrs);
break;
- case HloOpcode::kAnd:
- func_builder.create<lhlo::AndOp>(loc, rets, args, attrs);
- break;
- case HloOpcode::kMinimum:
- func_builder.create<lhlo::MinOp>(loc, rets, args, attrs);
+ case HloOpcode::kExp:
+ func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
break;
case HloOpcode::kMaximum:
func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
break;
- case HloOpcode::kExp:
- func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
+ case HloOpcode::kMinimum:
+ func_builder.create<lhlo::MinOp>(loc, rets, args, attrs);
+ break;
+ case HloOpcode::kMultiply:
+ func_builder.create<lhlo::MulOp>(loc, rets, args, attrs);
break;
case HloOpcode::kSelect:
- func_builder.create<::mlir::xla_lhlo::SelectOp>(loc, rets, args, attrs);
+ func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
+ break;
+ case HloOpcode::kSubtract:
+ func_builder.create<lhlo::SubOp>(loc, rets, args, attrs);
break;
default:
return tensorflow::errors::Internal(absl::StrCat(
@@ -179,6 +179,19 @@
return Status::OK();
}
+Status LhloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) {
+ mlir::DenseIntElementsAttr broadcast_dim =
+ CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_);
+
+ TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*broadcast));
+ OpBuilder func_builder(function.getBody());
+ auto broadcast_op = func_builder.create<lhlo::BroadcastInDimOp>(
+ getLocation(broadcast), function.getArgument(0), function.getArgument(1),
+ broadcast_dim);
+ broadcast_op.setAttr("name", builder_.getStringAttr(broadcast->name()));
+ return Status::OK();
+}
+
Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) {
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*fusion));
OpBuilder func_builder(function.getBody());
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
index e79fea9..148cf91 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h
@@ -53,11 +53,11 @@
// Default action which emits code for most operations. Operations which are
// special in some way are handled explicitly in HandleFoo methods.
Status DefaultAction(HloInstruction* instr) override;
-
- Status HandleFusion(HloInstruction* fusion) override;
- Status HandleCustomCall(HloInstruction* custom_call) override;
- Status HandleParameter(HloInstruction* parameter) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleCompare(HloInstruction* compare) override;
+ Status HandleCustomCall(HloInstruction* custom_call) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleParameter(HloInstruction* parameter) override;
Status FinishVisit(HloInstruction* root) override;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
index bdd691d..bf1c241 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
@@ -290,5 +290,22 @@
)");
}
+TEST_F(LhloGenTest, Broadcast) {
+ CompileAndVerifyIr(R"(
+HloModule Broadcast
+
+ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] {
+ %x = f32[10]{0} parameter(0)
+ ROOT %broadcast = f32[10, 5]{1,0} broadcast(f32[10]{0} %x), dimensions={0}
+})",
+ R"(
+;CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) {
+;CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]])
+;CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>, name = "broadcast"}
+;CHECK: : ([[IN_T]], [[OUT_T]]) -> ()
+;CHECK: }
+)");
+}
+
} // namespace mlir_gpu
} // namespace xla