Add conversion rule of segment_sum op for Tensorflow Lite MLIR and TOCO

PiperOrigin-RevId: 289545360
Change-Id: Ibe258ccdd660f28bbcf25ef03eddb845640fd9e4
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index e877714..a27589f 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3355,4 +3355,18 @@
   }];
 }
 
+def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
+  let summary = "SegmentSum operator";
+
+  let description = [{
+    Computes the sum along segments of a tensor.
+  }];
+
+  let arguments = (ins
+    TensorOf<[F32, I32]>:$data,
+    I32Tensor:$segment_ids
+  );
+  let results = (outs TensorOf<[F32, I32]>:$output);
+}
+
 #endif // TFL_OPS
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index 596809d..45e427c 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -150,6 +150,7 @@
 def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
 def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
 def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
+def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>;
 def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
 def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
 def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 4b1a6fa..fa2119e 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -2442,6 +2442,8 @@
       // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
       // their shapes are propagated.
       break;
+    case OperatorType::kSegmentSum:
+      break;
     default:
       // Unimplemented, another graph transformation should drop it.
       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index 26ce2af..457e06c 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -2608,6 +2608,7 @@
       {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
       {"Round", ConvertRoundOperator},
       {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
+      {"SegmentSum", ConvertSimpleOperator<SegmentSumOperator, 2, 1>},
       {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
       {"SelectV2", ConvertSimpleOperator<SelectOperator, 3, 1>},
       {"Shape", ConvertShapeOperator},
diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h
index 7b07b1b..21236fe 100644
--- a/tensorflow/lite/toco/model.h
+++ b/tensorflow/lite/toco/model.h
@@ -2191,6 +2191,10 @@
   MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {}
 };
 
+struct SegmentSumOperator : Operator {
+  SegmentSumOperator() : Operator(OperatorType::kSegmentSum) {}
+};
+
 // Alloc's are used for transient arrays only. An Alloc specifies which interval
 // of the "transient_data" workspace buffer passed to inference functions, is to
 // be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index f106e4c..7204504 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1987,6 +1987,8 @@
       ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
       ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
+  ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
+      ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
   return ops;
 }
 }  // namespace
diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc
index 40313f8..3bd5386 100644
--- a/tensorflow/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/lite/toco/tflite/operator_test.cc
@@ -727,6 +727,13 @@
   EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
 }
 
+TEST_F(OperatorTest, BuiltinSegmentSum) {
+  SegmentSumOperator op;
+  auto output_toco_op = SerializeAndDeserialize(
+      GetOperator("SEGMENT_SUM", OperatorType::kSegmentSum), op);
+  ASSERT_NE(nullptr, output_toco_op.get());
+}
+
 TEST_F(OperatorTest, BuiltinReverseSequence) {
   ReverseSequenceOperator op;
   op.seq_dim = 3;
diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc
index ebcb175..fc666f1 100644
--- a/tensorflow/lite/toco/tooling_util.cc
+++ b/tensorflow/lite/toco/tooling_util.cc
@@ -387,6 +387,7 @@
     HANDLE_OPERATORTYPENAME_CASE(Reshape)
     HANDLE_OPERATORTYPENAME_CASE(Squeeze)
     HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
+    HANDLE_OPERATORTYPENAME_CASE(SegmentSum)
     HANDLE_OPERATORTYPENAME_CASE(Shape)
     HANDLE_OPERATORTYPENAME_CASE(Slice)
     HANDLE_OPERATORTYPENAME_CASE(Split)