Support complex types when converting HLO multiply op.
We can lower it to the MulOp in the complex dialect.
PiperOrigin-RevId: 375675079
Change-Id: I045fefe0d3f500d9378ab417e579b56c7129c7c1
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
index a4169eb..0a3c240 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
@@ -65,6 +65,7 @@
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
using UOp = ::mlir::MulIOp;
+ using COp = ::mlir::complex::MulOp;
};
template <>
struct LhloToScalarOp<lmhlo::RemOp> {
@@ -632,6 +633,19 @@
}
template <>
+inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
+ ArrayRef<Type> result_types,
+ ArrayRef<Type> arg_types,
+ ArrayRef<Value> args,
+ OpBuilder* b) {
+ return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
+ isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
+ isFloatType, ScalarFOp<lmhlo::MulOp>,
+ isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
+ loc, result_types, arg_types, args, b);
+}
+
+template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Type> arg_types,
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
index 79543e6..08c158e 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
@@ -65,6 +65,19 @@
// -----
+// CHECK-LABEL: func @complex_mul
+func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
+ %rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
+ // CHECK: linalg.generic
+ // CHECK: complex.mul
+ %0 = "mhlo.multiply"(%lhs, %rhs)
+ : (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
+ -> tensor<2x2xcomplex<f32>>
+ return %0 : tensor<2x2xcomplex<f32>>
+}
+
+// -----
+
// CHECK-LABEL: func @float_remainder
func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
index dc96182..d980782 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
@@ -242,6 +242,20 @@
// -----
+// CHECK-LABEL: func @complex_multiply
+func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
+ %result: memref<2xcomplex<f64>>) {
+ "lmhlo.multiply"(%lhs, %rhs, %result)
+ : (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
+ return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
+// CHECK-NEXT: %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
+// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
+
+// -----
+
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {