[NFC] [XLA] Descriptive return type for ShapeUtil::InsertedOrDeleted1SizedDimensions
PiperOrigin-RevId: 457465355
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 98a0ded..e6e27ab 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -3479,15 +3479,12 @@
// A broadcast of a reshape which merely inserts 1-sized dimensions can
// elide its operand.
{
- bool merely_inserts_or_deletes_1_sized_dimensions;
- std::vector<int64_t> inserted_indices, deleted_indices;
- std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
- inserted_indices) =
+ std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
- if (merely_inserts_or_deletes_1_sized_dimensions &&
- deleted_indices.empty()) {
- std::reverse(inserted_indices.begin(), inserted_indices.end());
- for (auto inserted_index : inserted_indices) {
+ if (reshape_degenerate.has_value() &&
+ reshape_degenerate->deleted_dimensions.empty()) {
+ absl::c_reverse(reshape_degenerate->inserted_dimensions);
+ for (auto inserted_index : reshape_degenerate->inserted_dimensions) {
dims.erase(dims.begin() + inserted_index);
}
return ReplaceWithNewInstruction(
@@ -4366,17 +4363,13 @@
// enable other optimizations, e.g., merging with broadcast, and sparse update
// (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
if (!options_.is_layout_sensitive()) {
- bool trivial_reshape;
- std::vector<int64_t> deleted_dims;
- std::vector<int64_t> inserted_dims;
-
HloInstruction* dus;
HloInstruction* slice;
- std::tie(trivial_reshape, deleted_dims, inserted_dims) =
+ std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
// 1-sized dimensions added and removed will be one sized in both the update
// slice and the dynamic-update-slice result.
- if (trivial_reshape &&
+ if (trivial_reshape.has_value() &&
Match(reshape->mutable_operand(0),
m::Op(&dus)
.WithOpcode(HloOpcode::kDynamicUpdateSlice)
@@ -4391,10 +4384,11 @@
auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
const Shape& old_slice_shape = dus->operand(1)->shape();
for (int64_t i = 0; i <= old_slice_shape.rank(); ++i) {
- if (absl::c_linear_search(deleted_dims, i)) {
+ if (absl::c_linear_search(trivial_reshape->deleted_dimensions, i)) {
continue;
}
- while (absl::c_linear_search(inserted_dims, new_slice_shape.size())) {
+ while (absl::c_linear_search(trivial_reshape->inserted_dimensions,
+ new_slice_shape.size())) {
new_slice_shape.push_back(1);
new_dus_operands.push_back(zero);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8ccd403..10e15cc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -3991,11 +3991,10 @@
return OperandElementUse(*this, i) == UseKind::kReuse;
}
-std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+std::optional<ShapeUtil::ShapeEqualityDescriptor>
HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
if (HloOpcode::kReshape != opcode_) {
- return std::make_tuple(false, std::vector<int64_t>(),
- std::vector<int64_t>());
+ return std::nullopt;
}
return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
shape_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 7c065dd..ab6dcf8 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1671,7 +1671,7 @@
// dimensions.
//
// Precondition: this op must be a reshape.
- std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+ std::optional<ShapeUtil::ShapeEqualityDescriptor>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
// Gets the string identifier for this instruction.
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 4d8bea7..868d7ac 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1492,7 +1492,7 @@
return true;
case HloOpcode::kReshape:
return hlo.operand(0)->shape().rank() == 1 ||
- std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
+ hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value();
case HloOpcode::kScatter:
case HloOpcode::kTranspose:
return true;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index ea8db97..76b7151 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -15,6 +15,7 @@
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include <optional>
#include <tuple>
#include <utility>
#include <vector>
@@ -218,14 +219,10 @@
CHECK_EQ(multidim_.size(), output_shape.rank());
std::vector<llvm::Value*> source_multidim_index(
input_shape.rank(), llvm::UndefValue::get(index_type_));
- auto trivial_reshape =
- ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape);
- if (std::get<0>(trivial_reshape)) {
- // The 1-sized dimensions which only appear in 'input_shape'.
- auto deleted_dims_indices = std::get<1>(trivial_reshape);
- // The 1-sized dimensions which only appear in 'output_shape'.
- auto inserted_dims_indices = std::get<2>(trivial_reshape);
+ if (std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape,
+ output_shape)) {
// This is a two-way merge of 'deleted_dims_indices' with indexing into
// 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
// with indexing into 'multidim_'. When we find a dimension in
@@ -234,11 +231,12 @@
// indices that appear in 'inserted_dims_indices').
for (int64_t i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
++i) {
- if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) {
+ if (j == trivial_reshape->deleted_dimensions.size() ||
+ trivial_reshape->deleted_dimensions[j] > i) {
// This is a dimension that was preserved. Take the matching value from
// multidim_.
- while (l < inserted_dims_indices.size() &&
- inserted_dims_indices[l] == k) {
+ while (l < trivial_reshape->inserted_dimensions.size() &&
+ trivial_reshape->inserted_dimensions[l] == k) {
// Skip 1-sized dimensions.
++k;
++l;
diff --git a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc
index b5446bf..2a991c7 100644
--- a/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc
+++ b/tensorflow/compiler/xla/service/spmd/canonicalize_all_gather_for_cse.cc
@@ -41,8 +41,8 @@
// Also only do this for degenerate dimension sizes as the additional
// reshaping may not be worth the potential for CSE.
HloInstruction* real_data = ag->mutable_operand(0);
- while (std::get<0>(
- real_data->ReshapeMerelyInsertsOrDeletes1SizedDimensions())) {
+ while (real_data->ReshapeMerelyInsertsOrDeletes1SizedDimensions()
+ .has_value()) {
real_data = real_data->mutable_operand(0);
}
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index ff6ec1b..15d9afd 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1180,15 +1180,12 @@
return new_shape;
}
-/* static */ std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+/* static */ std::optional<ShapeUtil::ShapeEqualityDescriptor>
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
const Shape& shape_post) {
CHECK(shape_pre.IsArray());
CHECK(shape_post.IsArray());
- auto nil =
- std::make_tuple(false, std::vector<int64_t>(), std::vector<int64_t>());
-
std::vector<int64_t> deleted_indices;
std::vector<int64_t> inserted_indices;
// Returns false if any input/output index between prior_unmodified_dim_pair
@@ -1234,11 +1231,11 @@
? unmodified_dims[i]
: std::make_pair(shape_pre.rank(), shape_post.rank());
if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
- return nil;
+ return std::nullopt;
}
}
- return std::make_tuple(true, deleted_indices, inserted_indices);
+ return ShapeEqualityDescriptor{deleted_indices, inserted_indices};
}
/* static */ std::vector<std::pair<int64_t, int64_t>>
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 00f1b08..9e686a8 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -505,6 +505,16 @@
static Shape PermuteDimensions(absl::Span<const int64_t> permutation,
const Shape& shape);
+ // Describes how we can go from shape A to shape B by inserting degenerate
+ // 1-sized dimensions in `added_dimensions` and removing degenerate 1-sized
+ // dimensions from B in `removed_dimensions`.
+ //
+ // Only exists if shapes A and B only differ by degenerate dimensions.
+ struct ShapeEqualityDescriptor {
+ std::vector<int64_t> deleted_dimensions;
+ std::vector<int64_t> inserted_dimensions;
+ };
+
// If we can go from `shape_pre` to `shape_post` by merely inserting or
// deleting 1-sized dimensions, return the indices in `shape_pre` of the
// deleted dimensions and the indices in `dims_post` of the inserted
@@ -515,7 +525,7 @@
// `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
// `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
// For another example, if `shape_pre = shape_post = {}`, we return `{}`.
- static std::tuple<bool, std::vector<int64_t>, std::vector<int64_t>>
+ static std::optional<ShapeEqualityDescriptor>
InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
const Shape& shape_post);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 243e5a8..f645b38 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -571,10 +571,10 @@
Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4});
Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1});
Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12});
- EXPECT_TRUE(std::get<0>(
- ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1)));
- EXPECT_FALSE(std::get<0>(
- ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
+ EXPECT_TRUE(
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1).has_value());
+ EXPECT_FALSE(
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2).has_value());
}
TEST(ShapeUtilTest, ForEachIndex) {