Add HLO ops and TF lowering to MLIR for some element-wise unary operations
PiperOrigin-RevId: 269835024
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index a16e5db..6135848 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -112,6 +112,8 @@
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AbsOp;
+def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CeilOp;
+
def HLO_ConvertOp : HLO_UnaryElementwiseOp<
"convert", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ConvertOp {
let hasFolder = 1;
@@ -121,6 +123,8 @@
let hasCustomHLOConverter = 1;
}
+def HLO_CosOp: HLO_UnaryElementwiseOp<"cos", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CosOp;
+
def HLO_ExpOp: HLO_UnaryElementwiseOp<"exp", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ExpOp;
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_FloorOp;
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 6623c21..c95e279 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -61,6 +61,17 @@
}];
}
+class BASE_HLO_CeilOp {
+ string summary = "Ceil operator";
+
+ string description = [{
+ Returns `Ceil(operand)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
class BASE_HLO_ConvertOp {
string summary = "Convert operator";
@@ -72,6 +83,17 @@
}];
}
+class BASE_HLO_CosOp {
+ string summary = "Cos operator";
+
+ string description = [{
+ Returns `Cos(operand)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
class BASE_HLO_ExpOp {
string summary = "Exponential operator";
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
index 597e5b3..d792271 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
@@ -88,8 +88,12 @@
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
+def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp;
+
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert">, BASE_HLO_ConvertOp;
+def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cos">, BASE_HLO_CosOp;
+
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exp">, BASE_HLO_ExpOp;
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"neg">, BASE_HLO_NegOp;
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 42fb73f..a76bb95 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -438,6 +438,55 @@
// Unary op legalizations.
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @abs
+func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @ceil
+func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @cos
+func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @exp
+func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @floor
+func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @neg
+func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: tanh
+func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
// CHECK-LABEL: reshape
func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
// CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index f289350..ce83ed8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -181,6 +181,21 @@
// Unary op patterns.
//===----------------------------------------------------------------------===//
+foreach Mapping = [
+ [TF_AbsOp, HLO_AbsOp],
+ [TF_CeilOp, HLO_CeilOp],
+ [TF_CosOp, HLO_CosOp],
+ [TF_ExpOp, HLO_ExpOp],
+ [TF_FloorOp, HLO_FloorOp],
+ [TF_LogOp, HLO_LogOp],
+ [TF_NegOp, HLO_NegOp],
+ [TF_TanhOp, HLO_TanhOp],
+ ] in {
+ def : Pat<(Mapping[0] AnyTensor:$input),
+ (Mapping[1] $input)>;
+}
+
+
foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp] in {
def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored),
(HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;