Let ConstantLike support complex constants.
PiperOrigin-RevId: 466643683
diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD
index d9d37f6..fb2fc30 100644
--- a/tensorflow/compiler/xla/mlir_hlo/BUILD
+++ b/tensorflow/compiler/xla/mlir_hlo/BUILD
@@ -481,6 +481,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:FuncDialect",
@@ -1480,6 +1481,7 @@
":map_chlo_to_hlo_op",
":mlir_hlo",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
index e687ff7..5110df7 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
@@ -19,6 +19,7 @@
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/utils/hlo_utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
@@ -72,22 +73,11 @@
static Value getConstantLike(OpBuilder& b, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
- if (auto complexTy = ty.dyn_cast<ComplexType>()) {
- // TODO(b/190374484): This code will only work for static shapes.
- // The proper way to support these constants is through chlo.constant_like
- // which then legalizes to code which works well for both static and dynamic
- // shapes of val.
- // The problem with that approach for complex numbers is that constant_like
- // doesn't work for complex numbers - it carries constants via attributes,
- // and there's no built-in attribute that carries complex numbers.
- return b.create<mhlo::ConstantOp>(
- loc,
- hlo::getSplat(&b, val.getType().cast<RankedTensorType>(), constant));
- }
-
auto getAttr = [&]() -> Attribute {
if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
+ if (auto complexTy = ty.dyn_cast<ComplexType>())
+ return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return b.create<ConstantLikeOp>(loc, getAttr(), val);
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt
index 1e2aaf2..c0056a9 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt
@@ -33,6 +33,7 @@
LINK_LIBS PUBLIC
MhloDialect
+ MLIRComplexDialect
MLIRIR
MLIRMhloUtils
)
@@ -54,6 +55,7 @@
)
target_link_libraries(MhloDialect
PUBLIC
+ MLIRComplexDialect
MLIRIR
MLIRMhloUtils
MLIRQuantDialect
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
index 79e6d8f..9b9551d 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
@@ -19,6 +19,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/utils/broadcast_utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -398,6 +399,8 @@
auto opType = operand().getType().cast<ShapedType>();
if (!opType.hasStaticShape()) return {};
auto type = RankedTensorType::get(opType.getShape(), value().getType());
+ if (auto complexAttr = value().dyn_cast<complex::NumberAttr>())
+ return DenseElementsAttr::get(type, complexAttr.getValue());
return DenseElementsAttr::get(type, value());
}
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 23e9b06..2946d23 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -48,6 +48,7 @@
#include "mlir-hlo/utils/convert_op_folder.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -717,14 +718,18 @@
Type type;
if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
type = elemAttr.getType();
- } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
- value.isa<IntegerAttr>()) {
+ } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
// All XLA types must be tensor types. In the build() method, we want to
// provide more flexibility by allowing attributes of scalar types. But we
// need to wrap it up with ElementsAttr to construct valid XLA constants.
type =
RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
+ } else if (auto complexAttr = value.dyn_cast<complex::NumberAttr>()) {
+ type = RankedTensorType::get(/*shape=*/{},
+ complexAttr.cast<TypedAttr>().getType());
+ value =
+ DenseElementsAttr::get(type.cast<TensorType>(), complexAttr.getValue());
}
// TODO: support other XLA specific types.
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
index 6d64326..7d4555c 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
@@ -217,6 +217,7 @@
LINK_LIBS PUBLIC
ChloDialect
HloToLinalgUtils
+ MLIRComplexDialect
MLIRIR
MLIRPass
MLIRRewrite
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
index f8818bd..3b00d26 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
@@ -30,6 +30,7 @@
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir-hlo/utils/hlo_utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@@ -58,8 +59,11 @@
// Lower to MHLO constant if statically shaped.
if (resultTy.hasStaticShape()) {
- rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(
- op, DenseElementsAttr::get(resultTy, op.value()));
+ auto complexAttr = op.value().dyn_cast<complex::NumberAttr>();
+ auto attr = complexAttr
+ ? DenseElementsAttr::get(resultTy, complexAttr.getValue())
+ : DenseElementsAttr::get(resultTy, op.value());
+ rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, attr);
return success();
}
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt
index aaa15fa..e9d72a2 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/CMakeLists.txt
@@ -35,6 +35,7 @@
Core
LINK_LIBS PUBLIC
+ ChloDialect
MLIRGPUOps
MLIRHLOAnalysis
MLIRIR
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
index a49be53..397ca38 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
@@ -95,6 +95,31 @@
// -----
+// CHECK-LABEL: func.func @asin_complex_f64_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+// CHECK: %[[TWO:.*]] = mhlo.constant dense<(2.000000e+00,0.000000e+00)>
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]]
+// CHECK: %[[TWO_BROADCASTED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[SHAPE]])
+// CHECK: %[[ONE:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)>
+// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG0]]
+// CHECK: %[[ONE_BROADCASTED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ONE]], %[[SHAPE2]])
+// CHECK: %[[ONE2:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)>
+// CHECK: %[[SHAPE3:.*]] = shape.shape_of %[[ARG0]]
+// CHECK: %[[ONE_BROADCASTED2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ONE2]], %[[SHAPE3]])
+// CHECK: %[[SQUARE:.*]] = mhlo.multiply %[[ARG0]], %[[ARG0]]
+// CHECK: %[[SUB:.*]] = mhlo.subtract %[[ONE_BROADCASTED2]], %[[SQUARE]]
+// CHECK: %[[SQRT:.*]] = mhlo.sqrt %[[SUB]]
+// CHECK: %[[ADD:.*]] = mhlo.add %[[ONE_BROADCASTED]], %[[SQRT]]
+// CHECK: %[[ATAN2:.*]] = mhlo.atan2 %[[ARG0]], %[[ADD]]
+// CHECK: %[[MUL:.*]] = mhlo.multiply %[[TWO_BROADCASTED]], %[[ATAN2]]
+// CHECK: return %[[MUL]]
+func.func @asin_complex_f64_dynamic(%arg : tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>> {
+ %result = "chlo.asin"(%arg) : (tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+ func.return %result : tensor<?xcomplex<f64>>
+}
+
+// -----
+
// CHECK-LABEL: @asinh_bf16
// CHECK-SAME: %[[ARG:.*]]: tensor<bf16>
func.func @asinh_bf16(%arg : tensor<bf16>) -> tensor<bf16> {