[TFLite/MLIR] Adds constant folders for integer casts.
PiperOrigin-RevId: 341658771
Change-Id: Ia35978215c05e68c66e85e1d1ff8c1531e6099d8
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 7b99e5f..215812a 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -1973,6 +1973,43 @@
}
//===----------------------------------------------------------------------===//
+// CastOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1);
+ // For now, only supports cast between integer types.
+ auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+ if (!elements_attr) {
+ return nullptr;
+ }
+
+ auto result_element_type =
+ getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
+ auto operand_element_type = input()
+ .getType()
+ .cast<ShapedType>()
+ .getElementType()
+ .dyn_cast<IntegerType>();
+ // Returns nullptr if either result/operand element type is not integer.
+ if (!result_element_type || !operand_element_type) {
+ return nullptr;
+ }
+
+ const bool is_input_unsigned = operand_element_type.isUnsigned();
+ const int output_bitwidth = result_element_type.getWidth();
+ // The integer cast op is the same as C integer cast. Depends on the operand
+ // type's signedness, we will determine whether or not sign extension is
+ // needed.
+ auto cast = [&](APInt value) {
+ return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
+ : value.sextOrTrunc(output_bitwidth);
+ };
+
+ return elements_attr.mapValues(result_element_type, cast);
+}
+
+//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 84a46c3..5f1d9ea 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3443,6 +3443,8 @@
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
let hasOptions = 0;
+
+ let hasFolder = 1;
}
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
index 69009ae..27a7068 100644
--- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
@@ -587,3 +587,55 @@
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
// CHECK: return %[[CST]]
}
+
+// CHECK-LABEL: @cast_i64_to_i32
+func @cast_i64_to_i32() -> tensor<5xi32> {
+ %cst = constant dense<[-1, 0, 1, 2147483647, 2147483648]> : tensor<5xi64>
+ %0 = "tfl.cast"(%cst) : (tensor<5xi64>) -> tensor<5xi32>
+ return %0 : tensor<5xi32>
+
+// CHECK: %[[CST:.*]] = constant dense<[-1, 0, 1, 2147483647, -2147483648]> : tensor<5xi32>
+// CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_i32_to_ui8
+func @cast_i32_to_ui8() -> tensor<6xui8> {
+ %cst = constant dense<[0, -1, 256, 127, -128, -129]> : tensor<6xi32>
+ %0 = "tfl.cast"(%cst) : (tensor<6xi32>) -> tensor<6xui8>
+ return %0 : tensor<6xui8>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, 255, 0, 127, 128, 127]> : tensor<6xui8>
+// CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_ui8_to_i8
+func @cast_ui8_to_i8() -> tensor<4xi8> {
+ %cst = constant dense<[0, 255, 127, 128]> : tensor<4xui8>
+ %0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi8>
+ return %0 : tensor<4xi8>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, -1, 127, -128]> : tensor<4xi8>
+// CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_i8_to_i32
+func @cast_i8_to_i32() -> tensor<4xi32> {
+ %cst = constant dense<[0, 128, -1, -128]> : tensor<4xi8>
+ %0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, -128, -1, -128]> : tensor<4xi32>
+// CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_ui8_to_i32
+func @cast_ui8_to_i32() -> tensor<4xi32> {
+ %cst = constant dense<[0, 128, 129, 255]> : tensor<4xui8>
+ %0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, 128, 129, 255]> : tensor<4xi32>
+// CHECK: return %[[CST]]
+}
+
+
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt
index 096033e..95b970c 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt
@@ -78,14 +78,14 @@
}
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
-# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
-# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
-# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
-# CHECK: %[[VAL_5:.*]] = constant unit
-# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
-# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
-# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
-# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
+# CHECK-DAG: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+# CHECK-DAG: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
+# CHECK-DAG: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
+# CHECK-DAG: %[[VAL_5:.*]] = constant unit
+# CHECK-DAG: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
+# CHECK-DAG: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
+# CHECK-DAG: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
+# CHECK-DAG: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>