[MLIR] Add XLA HLO -> LMHLO conversion for all elementwise ops.
PiperOrigin-RevId: 345557248
Change-Id: I5832bb00cb735489f6115c19007b68a49b434a0a
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index 47b100a..3f302f1 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -85,6 +85,8 @@
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
+def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp;
+
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index d122790..00f2f4b 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -139,6 +139,7 @@
":mlir_hlo_to_hlo",
":translate_cl_options",
"//tensorflow/compiler/mlir/hlo",
+ "//tensorflow/compiler/mlir/hlo:hlo_ops_base_enums",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/xla:debug_options_flags",
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 762fb8a..23fab36 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -537,7 +537,8 @@
}
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
- attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups()));
+ attributes.push_back(
+ ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
@@ -932,7 +933,7 @@
}
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
- const std::vector<ReplicaGroup>& replica_groups) {
+ const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
int64_t num_groups = replica_groups.size();
int64_t group_size =
num_groups == 0 ? 0 : replica_groups[0].replica_ids_size();
@@ -944,9 +945,9 @@
attr[flat_index++] = group.replica_ids(i);
}
auto type = mlir::RankedTensorType::get({num_groups, group_size},
- builder_->getIntegerType(64));
- return builder_->getNamedAttr("replica_groups",
- DenseIntElementsAttr::get(type, attr));
+ builder.getIntegerType(64));
+ return builder.getNamedAttr("replica_groups",
+ DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index a0f6e6c..d849b83 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -64,6 +64,12 @@
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
+ // Converts replica groups to attribute
+ //
+ // TODO(timshen): move this to attribute_importer.h.
+ static mlir::NamedAttribute ConvertReplicaGroups(
+ const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
+
private:
HloFunctionImporter(mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
@@ -136,10 +142,6 @@
// padding low and padding high for each of the spatial dimensions.
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
- // Converts replica groups to attribute
- mlir::NamedAttribute ConvertReplicaGroups(
- const std::vector<ReplicaGroup>& replica_groups);
-
// Converts channel id to attribute
mlir::NamedAttribute ConvertChannelHandle(
absl::optional<tensorflow::int64> channel_id);
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index ce42ccf..0884230 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -169,3 +169,52 @@
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
}
+// -----
+
+HloModule GemmBias
+
+// CHECK-LABEL: func @main
+// CHECK: "lmhlo_gpu.gemm_bias"
+// CHECK-SAME: algorithm = 0 : i64
+// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
+// CHECK-SAME: alpha_real = 1.000000e+00 : f64
+// CHECK-SAME: batch_size = 1 : i64
+// CHECK-SAME: beta = 1.000000e+00 : f64
+// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
+// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
+// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
+// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
+// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
+ENTRY main {
+ %A = f32[1,1]{1,0} parameter(0)
+ %B = f32[1,4]{1,0} parameter(1)
+ %C = f32[1,4]{1,0} parameter(2)
+ ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
+ custom_call_target="__cublas$gemm",
+ backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
+}
+
+// -----
+
+HloModule AllReduce
+
+// Test all-reduce
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK-LABEL: func @test_all_reduce
+// CHECK-SAME: ([[INPUT:%.*]]: memref<8xf32>
+%test_all_reduce {
+ input = f32[8] parameter(0)
+ // CHECK: "lmhlo.all_reduce"([[INPUT]], {{.*}})
+ // CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+ // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+ // CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
+ // CHECK: }) {
+ // CHECK-SAME: channel_id = {handle = 1 : i64, type = 0 : i64}
+ // CHECK-SAME: replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
+ ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
index cd72707..e39a654 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
@@ -45,6 +45,34 @@
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.atan2
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.atan2"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value: tensor<2x2xf32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.bitcast_convert
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.bitcast_convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
@@ -57,6 +85,63 @@
// -----
// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.cbrt
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.cbrt"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 2
+// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8>
+func @main(%pred: tensor<2x2xf32>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.clamp
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
+// CHECK-NEXT: return
+ %0 = "mhlo.clamp"(%pred, %lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.count_leading_zeros
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.count_leading_zeros"(%value) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.compare
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.compare"(%value0, %value1) {comparison_direction="GT"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
+ return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
@@ -72,6 +157,19 @@
// -----
// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
+func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf16> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<2x2xf16>
+// CHECK: lmhlo.convert
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf16>
+ return %res : tensor<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
@@ -118,6 +216,45 @@
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.exponential_minus_one
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.exponential_minus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.floor
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.floor"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.is_finite
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.is_finite"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
+ return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.log
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
@@ -128,6 +265,39 @@
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.log_plus_one
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.log_plus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.map
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK: return
+ %res = "mhlo.map"(%value0, %value1) ({
+ ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+ %c = "mhlo.add"(%a, %b) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %ret = "mhlo.add"(%a, %c) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "mhlo.return"(%ret) : (tensor<f32>) -> ()
+ }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -185,6 +355,90 @@
// -----
// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xi1>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.not
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.not"(%value0) : (tensor<2x2xi1>) -> tensor<2x2xi1>
+ return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.not
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.not"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.or
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
+ return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.or
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.popcnt
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.popcnt"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.power
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.power"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
@@ -211,6 +465,19 @@
// -----
// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.reduce_precision
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.reduce_precision"(%value0) {exponent_bits=5 : i32, mantissa_bits=12 : i32}: (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
@@ -230,6 +497,19 @@
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.round_nearest_afz
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.round_nearest_afz"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.rsqrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
@@ -255,6 +535,51 @@
// -----
// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.shift_left
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.shift_left"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.shift_right_arithmetic
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.shift_right_arithmetic"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.shift_right_logical
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.shift_right_logical"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -272,6 +597,19 @@
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.sine
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+ %res = "mhlo.sine"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.sqrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
@@ -309,6 +647,36 @@
// -----
// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.xor
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
+ return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.xor
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+ %res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32>
// CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32>
// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {lmhlo.alloc = 0
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index bf11d52..f4ad6f0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -39,6 +39,7 @@
#include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
@@ -229,12 +230,28 @@
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
case HloOpcode::kAdd:
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
+ case HloOpcode::kAllReduce:
+ return EmitAllReduceOp(instr);
case HloOpcode::kAnd:
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
+ case HloOpcode::kAtan2:
+ return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
+ case HloOpcode::kBitcastConvert:
+ return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
case HloOpcode::kCeil:
return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr);
+ case HloOpcode::kCbrt:
+ return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
+ case HloOpcode::kClamp:
+ return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
+ case HloOpcode::kClz:
+ return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
+ case HloOpcode::kCompare:
+ return EmitCompareOp(instr);
case HloOpcode::kComplex:
return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr);
+ case HloOpcode::kConvert:
+ return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr);
case HloOpcode::kCopy:
return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr);
case HloOpcode::kCos:
@@ -243,10 +260,20 @@
return CreateOpWithoutAttrs<lmhlo::DivOp>(instr);
case HloOpcode::kExp:
return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr);
+ case HloOpcode::kExpm1:
+ return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr);
+ case HloOpcode::kFloor:
+ return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
case HloOpcode::kImag:
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
+ case HloOpcode::kIsFinite:
+ return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
case HloOpcode::kLog:
return CreateOpWithoutAttrs<lmhlo::LogOp>(instr);
+ case HloOpcode::kLog1p:
+ return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr);
+ case HloOpcode::kMap:
+ return EmitMapOp(instr);
case HloOpcode::kMaximum:
return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr);
case HloOpcode::kMinimum:
@@ -255,22 +282,44 @@
return CreateOpWithoutAttrs<lmhlo::MulOp>(instr);
case HloOpcode::kNegate:
return CreateOpWithoutAttrs<lmhlo::NegOp>(instr);
+ case HloOpcode::kNot:
+ return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
+ case HloOpcode::kOr:
+ return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
+ case HloOpcode::kPopulationCount:
+ return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
+ case HloOpcode::kPower:
+ return CreateOpWithoutAttrs<lmhlo::PowOp>(instr);
case HloOpcode::kReal:
return CreateOpWithoutAttrs<lmhlo::RealOp>(instr);
+ case HloOpcode::kReducePrecision:
+ return EmitReducePrecisionOp(instr);
case HloOpcode::kRemainder:
return CreateOpWithoutAttrs<lmhlo::RemOp>(instr);
+ case HloOpcode::kRoundNearestAfz:
+ return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr);
case HloOpcode::kRsqrt:
return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr);
case HloOpcode::kSelect:
return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr);
+ case HloOpcode::kShiftLeft:
+ return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr);
+ case HloOpcode::kShiftRightLogical:
+ return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr);
+ case HloOpcode::kShiftRightArithmetic:
+ return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr);
case HloOpcode::kSign:
return CreateOpWithoutAttrs<lmhlo::SignOp>(instr);
+ case HloOpcode::kSin:
+ return CreateOpWithoutAttrs<lmhlo::SinOp>(instr);
case HloOpcode::kSqrt:
return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr);
case HloOpcode::kSubtract:
return CreateOpWithoutAttrs<lmhlo::SubOp>(instr);
case HloOpcode::kTanh:
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
+ case HloOpcode::kXor:
+ return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
case HloOpcode::kSort:
return EmitSortOp(instr);
case HloOpcode::kFusion:
@@ -642,6 +691,92 @@
return reduce_op;
}
+StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(HloInstruction* instr) {
+ TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
+ auto* map = ::xla::Cast<::xla::HloMapInstruction>(instr);
+ std::vector<int64_t> dimensions(map->dimensions().begin(),
+ map->dimensions().end());
+ map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
+ TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
+ *instr->called_computations()[0], &map_op.computation(), &builder_));
+ return map_op;
+}
+
+StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
+ HloInstruction* instr) {
+ TF_ASSIGN_OR_RETURN(auto compare_op,
+ CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
+
+ auto* compare = ::xla::Cast<::xla::HloCompareInstruction>(instr);
+ auto direction = [&]() {
+ switch (compare->direction()) {
+ case xla::ComparisonDirection::kEq:
+ return mhlo::ComparisonDirection::EQ;
+ case xla::ComparisonDirection::kNe:
+ return mhlo::ComparisonDirection::NE;
+ case xla::ComparisonDirection::kGe:
+ return mhlo::ComparisonDirection::GE;
+ case xla::ComparisonDirection::kGt:
+ return mhlo::ComparisonDirection::GT;
+ case xla::ComparisonDirection::kLe:
+ return mhlo::ComparisonDirection::LE;
+ case xla::ComparisonDirection::kLt:
+ return mhlo::ComparisonDirection::LT;
+ }
+ }();
+ compare_op.comparison_directionAttr(
+ builder_.getStringAttr(stringifyComparisonDirection(direction)));
+ auto compare_type = [&]() {
+ switch (compare->type()) {
+ case xla::Comparison::Type::kFloat:
+ return mhlo::ComparisonType::FLOAT;
+ case xla::Comparison::Type::kFloatTotalOrder:
+ return mhlo::ComparisonType::TOTALORDER;
+ case xla::Comparison::Type::kSigned:
+ return mhlo::ComparisonType::SIGNED;
+ case xla::Comparison::Type::kUnsigned:
+ return mhlo::ComparisonType::UNSIGNED;
+ }
+ }();
+ compare_op.compare_typeAttr(
+ builder_.getStringAttr(stringifyComparisonType(compare_type)));
+ return compare_op;
+}
+
+StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
+ HloInstruction* instr) {
+ TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
+ CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
+ auto* reduce_precision =
+ ::xla::Cast<::xla::HloReducePrecisionInstruction>(instr);
+ reduce_precision_op.exponent_bitsAttr(
+ builder_.getI32IntegerAttr(reduce_precision->exponent_bits()));
+ reduce_precision_op.mantissa_bitsAttr(
+ builder_.getI32IntegerAttr(reduce_precision->mantissa_bits()));
+ return reduce_precision_op;
+}
+
+StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
+ HloInstruction* instr) {
+ TF_ASSIGN_OR_RETURN(auto all_reduce_op,
+ CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
+ auto* all_reduce = ::xla::Cast<::xla::HloAllReduceInstruction>(instr);
+ auto replica_groups_attr = ::xla::HloFunctionImporter::ConvertReplicaGroups(
+ all_reduce->replica_groups(), builder_);
+ all_reduce_op.setAttr(replica_groups_attr.first, replica_groups_attr.second);
+ all_reduce_op.constrain_layoutAttr(
+ builder_.getBoolAttr(all_reduce->constrain_layout()));
+ all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
+ builder_.getI64IntegerAttr(all_reduce->channel_id().value_or(0)),
+ builder_.getI64IntegerAttr(0), builder_.getContext()));
+ all_reduce_op.use_global_device_idsAttr(
+ builder_.getBoolAttr(all_reduce->use_global_device_ids()));
+ TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
+ *instr->called_computations()[0], &all_reduce_op.computation(),
+ &builder_));
+ return all_reduce_op;
+}
+
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 8214451..6c7bdd8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -72,6 +72,16 @@
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
const ::xla::HloInstruction* instr);
+ ::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
+
+ ::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
+
+ ::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
+ ::xla::HloInstruction* instr);
+
+ ::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
+ ::xla::HloInstruction* instr);
+
::xla::Status CreateOperands(::xla::HloInstruction* instr,
SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results);