Let ConstantLike support complex constants.

PiperOrigin-RevId: 466643683
diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD
index d9d37f6..fb2fc30 100644
--- a/tensorflow/compiler/xla/mlir_hlo/BUILD
+++ b/tensorflow/compiler/xla/mlir_hlo/BUILD
@@ -481,6 +481,7 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithmeticDialect",
+        "@llvm-project//mlir:ComplexDialect",
         "@llvm-project//mlir:ControlFlowInterfaces",
         "@llvm-project//mlir:Dialect",
         "@llvm-project//mlir:FuncDialect",
@@ -1480,6 +1481,7 @@
         ":map_chlo_to_hlo_op",
         ":mlir_hlo",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ComplexDialect",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:SCFDialect",
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
index e687ff7..5110df7 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir-hlo/utils/hlo_utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -72,22 +73,11 @@
 static Value getConstantLike(OpBuilder& b, Location loc, T constant,
                              Value val) {
   Type ty = getElementTypeOrSelf(val.getType());
-  if (auto complexTy = ty.dyn_cast<ComplexType>()) {
-    // TODO(b/190374484): This code will only work for static shapes.
-    // The proper way to support these constants is through chlo.constant_like
-    // which then legalizes to code which works well for both static and dynamic
-    // shapes of val.
-    // The problem with that approach for complex numbers is that constant_like
-    // doesn't work for complex numbers - it carries constants via attributes,
-    // and there's no built-in attribute that carries complex numbers.
-    return b.create<mhlo::ConstantOp>(
-        loc,
-        hlo::getSplat(&b, val.getType().cast<RankedTensorType>(), constant));
-  }
-
   auto getAttr = [&]() -> Attribute {
     if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
     if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
+    if (auto complexTy = ty.dyn_cast<ComplexType>())
+      return complex::NumberAttr::get(complexTy, constant, 0);
     llvm_unreachable("unhandled element type");
   };
   return b.create<ConstantLikeOp>(loc, getAttr(), val);
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt
index 1e2aaf2..c0056a9 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt
@@ -33,6 +33,7 @@
 
   LINK_LIBS PUBLIC
   MhloDialect
+  MLIRComplexDialect
   MLIRIR
   MLIRMhloUtils
 )
@@ -54,6 +55,7 @@
 )
 target_link_libraries(MhloDialect
   PUBLIC
+  MLIRComplexDialect
   MLIRIR
   MLIRMhloUtils
   MLIRQuantDialect
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
index 79e6d8f..9b9551d 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
@@ -19,6 +19,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir-hlo/utils/broadcast_utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -398,6 +399,8 @@
   auto opType = operand().getType().cast<ShapedType>();
   if (!opType.hasStaticShape()) return {};
   auto type = RankedTensorType::get(opType.getShape(), value().getType());
+  if (auto complexAttr = value().dyn_cast<complex::NumberAttr>())
+    return DenseElementsAttr::get(type, complexAttr.getValue());
   return DenseElementsAttr::get(type, value());
 }
 
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 23e9b06..2946d23 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -48,6 +48,7 @@
 #include "mlir-hlo/utils/convert_op_folder.h"
 #include "mlir-hlo/utils/hlo_utils.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -717,14 +718,18 @@
   Type type;
   if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
     type = elemAttr.getType();
-  } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
-             value.isa<IntegerAttr>()) {
+  } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
     // All XLA types must be tensor types. In the build() method, we want to
     // provide more flexibility by allowing attributes of scalar types. But we
     // need to wrap it up with ElementsAttr to construct valid XLA constants.
     type =
         RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
     value = DenseElementsAttr::get(type.cast<TensorType>(), value);
+  } else if (auto complexAttr = value.dyn_cast<complex::NumberAttr>()) {
+    type = RankedTensorType::get(/*shape=*/{},
+                                 complexAttr.cast<TypedAttr>().getType());
+    value =
+        DenseElementsAttr::get(type.cast<TensorType>(), complexAttr.getValue());
   }
 
   // TODO: support other XLA specific types.
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
index 6d64326..7d4555c 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
@@ -217,6 +217,7 @@
   LINK_LIBS PUBLIC
   ChloDialect
   HloToLinalgUtils
+  MLIRComplexDialect
   MLIRIR
   MLIRPass
   MLIRRewrite
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
index f8818bd..3b00d26 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
@@ -30,6 +30,7 @@
 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 #include "mlir-hlo/utils/broadcast_utils.h"
 #include "mlir-hlo/utils/hlo_utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
@@ -58,8 +59,11 @@
 
     // Lower to MHLO constant if statically shaped.
     if (resultTy.hasStaticShape()) {
-      rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(
-          op, DenseElementsAttr::get(resultTy, op.value()));
+      auto complexAttr = op.value().dyn_cast<complex::NumberAttr>();
+      auto attr = complexAttr
+                      ? DenseElementsAttr::get(resultTy, complexAttr.getValue())
+                      : DenseElementsAttr::get(resultTy, op.value());
+      rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, attr);
       return success();
     }
 
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt
index aaa15fa..e9d72a2 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt
@@ -35,6 +35,7 @@
   Core
 
   LINK_LIBS PUBLIC
+  ChloDialect
   MLIRGPUOps
   MLIRHLOAnalysis
   MLIRIR
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
index a49be53..397ca38 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
@@ -95,6 +95,31 @@
 
 // -----
 
+// CHECK-LABEL:  func.func @asin_complex_f64_dynamic(
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+// CHECK:  %[[TWO:.*]] = mhlo.constant dense<(2.000000e+00,0.000000e+00)>
+// CHECK:  %[[SHAPE:.*]] = shape.shape_of %[[ARG0]]
+// CHECK:  %[[TWO_BROADCASTED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[SHAPE]])
+// CHECK:  %[[ONE:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)>
+// CHECK:  %[[SHAPE2:.*]] = shape.shape_of %[[ARG0]]
+// CHECK:  %[[ONE_BROADCASTED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ONE]], %[[SHAPE2]])
+// CHECK:  %[[ONE2:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)>
+// CHECK:  %[[SHAPE3:.*]] = shape.shape_of %[[ARG0]]
+// CHECK:  %[[ONE_BROADCASTED2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ONE2]], %[[SHAPE3]])
+// CHECK:  %[[SQUARE:.*]] = mhlo.multiply %[[ARG0]], %[[ARG0]]
+// CHECK:  %[[SUB:.*]] = mhlo.subtract %[[ONE_BROADCASTED2]], %[[SQUARE]]
+// CHECK:  %[[SQRT:.*]] = mhlo.sqrt %[[SUB]]
+// CHECK:  %[[ADD:.*]] = mhlo.add %[[ONE_BROADCASTED]], %[[SQRT]]
+// CHECK:  %[[ATAN2:.*]] = mhlo.atan2 %[[ARG0]], %[[ADD]]
+// CHECK:  %[[MUL:.*]] = mhlo.multiply %[[TWO_BROADCASTED]], %[[ATAN2]]
+// CHECK:  return %[[MUL]]
+func.func @asin_complex_f64_dynamic(%arg : tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>> {
+  %result = "chlo.asin"(%arg) : (tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+  func.return %result : tensor<?xcomplex<f64>>
+}
+
+// -----
+
 // CHECK-LABEL: @asinh_bf16
 // CHECK-SAME: %[[ARG:.*]]: tensor<bf16>
 func.func @asinh_bf16(%arg : tensor<bf16>) -> tensor<bf16> {