Add pattern to replace
  (float)OneHot($X, $N, $on_val, $off_val, $axis)
with
  OneHot($X, $N, (float)$on_val, (float)$off_val, $axis)

PiperOrigin-RevId: 380084267
Change-Id: I6b0eb4ec81e592ec19935c5c531736701569f561
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index cdd923a..71ff55f 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -2204,3 +2204,28 @@
   // CHECK: %[[TMP:.*]] = "tfl.reshape"(%arg0, %[[CST1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x1xi32>
   // CHECK: %[[RES:.*]] = "tfl.equal"(%[[TMP]], %[[CST2]]) : (tensor<2x1xi32>, tensor<3xi32>) -> tensor<2x3xi1>
 }
+
+// CHECK-LABEL: fuseOneHotCast
+func @fuseOneHotCast(%arg: tensor<2xi32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+  %depth = constant dense<3> : tensor<i32>
+  %bool_on = constant dense<true> : tensor<i1>
+  %bool_off = constant dense<false> : tensor<i1>
+  %int_on = constant dense<5> : tensor<i32>
+  %int_off = constant dense<7> : tensor<i32>
+
+  %tmp_bool = "tfl.one_hot"(%arg, %depth, %bool_on, %bool_off) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<i1>, tensor<i1>) -> tensor<2x3xi1>
+  %result_bool = "tfl.cast"(%tmp_bool) : (tensor<2x3xi1>) -> tensor<2x3xf32>
+
+  %tmp_int = "tfl.one_hot"(%arg, %depth, %int_on, %int_off) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2x3xi1>
+  %result_int = "tfl.cast"(%tmp_int) : (tensor<2x3xi1>) -> tensor<2x3xf32>
+
+  return %result_bool, %result_int : tensor<2x3xf32>, tensor<2x3xf32>
+
+  // CHECK: %[[CST1:.*]] = constant dense<3> : tensor<i32>
+  // CHECK: %[[CST2:.*]] = constant dense<1.000000e+00> : tensor<f32>
+  // CHECK: %[[CST3:.*]] = constant dense<0.000000e+00> : tensor<f32>
+  // CHECK: %[[CST4:.*]] = constant dense<5.000000e+00> : tensor<f32>
+  // CHECK: %[[CST5:.*]] = constant dense<7.000000e+00> : tensor<f32>
+  // CHECK: %[[RES1:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<2x3xf32>
+  // CHECK: %[[RES2:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST4]], %[[CST5]]) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<2x3xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index b20e2e2..a1cac6f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -416,6 +416,36 @@
   return true;
 }
 
+// Converts an Attribute with a single value of float or integral type to an
+// Attribute holding a single value of float type. If attr has no elements, the
+// result is 0.0f.
+Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
+  const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
+  if (dense_fp_attr) {
+    // Already float => return
+    return dense_fp_attr;
+  }
+
+  OpBuilder builder(attr.getContext());
+
+  const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
+  const auto int_values = dense_int_attr.getIntValues();
+  float float_val = 0.0f;
+  if (!int_values.empty()) {
+    const APInt apint_val = *int_values.begin();
+    if (dense_int_attr.getType().getElementType().isSignedInteger()) {
+      // Get the sign-extended value (=>int64) if the type is signed.
+      float_val = apint_val.getSExtValue();
+    } else {
+      // Get the zero-extended value (=>uint64) if unsigned or signless.
+      float_val = apint_val.getZExtValue();
+    }
+  }
+  return DenseFPElementsAttr::get(
+      RankedTensorType::get({}, builder.getF32Type()),
+      {llvm::APFloat(float_val)});
+}
+
 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
 
 // Fuse Add with proceeding FullyConnected.
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 06669f6..b52f649 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -897,3 +897,27 @@
                 ConstantAttr<I32Attr, "-1">),
   [(IsLastElementEqualsOne $shape),
    (IsOneHotIndexAttribute $series)]>;
+
+def F32ElementsVal : Constraint<CPred<
+  "$0.getType().cast<TensorType>().getElementType().isF32()">,
+  "32 bit float tensor">;
+
+def ConvertSingleElementAttrToFloatAttr :
+  NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">;
+
+// Replace
+//   (float)OneHot(index, depth, on_val, off_val, axis)
+// With
+//   OneHot(index, depth, (float)on_val, (float)off_val, axis)
+def FuseOneHotAndCastToFloat : Pat<
+  (TFL_CastOp:$output (TFL_OneHotOp $indices,
+                                    $depth,
+                                    (ConstantOp $on_val),
+                                    (ConstantOp $off_val),
+                                    $axis)),
+  (TFL_OneHotOp $indices,
+                $depth,
+                (ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)),
+                (ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)),
+                $axis),
+  [(F32ElementsVal $output)]>;
\ No newline at end of file