[NFC] [XLA] Descriptive return type for ShapeUtil::InsertedOrDeleted1SizedDimensions

PiperOrigin-RevId: 457465355
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 98a0ded..e6e27ab 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -3479,15 +3479,12 @@
   // A broadcast of a reshape which merely inserts 1-sized dimensions can
   // elide its operand.
   {
-    bool merely_inserts_or_deletes_1_sized_dimensions;
-    std::vector<int64_t> inserted_indices, deleted_indices;
-    std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
-             inserted_indices) =
+    std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
         operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
-    if (merely_inserts_or_deletes_1_sized_dimensions &&
-        deleted_indices.empty()) {
-      std::reverse(inserted_indices.begin(), inserted_indices.end());
-      for (auto inserted_index : inserted_indices) {
+    if (reshape_degenerate.has_value() &&
+        reshape_degenerate->deleted_dimensions.empty()) {
+      absl::c_reverse(reshape_degenerate->inserted_dimensions);
+      for (auto inserted_index : reshape_degenerate->inserted_dimensions) {
         dims.erase(dims.begin() + inserted_index);
       }
       return ReplaceWithNewInstruction(
@@ -4366,17 +4363,13 @@
   // enable other optimizations, e.g., merging with broadcast, and sparse update
   // (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
   if (!options_.is_layout_sensitive()) {
-    bool trivial_reshape;
-    std::vector<int64_t> deleted_dims;
-    std::vector<int64_t> inserted_dims;
-
     HloInstruction* dus;
     HloInstruction* slice;
-    std::tie(trivial_reshape, deleted_dims, inserted_dims) =
+    std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
         reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
     // 1-sized dimensions added and removed will be one sized in both the update
     // slice and the dynamic-update-slice result.
-    if (trivial_reshape &&
+    if (trivial_reshape.has_value() &&
         Match(reshape->mutable_operand(0),
               m::Op(&dus)
                   .WithOpcode(HloOpcode::kDynamicUpdateSlice)
@@ -4391,10 +4384,11 @@
       auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
       const Shape& old_slice_shape = dus->operand(1)->shape();
       for (int64_t i = 0; i <= old_slice_shape.rank(); ++i) {
-        if (absl::c_linear_search(deleted_dims, i)) {
+        if (absl::c_linear_search(trivial_reshape->deleted_dimensions, i)) {
           continue;
         }
-        while (absl::c_linear_search(inserted_dims, new_slice_shape.size())) {
+        while (absl::c_linear_search(trivial_reshape->inserted_dimensions,
+                                     new_slice_shape.size())) {
           new_slice_shape.push_back(1);
           new_dus_operands.push_back(zero);
         }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8ccd403..10e15cc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -3991,11 +3991,10 @@
   return OperandElementUse(*this, i) == UseKind::kReuse;
 }
 
-std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+std::optional<ShapeUtil::ShapeEqualityDescriptor>
 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
   if (HloOpcode::kReshape != opcode_) {
-    return std::make_tuple(false, std::vector<int64_t>(),
-                           std::vector<int64_t>());
+    return std::nullopt;
   }
   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
                                                       shape_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 7c065dd..ab6dcf8 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1671,7 +1671,7 @@
   // dimensions.
   //
   // Precondition: this op must be a reshape.
-  std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+  std::optional<ShapeUtil::ShapeEqualityDescriptor>
   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
 
   // Gets the string identifier for this instruction.
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 4d8bea7..868d7ac 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1492,7 +1492,7 @@
       return true;
     case HloOpcode::kReshape:
       return hlo.operand(0)->shape().rank() == 1 ||
-             std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
+             hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value();
     case HloOpcode::kScatter:
     case HloOpcode::kTranspose:
       return true;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index ea8db97..76b7151 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 
+#include <optional>
 #include <tuple>
 #include <utility>
 #include <vector>
@@ -218,14 +219,10 @@
   CHECK_EQ(multidim_.size(), output_shape.rank());
   std::vector<llvm::Value*> source_multidim_index(
       input_shape.rank(), llvm::UndefValue::get(index_type_));
-  auto trivial_reshape =
-      ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape);
-  if (std::get<0>(trivial_reshape)) {
-    // The 1-sized dimensions which only appear in 'input_shape'.
-    auto deleted_dims_indices = std::get<1>(trivial_reshape);
-    // The 1-sized dimensions which only appear in 'output_shape'.
-    auto inserted_dims_indices = std::get<2>(trivial_reshape);
 
+  if (std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
+          ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape,
+                                                       output_shape)) {
     // This is a two-way merge of 'deleted_dims_indices' with indexing into
     // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
     // with indexing into 'multidim_'. When we find a dimension in
@@ -234,11 +231,12 @@
     // indices that appear in 'inserted_dims_indices').
     for (int64_t i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
          ++i) {
-      if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) {
+      if (j == trivial_reshape->deleted_dimensions.size() ||
+          trivial_reshape->deleted_dimensions[j] > i) {
         // This is a dimension that was preserved. Take the matching value from
         // multidim_.
-        while (l < inserted_dims_indices.size() &&
-               inserted_dims_indices[l] == k) {
+        while (l < trivial_reshape->inserted_dimensions.size() &&
+               trivial_reshape->inserted_dimensions[l] == k) {
           // Skip 1-sized dimensions.
           ++k;
           ++l;
diff --git a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc
index b5446bf..2a991c7 100644
--- a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc
+++ b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc
@@ -41,8 +41,8 @@
     // Also only do this for degenerate dimension sizes as the additional
     // reshaping may not be worth the potential for CSE.
     HloInstruction* real_data = ag->mutable_operand(0);
-    while (std::get<0>(
-        real_data->ReshapeMerelyInsertsOrDeletes1SizedDimensions())) {
+    while (real_data->ReshapeMerelyInsertsOrDeletes1SizedDimensions()
+               .has_value()) {
       real_data = real_data->mutable_operand(0);
     }
 
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index ff6ec1b..15d9afd 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1180,15 +1180,12 @@
   return new_shape;
 }
 
-/* static */ std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+/* static */ std::optional<ShapeUtil::ShapeEqualityDescriptor>
 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
                                              const Shape& shape_post) {
   CHECK(shape_pre.IsArray());
   CHECK(shape_post.IsArray());
 
-  auto nil =
-      std::make_tuple(false, std::vector<int64_t>(), std::vector<int64_t>());
-
   std::vector<int64_t> deleted_indices;
   std::vector<int64_t> inserted_indices;
   // Returns false if any input/output index between prior_unmodified_dim_pair
@@ -1234,11 +1231,11 @@
             ? unmodified_dims[i]
             : std::make_pair(shape_pre.rank(), shape_post.rank());
     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
-      return nil;
+      return std::nullopt;
     }
   }
 
-  return std::make_tuple(true, deleted_indices, inserted_indices);
+  return ShapeEqualityDescriptor{deleted_indices, inserted_indices};
 }
 
 /* static */ std::vector<std::pair<int64_t, int64_t>>
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 00f1b08..9e686a8 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -505,6 +505,16 @@
   static Shape PermuteDimensions(absl::Span<const int64_t> permutation,
                                  const Shape& shape);
 
+  // Describes how we can go from shape A to shape B by inserting degenerate
+  // 1-sized dimensions in `added_dimensions` and removing degenerate 1-sized
+  // dimensions from B in `removed_dimensions`.
+  //
+  // Only exists if shapes A and B only differ by degenerate dimensions.
+  struct ShapeEqualityDescriptor {
+    std::vector<int64_t> deleted_dimensions;
+    std::vector<int64_t> inserted_dimensions;
+  };
+
   // If we can go from `shape_pre` to `shape_post` by merely inserting or
   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
   // deleted dimensions and the indices in `dims_post` of the inserted
@@ -515,7 +525,7 @@
   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
-  static std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+  static std::optional<ShapeEqualityDescriptor>
   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
                                     const Shape& shape_post);
 
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 243e5a8..f645b38 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -571,10 +571,10 @@
   Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4});
   Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1});
   Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12});
-  EXPECT_TRUE(std::get<0>(
-      ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1)));
-  EXPECT_FALSE(std::get<0>(
-      ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
+  EXPECT_TRUE(
+      ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1).has_value());
+  EXPECT_FALSE(
+      ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2).has_value());
 }
 
 TEST(ShapeUtilTest, ForEachIndex) {