Add pattern to replace
(float)OneHot($X, $N, $on_val, $off_val, $axis)
with
OneHot($X, $N, (float)$on_val, (float)$off_val, $axis)
PiperOrigin-RevId: 380084267
Change-Id: I6b0eb4ec81e592ec19935c5c531736701569f561
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index cdd923a..71ff55f 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -2204,3 +2204,28 @@
// CHECK: %[[TMP:.*]] = "tfl.reshape"(%arg0, %[[CST1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x1xi32>
// CHECK: %[[RES:.*]] = "tfl.equal"(%[[TMP]], %[[CST2]]) : (tensor<2x1xi32>, tensor<3xi32>) -> tensor<2x3xi1>
}
+
+// CHECK-LABEL: fuseOneHotCast
+func @fuseOneHotCast(%arg: tensor<2xi32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+ %depth = constant dense<3> : tensor<i32>
+ %bool_on = constant dense<true> : tensor<i1>
+ %bool_off = constant dense<false> : tensor<i1>
+ %int_on = constant dense<5> : tensor<i32>
+ %int_off = constant dense<7> : tensor<i32>
+
+ %tmp_bool = "tfl.one_hot"(%arg, %depth, %bool_on, %bool_off) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<i1>, tensor<i1>) -> tensor<2x3xi1>
+ %result_bool = "tfl.cast"(%tmp_bool) : (tensor<2x3xi1>) -> tensor<2x3xf32>
+
+ %tmp_int = "tfl.one_hot"(%arg, %depth, %int_on, %int_off) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2x3xi1>
+ %result_int = "tfl.cast"(%tmp_int) : (tensor<2x3xi1>) -> tensor<2x3xf32>
+
+ return %result_bool, %result_int : tensor<2x3xf32>, tensor<2x3xf32>
+
+ // CHECK: %[[CST1:.*]] = constant dense<3> : tensor<i32>
+ // CHECK: %[[CST2:.*]] = constant dense<1.000000e+00> : tensor<f32>
+ // CHECK: %[[CST3:.*]] = constant dense<0.000000e+00> : tensor<f32>
+ // CHECK: %[[CST4:.*]] = constant dense<5.000000e+00> : tensor<f32>
+ // CHECK: %[[CST5:.*]] = constant dense<7.000000e+00> : tensor<f32>
+ // CHECK: %[[RES1:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<2x3xf32>
+ // CHECK: %[[RES2:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST4]], %[[CST5]]) {axis = -1 : i32} : (tensor<2xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<2x3xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index b20e2e2..a1cac6f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -416,6 +416,36 @@
return true;
}
+// Converts an Attribute with a single value of float or integral type to an
+// Attribute holding a single value of float type. If attr has no elements, the
+// result is 0.0f.
+Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
+ const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
+ if (dense_fp_attr) {
+ // Already float => return
+ return dense_fp_attr;
+ }
+
+ OpBuilder builder(attr.getContext());
+
+ const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
+ const auto int_values = dense_int_attr.getIntValues();
+ float float_val = 0.0f;
+ if (!int_values.empty()) {
+ const APInt apint_val = *int_values.begin();
+ if (dense_int_attr.getType().getElementType().isSignedInteger()) {
+ // Get the sign-extended value (=>int64) if the type is signed.
+ float_val = apint_val.getSExtValue();
+ } else {
+ // Get the zero-extended value (=>uint64) if unsigned or signless.
+ float_val = apint_val.getZExtValue();
+ }
+ }
+ return DenseFPElementsAttr::get(
+ RankedTensorType::get({}, builder.getF32Type()),
+ {llvm::APFloat(float_val)});
+}
+
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with proceeding FullyConnected.
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 06669f6..b52f649 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -897,3 +897,27 @@
ConstantAttr<I32Attr, "-1">),
[(IsLastElementEqualsOne $shape),
(IsOneHotIndexAttribute $series)]>;
+
+def F32ElementsVal : Constraint<CPred<
+ "$0.getType().cast<TensorType>().getElementType().isF32()">,
+ "32 bit float tensor">;
+
+def ConvertSingleElementAttrToFloatAttr :
+ NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">;
+
+// Replace
+// (float)OneHot(index, depth, on_val, off_val, axis)
+// With
+// OneHot(index, depth, (float)on_val, (float)off_val, axis)
+def FuseOneHotAndCastToFloat : Pat<
+ (TFL_CastOp:$output (TFL_OneHotOp $indices,
+ $depth,
+ (ConstantOp $on_val),
+ (ConstantOp $off_val),
+ $axis)),
+ (TFL_OneHotOp $indices,
+ $depth,
+ (ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)),
+ (ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)),
+ $axis),
+ [(F32ElementsVal $output)]>;
\ No newline at end of file