Support complex types when converting HLO multiply op.

We can lower it to the MulOp in the complex dialect.

PiperOrigin-RevId: 375675079
Change-Id: I045fefe0d3f500d9378ab417e579b56c7129c7c1
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
index a4169eb..0a3c240 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
@@ -65,6 +65,7 @@
   using FOp = ::mlir::MulFOp;
   using IOp = ::mlir::MulIOp;
   using UOp = ::mlir::MulIOp;
+  using COp = ::mlir::complex::MulOp;
 };
 template <>
 struct LhloToScalarOp<lmhlo::RemOp> {
@@ -632,6 +633,19 @@
 }
 
 template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
+                                                  ArrayRef<Type> result_types,
+                                                  ArrayRef<Type> arg_types,
+                                                  ArrayRef<Value> args,
+                                                  OpBuilder* b) {
+  return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
+                                 isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
+                                 isFloatType, ScalarFOp<lmhlo::MulOp>,
+                                 isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
+      loc, result_types, arg_types, args, b);
+}
+
+template <>
 inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
                                                     ArrayRef<Type> result_types,
                                                     ArrayRef<Type> arg_types,
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
index 79543e6..08c158e 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
@@ -65,6 +65,19 @@
 
 // -----
 
+// CHECK-LABEL: func @complex_mul
+func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
+                  %rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
+  // CHECK: linalg.generic
+  // CHECK: complex.mul
+  %0 = "mhlo.multiply"(%lhs, %rhs)
+          : (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
+          -> tensor<2x2xcomplex<f32>>
+  return %0 : tensor<2x2xcomplex<f32>>
+}
+
+// -----
+
 // CHECK-LABEL: func @float_remainder
 func @float_remainder(%lhs: tensor<2x2xf32>,
                       %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
index dc96182..d980782 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
@@ -242,6 +242,20 @@
 
 // -----
 
+// CHECK-LABEL: func @complex_multiply
+func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
+                       %result: memref<2xcomplex<f64>>) {
+  "lmhlo.multiply"(%lhs, %rhs, %result)
+      : (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
+  return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
+// CHECK-NEXT:   %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : complex<f64>
+
+// -----
+
 // CHECK-LABEL: func @select
 func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
              %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {