[MLIR] Add XLA HLO -> LMHLO conversion for all elementwise ops.

PiperOrigin-RevId: 345557248
Change-Id: I5832bb00cb735489f6115c19007b68a49b434a0a
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index 47b100a..3f302f1 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -85,6 +85,8 @@
 def LHLO_BitcastConvertOp:
     LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
 
+def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp;
+
 def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
 
 def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index d122790..00f2f4b 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -139,6 +139,7 @@
         ":mlir_hlo_to_hlo",
         ":translate_cl_options",
         "//tensorflow/compiler/mlir/hlo",
+        "//tensorflow/compiler/mlir/hlo:hlo_ops_base_enums",
         "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
         "//tensorflow/compiler/xla:debug_options_flags",
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 762fb8a..23fab36 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -537,7 +537,8 @@
     }
     case HloOpcode::kAllReduce: {
       auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
-      attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups()));
+      attributes.push_back(
+          ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
       attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
       auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
           loc, result_type, operands, attributes);
@@ -932,7 +933,7 @@
 }
 
 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
-    const std::vector<ReplicaGroup>& replica_groups) {
+    const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
   int64_t num_groups = replica_groups.size();
   int64_t group_size =
       num_groups == 0 ? 0 : replica_groups[0].replica_ids_size();
@@ -944,9 +945,9 @@
       attr[flat_index++] = group.replica_ids(i);
   }
   auto type = mlir::RankedTensorType::get({num_groups, group_size},
-                                          builder_->getIntegerType(64));
-  return builder_->getNamedAttr("replica_groups",
-                                DenseIntElementsAttr::get(type, attr));
+                                          builder.getIntegerType(64));
+  return builder.getNamedAttr("replica_groups",
+                              DenseIntElementsAttr::get(type, attr));
 }
 
 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index a0f6e6c..d849b83 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -64,6 +64,12 @@
 
   static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
 
+  // Converts replica groups to attribute
+  //
+  // TODO(timshen): move this to attribute_importer.h.
+  static mlir::NamedAttribute ConvertReplicaGroups(
+      const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
+
  private:
   HloFunctionImporter(mlir::ModuleOp module,
                       std::unordered_map<const xla::HloComputation*,
@@ -136,10 +142,6 @@
   // padding low and padding high for each of the spatial dimensions.
   mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
 
-  // Converts replica groups to attribute
-  mlir::NamedAttribute ConvertReplicaGroups(
-      const std::vector<ReplicaGroup>& replica_groups);
-
   // Converts channel id to attribute
   mlir::NamedAttribute ConvertChannelHandle(
       absl::optional<tensorflow::int64> channel_id);
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index ce42ccf..0884230 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -169,3 +169,52 @@
   backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
 }
 
+// -----
+
+HloModule GemmBias
+
+// CHECK-LABEL: func @main
+// CHECK: "lmhlo_gpu.gemm_bias"
+// CHECK-SAME: algorithm = 0 : i64
+// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
+// CHECK-SAME: alpha_real = 1.000000e+00 : f64
+// CHECK-SAME: batch_size = 1 : i64
+// CHECK-SAME: beta = 1.000000e+00 : f64
+// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
+// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
+// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
+// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
+// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
+ENTRY main {
+  %A = f32[1,1]{1,0} parameter(0)
+  %B = f32[1,4]{1,0} parameter(1)
+  %C = f32[1,4]{1,0} parameter(2)
+  ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
+                                  custom_call_target="__cublas$gemm",
+  backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
+}
+
+// -----
+
+HloModule AllReduce
+
+// Test all-reduce
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK-LABEL: func @test_all_reduce
+// CHECK-SAME:  ([[INPUT:%.*]]: memref<8xf32>
+%test_all_reduce {
+  input = f32[8] parameter(0)
+  // CHECK:  "lmhlo.all_reduce"([[INPUT]], {{.*}})
+  // CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+  // CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+  // CHECK:    "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
+  // CHECK:  }) {
+  // CHECK-SAME:  channel_id = {handle = 1 : i64, type = 0 : i64}
+  // CHECK-SAME:  replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
+  ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
index cd72707..e39a654 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir
@@ -45,6 +45,34 @@
 
 // CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.atan2
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.atan2"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value: tensor<2x2xf32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.bitcast_convert
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.bitcast_convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
 func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
 // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
@@ -57,6 +85,63 @@
 // -----
 
 // CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.cbrt
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.cbrt"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 2
+// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8>
+func @main(%pred: tensor<2x2xf32>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.clamp
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
+// CHECK-NEXT: return
+  %0 = "mhlo.clamp"(%pred, %lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.count_leading_zeros
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.count_leading_zeros"(%value) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.compare
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.compare"(%value0, %value1) {comparison_direction="GT"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
+  return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
@@ -72,6 +157,19 @@
 // -----
 
 // CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
+func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf16> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<2x2xf16>
+// CHECK: lmhlo.convert
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf16>
+  return %res : tensor<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
 func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
@@ -118,6 +216,45 @@
 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
 func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
 // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.exponential_minus_one
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.exponential_minus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.floor
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.floor"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.is_finite
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.is_finite"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
+  return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
 // CHECK: lmhlo.log
 // CHECK-SAME: %[[ARG0]], %[[VIEW]]
   %res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
@@ -128,6 +265,39 @@
 
 // CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.log_plus_one
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.log_plus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.map
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK: return
+  %res = "mhlo.map"(%value0, %value1) ({
+  ^bb0(%a: tensor<f32>, %b: tensor<f32>):
+    %c = "mhlo.add"(%a, %b) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %ret = "mhlo.add"(%a, %c) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    "mhlo.return"(%ret) : (tensor<f32>) -> ()
+  }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
 func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -185,6 +355,90 @@
 // -----
 
 // CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xi1>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.not
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.not"(%value0) : (tensor<2x2xi1>) -> tensor<2x2xi1>
+  return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.not
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.not"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.or
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
+  return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.or
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.popcnt
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.popcnt"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.power
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.power"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
 func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
@@ -211,6 +465,19 @@
 // -----
 
 // CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.reduce_precision
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.reduce_precision"(%value0) {exponent_bits=5 : i32, mantissa_bits=12 : i32}: (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
@@ -230,6 +497,19 @@
 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
 func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
 // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.round_nearest_afz
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.round_nearest_afz"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
 // CHECK: lmhlo.rsqrt
 // CHECK-SAME: %[[ARG0]], %[[VIEW]]
   %res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
@@ -255,6 +535,51 @@
 // -----
 
 // CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.shift_left
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.shift_left"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.shift_right_arithmetic
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.shift_right_arithmetic"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.shift_right_logical
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.shift_right_logical"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
 func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -272,6 +597,19 @@
 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
 func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
 // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
+// CHECK: lmhlo.sine
+// CHECK-SAME: %[[ARG0]], %[[VIEW]]
+  %res = "mhlo.sine"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %res : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
 // CHECK: lmhlo.sqrt
 // CHECK-SAME: %[[ARG0]], %[[VIEW]]
   %res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
@@ -309,6 +647,36 @@
 // -----
 
 // CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
+func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
+// CHECK: lmhlo.xor
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
+  return %res : tensor<2x2xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
+// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
+// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
+// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
+func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
+// CHECK: lmhlo.xor
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
+// CHECK-NEXT: return
+  %res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+  return %res : tensor<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @main
 // CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32>
 // CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32>
 // CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {lmhlo.alloc = 0
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index bf11d52..f4ad6f0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -39,6 +39,7 @@
 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
 #include "mlir/Translation.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
@@ -229,12 +230,28 @@
       return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
     case HloOpcode::kAdd:
       return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
+    case HloOpcode::kAllReduce:
+      return EmitAllReduceOp(instr);
     case HloOpcode::kAnd:
       return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
+    case HloOpcode::kAtan2:
+      return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
+    case HloOpcode::kBitcastConvert:
+      return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
     case HloOpcode::kCeil:
       return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr);
+    case HloOpcode::kCbrt:
+      return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
+    case HloOpcode::kClamp:
+      return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
+    case HloOpcode::kClz:
+      return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
+    case HloOpcode::kCompare:
+      return EmitCompareOp(instr);
     case HloOpcode::kComplex:
       return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr);
+    case HloOpcode::kConvert:
+      return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr);
     case HloOpcode::kCopy:
       return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr);
     case HloOpcode::kCos:
@@ -243,10 +260,20 @@
       return CreateOpWithoutAttrs<lmhlo::DivOp>(instr);
     case HloOpcode::kExp:
       return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr);
+    case HloOpcode::kExpm1:
+      return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr);
+    case HloOpcode::kFloor:
+      return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
     case HloOpcode::kImag:
       return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
+    case HloOpcode::kIsFinite:
+      return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
     case HloOpcode::kLog:
       return CreateOpWithoutAttrs<lmhlo::LogOp>(instr);
+    case HloOpcode::kLog1p:
+      return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr);
+    case HloOpcode::kMap:
+      return EmitMapOp(instr);
     case HloOpcode::kMaximum:
       return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr);
     case HloOpcode::kMinimum:
@@ -255,22 +282,44 @@
       return CreateOpWithoutAttrs<lmhlo::MulOp>(instr);
     case HloOpcode::kNegate:
       return CreateOpWithoutAttrs<lmhlo::NegOp>(instr);
+    case HloOpcode::kNot:
+      return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
+    case HloOpcode::kOr:
+      return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
+    case HloOpcode::kPopulationCount:
+      return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
+    case HloOpcode::kPower:
+      return CreateOpWithoutAttrs<lmhlo::PowOp>(instr);
     case HloOpcode::kReal:
       return CreateOpWithoutAttrs<lmhlo::RealOp>(instr);
+    case HloOpcode::kReducePrecision:
+      return EmitReducePrecisionOp(instr);
     case HloOpcode::kRemainder:
       return CreateOpWithoutAttrs<lmhlo::RemOp>(instr);
+    case HloOpcode::kRoundNearestAfz:
+      return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr);
     case HloOpcode::kRsqrt:
       return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr);
     case HloOpcode::kSelect:
       return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr);
+    case HloOpcode::kShiftLeft:
+      return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr);
+    case HloOpcode::kShiftRightLogical:
+      return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr);
+    case HloOpcode::kShiftRightArithmetic:
+      return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr);
     case HloOpcode::kSign:
       return CreateOpWithoutAttrs<lmhlo::SignOp>(instr);
+    case HloOpcode::kSin:
+      return CreateOpWithoutAttrs<lmhlo::SinOp>(instr);
     case HloOpcode::kSqrt:
       return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr);
     case HloOpcode::kSubtract:
       return CreateOpWithoutAttrs<lmhlo::SubOp>(instr);
     case HloOpcode::kTanh:
       return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
+    case HloOpcode::kXor:
+      return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
     case HloOpcode::kSort:
       return EmitSortOp(instr);
     case HloOpcode::kFusion:
@@ -642,6 +691,92 @@
   return reduce_op;
 }
 
+StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(HloInstruction* instr) {
+  TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
+  auto* map = ::xla::Cast<::xla::HloMapInstruction>(instr);
+  std::vector<int64_t> dimensions(map->dimensions().begin(),
+                                  map->dimensions().end());
+  map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
+  TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
+      *instr->called_computations()[0], &map_op.computation(), &builder_));
+  return map_op;
+}
+
+StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
+    HloInstruction* instr) {
+  TF_ASSIGN_OR_RETURN(auto compare_op,
+                      CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
+
+  auto* compare = ::xla::Cast<::xla::HloCompareInstruction>(instr);
+  auto direction = [&]() {
+    switch (compare->direction()) {
+      case xla::ComparisonDirection::kEq:
+        return mhlo::ComparisonDirection::EQ;
+      case xla::ComparisonDirection::kNe:
+        return mhlo::ComparisonDirection::NE;
+      case xla::ComparisonDirection::kGe:
+        return mhlo::ComparisonDirection::GE;
+      case xla::ComparisonDirection::kGt:
+        return mhlo::ComparisonDirection::GT;
+      case xla::ComparisonDirection::kLe:
+        return mhlo::ComparisonDirection::LE;
+      case xla::ComparisonDirection::kLt:
+        return mhlo::ComparisonDirection::LT;
+    }
+  }();
+  compare_op.comparison_directionAttr(
+      builder_.getStringAttr(stringifyComparisonDirection(direction)));
+  auto compare_type = [&]() {
+    switch (compare->type()) {
+      case xla::Comparison::Type::kFloat:
+        return mhlo::ComparisonType::FLOAT;
+      case xla::Comparison::Type::kFloatTotalOrder:
+        return mhlo::ComparisonType::TOTALORDER;
+      case xla::Comparison::Type::kSigned:
+        return mhlo::ComparisonType::SIGNED;
+      case xla::Comparison::Type::kUnsigned:
+        return mhlo::ComparisonType::UNSIGNED;
+    }
+  }();
+  compare_op.compare_typeAttr(
+      builder_.getStringAttr(stringifyComparisonType(compare_type)));
+  return compare_op;
+}
+
+StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
+    HloInstruction* instr) {
+  TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
+                      CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
+  auto* reduce_precision =
+      ::xla::Cast<::xla::HloReducePrecisionInstruction>(instr);
+  reduce_precision_op.exponent_bitsAttr(
+      builder_.getI32IntegerAttr(reduce_precision->exponent_bits()));
+  reduce_precision_op.mantissa_bitsAttr(
+      builder_.getI32IntegerAttr(reduce_precision->mantissa_bits()));
+  return reduce_precision_op;
+}
+
+StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
+    HloInstruction* instr) {
+  TF_ASSIGN_OR_RETURN(auto all_reduce_op,
+                      CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
+  auto* all_reduce = ::xla::Cast<::xla::HloAllReduceInstruction>(instr);
+  auto replica_groups_attr = ::xla::HloFunctionImporter::ConvertReplicaGroups(
+      all_reduce->replica_groups(), builder_);
+  all_reduce_op.setAttr(replica_groups_attr.first, replica_groups_attr.second);
+  all_reduce_op.constrain_layoutAttr(
+      builder_.getBoolAttr(all_reduce->constrain_layout()));
+  all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
+      builder_.getI64IntegerAttr(all_reduce->channel_id().value_or(0)),
+      builder_.getI64IntegerAttr(0), builder_.getContext()));
+  all_reduce_op.use_global_device_idsAttr(
+      builder_.getBoolAttr(all_reduce->use_global_device_ids()));
+  TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
+      *instr->called_computations()[0], &all_reduce_op.computation(),
+      &builder_));
+  return all_reduce_op;
+}
+
 StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
     const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
     const ::xla::ShapeIndex& shape_index) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 8214451..6c7bdd8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -72,6 +72,16 @@
   ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
       const ::xla::HloInstruction* instr);
 
+  ::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
+
+  ::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
+
+  ::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
+      ::xla::HloInstruction* instr);
+
+  ::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
+      ::xla::HloInstruction* instr);
+
   ::xla::Status CreateOperands(::xla::HloInstruction* instr,
                                SmallVectorImpl<Value>& operands,
                                size_t& num_arguments, size_t& num_results);