Add conversion rule of segment_sum op for Tensorflow Lite MLIR and TOCO
PiperOrigin-RevId: 289545360
Change-Id: Ibe258ccdd660f28bbcf25ef03eddb845640fd9e4
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index e877714..a27589f 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3355,4 +3355,18 @@
}];
}
+def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
+ let summary = "SegmentSum operator";
+
+ let description = [{
+ Computes the sum along segments of a tensor.
+ }];
+
+ let arguments = (ins
+ TensorOf<[F32, I32]>:$data,
+ I32Tensor:$segment_ids
+ );
+ let results = (outs TensorOf<[F32, I32]>:$output);
+}
+
#endif // TFL_OPS
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index 596809d..45e427c 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -150,6 +150,7 @@
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
+def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>;
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 4b1a6fa..fa2119e 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -2442,6 +2442,8 @@
// MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
// their shapes are propagated.
break;
+ case OperatorType::kSegmentSum:
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index 26ce2af..457e06c 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -2608,6 +2608,7 @@
{"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
{"Round", ConvertRoundOperator},
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
+ {"SegmentSum", ConvertSimpleOperator<SegmentSumOperator, 2, 1>},
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
{"SelectV2", ConvertSimpleOperator<SelectOperator, 3, 1>},
{"Shape", ConvertShapeOperator},
diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h
index 7b07b1b..21236fe 100644
--- a/tensorflow/lite/toco/model.h
+++ b/tensorflow/lite/toco/model.h
@@ -2191,6 +2191,10 @@
MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {}
};
+struct SegmentSumOperator : Operator {
+ SegmentSumOperator() : Operator(OperatorType::kSegmentSum) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index f106e4c..7204504 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1987,6 +1987,8 @@
::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
::tflite::BuiltinOperator_RANK, OperatorType::kRank));
+ ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
+ ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
return ops;
}
} // namespace
diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc
index 40313f8..3bd5386 100644
--- a/tensorflow/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/lite/toco/tflite/operator_test.cc
@@ -727,6 +727,13 @@
EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
}
+TEST_F(OperatorTest, BuiltinSegmentSum) {
+ SegmentSumOperator op;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SEGMENT_SUM", OperatorType::kSegmentSum), op);
+ ASSERT_NE(nullptr, output_toco_op.get());
+}
+
TEST_F(OperatorTest, BuiltinReverseSequence) {
ReverseSequenceOperator op;
op.seq_dim = 3;
diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc
index ebcb175..fc666f1 100644
--- a/tensorflow/lite/toco/tooling_util.cc
+++ b/tensorflow/lite/toco/tooling_util.cc
@@ -387,6 +387,7 @@
HANDLE_OPERATORTYPENAME_CASE(Reshape)
HANDLE_OPERATORTYPENAME_CASE(Squeeze)
HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
+ HANDLE_OPERATORTYPENAME_CASE(SegmentSum)
HANDLE_OPERATORTYPENAME_CASE(Shape)
HANDLE_OPERATORTYPENAME_CASE(Slice)
HANDLE_OPERATORTYPENAME_CASE(Split)