Add xla_hlo.log operator and import support
PiperOrigin-RevId: 265571163
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index f2a40f1..e2d08c8 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -358,6 +358,7 @@
NoAttributeCase(kConvert, ConvertOp);
NoAttributeCase(kDivide, DivOp);
NoAttributeCase(kExp, ExpOp);
+ NoAttributeCase(kLog, LogOp);
NoAttributeCase(kMaximum, MaxOp);
NoAttributeCase(kMinimum, MinOp);
NoAttributeCase(kMultiply, MulOp);
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index ff0b6da..389f51a 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -119,6 +119,8 @@
def HLO_ExpOp: HLO_UnaryElementwiseOp<"exp", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ExpOp;
+def HLO_LogOp: HLO_UnaryElementwiseOp<"log", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_LogOp;
+
def HLO_NegOp: HLO_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_NegOp;
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_SignOp;
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 5ea2de8..267d9ba 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -83,6 +83,17 @@
}];
}
+class BASE_HLO_LogOp {
+ string summary = "Logarithm operator";
+
+ string description = [{
+ Returns `log(operand)` element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
class BASE_HLO_NegOp {
string summary = "Negation operator";
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index 6ee398d..af0a61e 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -322,6 +322,14 @@
// -----
+func @log_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> {
+ // expected-error@+1 {{'xla_hlo.log' op requires the same type for all operands and results}}
+ %0 = "xla_hlo.log"(%arg0) : (tensor<1xf32>) -> tensor<1xi32>
+ return %0: tensor<1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @reshape_same_shape
func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/log.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/log.hlotxt
new file mode 100644
index 0000000..616ad0c
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/log.hlotxt
@@ -0,0 +1,12 @@
+// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
+
+HloModule foo
+
+// CHECK-LABEL: func @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
+ENTRY %foo (arg0.1: f32[16]) -> f32[16] {
+ %arg0.1 = f32[16] parameter(0)
+
+ // CHECK-NEXT: %0 = "xla_hlo.log"(%arg0) {name = "log.2"} : (tensor<16xf32>) -> tensor<16xf32>
+ // CHECK-NEXT: return %0 : tensor<16xf32>
+ ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
+}