[TFLite/MLIR] Adds constant folders for integer casts.

PiperOrigin-RevId: 341658771
Change-Id: Ia35978215c05e68c66e85e1d1ff8c1531e6099d8
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 7b99e5f..215812a 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -1973,6 +1973,43 @@
 }
 
 //===----------------------------------------------------------------------===//
+// CastOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1);
+  // For now, only supports cast between integer types.
+  auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (!elements_attr) {
+    return nullptr;
+  }
+
+  auto result_element_type =
+      getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
+  auto operand_element_type = input()
+                                  .getType()
+                                  .cast<ShapedType>()
+                                  .getElementType()
+                                  .dyn_cast<IntegerType>();
+  // Returns nullptr if either result/operand element type is not integer.
+  if (!result_element_type || !operand_element_type) {
+    return nullptr;
+  }
+
+  const bool is_input_unsigned = operand_element_type.isUnsigned();
+  const int output_bitwidth = result_element_type.getWidth();
+  // The integer cast op is the same as C integer cast. Depends on the operand
+  // type's signedness, we will determine whether or not sign extension is
+  // needed.
+  auto cast = [&](APInt value) {
+    return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
+                             : value.sextOrTrunc(output_bitwidth);
+  };
+
+  return elements_attr.mapValues(result_element_type, cast);
+}
+
+//===----------------------------------------------------------------------===//
 // SelectV2Op
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 84a46c3..5f1d9ea 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3443,6 +3443,8 @@
   // TFLite's cast op does not utilize CastOptions, instead derives types
   // from the TfLiteTensors.
   let hasOptions = 0;
+
+  let hasFolder = 1;
 }
 
 def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
index 69009ae..27a7068 100644
--- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir
@@ -587,3 +587,55 @@
 // CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
 // CHECK:  return %[[CST]]
 }
+
+// CHECK-LABEL: @cast_i64_to_i32
+func @cast_i64_to_i32() -> tensor<5xi32> {
+  %cst = constant dense<[-1, 0, 1, 2147483647, 2147483648]> : tensor<5xi64>
+  %0 = "tfl.cast"(%cst) : (tensor<5xi64>) -> tensor<5xi32>
+  return %0 : tensor<5xi32>
+
+// CHECK: %[[CST:.*]] = constant dense<[-1, 0, 1, 2147483647, -2147483648]> : tensor<5xi32>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_i32_to_ui8
+func @cast_i32_to_ui8() -> tensor<6xui8> {
+  %cst = constant dense<[0, -1, 256, 127, -128, -129]> : tensor<6xi32>
+  %0 = "tfl.cast"(%cst) : (tensor<6xi32>) -> tensor<6xui8>
+  return %0 : tensor<6xui8>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, 255, 0, 127, 128, 127]> : tensor<6xui8>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_ui8_to_i8
+func @cast_ui8_to_i8() -> tensor<4xi8> {
+  %cst = constant dense<[0, 255, 127, 128]> : tensor<4xui8>
+  %0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi8>
+  return %0 : tensor<4xi8>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, -1, 127, -128]> : tensor<4xi8>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_i8_to_i32
+func @cast_i8_to_i32() -> tensor<4xi32> {
+  %cst = constant dense<[0, 128, -1, -128]> : tensor<4xi8>
+  %0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, -128, -1, -128]> : tensor<4xi32>
+// CHECK:  return %[[CST]]
+}
+
+// CHECK-LABEL: @cast_ui8_to_i32
+func @cast_ui8_to_i32() -> tensor<4xi32> {
+  %cst = constant dense<[0, 128, 129, 255]> : tensor<4xui8>
+  %0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+
+// CHECK: %[[CST:.*]] = constant dense<[0, 128, 129, 255]> : tensor<4xi32>
+// CHECK:  return %[[CST]]
+}
+
+
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt
index 096033e..95b970c 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt
@@ -78,14 +78,14 @@
 }
 
 # CHECK:       func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
-# CHECK:           %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
-# CHECK:           %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
-# CHECK:           %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
-# CHECK:           %[[VAL_5:.*]] = constant unit
-# CHECK:           %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
-# CHECK:           %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
-# CHECK:           %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
-# CHECK:           %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
+# CHECK-DAG:       %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+# CHECK-DAG:       %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
+# CHECK-DAG:       %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
+# CHECK-DAG:       %[[VAL_5:.*]] = constant unit
+# CHECK-DAG:       %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
+# CHECK-DAG:       %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
+# CHECK-DAG:       %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
+# CHECK-DAG:       %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
 # CHECK:           %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
 # CHECK:           %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
 # CHECK:           %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>