A fix so that a module and its clone produce similar compilation results. This fix ensures that fingerprint(compilation(M)) == fingerprint(compilation(clone(M))).
PiperOrigin-RevId: 449909501
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 5182537..a1fc3d7 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -558,6 +558,7 @@
deps = [
":hlo_module_config",
":hlo_proto_cc",
+ ":mapped_ptr_container_sorter",
":name_uniquer",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:comparison_util",
@@ -6303,3 +6304,30 @@
"@absl_py//absl/testing:absltest",
] + xla_py_test_deps(),
)
+
+cc_library(
+ name = "mapped_ptr_container_sorter",
+ hdrs = ["mapped_ptr_container_sorter.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:portable_gif_internal",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/core/platform:statusor",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "mapped_ptr_container_sorter_test",
+ srcs = ["mapped_ptr_container_sorter_test.cc"],
+ deps = [
+ ":mapped_ptr_container_sorter",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 606a9d6..b84f72e 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -17,8 +17,10 @@
#include <algorithm>
#include <cstddef>
+#include <cstdint>
#include <functional>
#include <list>
+#include <memory>
#include <queue>
#include <set>
#include <sstream>
@@ -35,9 +37,12 @@
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -1037,6 +1042,86 @@
context, suffix);
}
+namespace {
+
+// Sorts unordered_instructions according to the order of ordered_instructions,
+// using MappedPtrContainerSorter. context and replace are used to map
+// instructions in ordered_instructions to instructions in
+// unordered_instructions. Unmapped parameter instructions are placed just after
+// the last parameter instruction in the sorted mapped instruction order. All
+// other mapped instructions are placed at the end.
+void SortClonedInstructions(
+ const HloCloneContext& context,
+ const std::function<const HloInstruction*(const HloInstruction*)>& replace,
+ const HloComputation& computation,
+ const HloComputation::InstructionList& ordered_instructions,
+ std::vector<std::unique_ptr<HloInstruction>>& unordered_instructions) {
+ using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
+ InstructionSorter::MapPtrFn instruction_mapper =
+ [&context, &replace](const HloInstruction* i) {
+ return context.FindInstruction(replace(i));
+ };
+ size_t num_mapped_instructions = 0;
+ size_t mapped_index_of_last_parameter_plus_one = 0;
+ for (const auto& instruction : ordered_instructions) {
+ if (!instruction_mapper(instruction.get())) {
+ continue;
+ }
+ ++num_mapped_instructions;
+ if (!dynamic_cast<const HloParameterInstruction*>(instruction.get())) {
+ continue;
+ }
+ mapped_index_of_last_parameter_plus_one = num_mapped_instructions;
+ }
+ InstructionSorter::UnmappedPtrIndexFn unmapped_ptr_index =
+ [num_mapped_instructions,
+ mapped_index_of_last_parameter_plus_one](const HloInstruction* i) {
+ if (dynamic_cast<const HloParameterInstruction*>(i)) {
+ if (num_mapped_instructions > 0 &&
+ mapped_index_of_last_parameter_plus_one > 0) {
+ return mapped_index_of_last_parameter_plus_one - 1;
+ }
+ return InstructionSorter::IndexBeforeMappedElementsFn()(i);
+ }
+ return InstructionSorter::IndexAfterMappedElementsFn()(i);
+ };
+ auto status =
+ InstructionSorter::Sort(instruction_mapper, unmapped_ptr_index,
+ ordered_instructions, unordered_instructions);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to reorder instructions while cloning computation: "
+ << computation.name() << "; " << status;
+ }
+}
+
+// For cloned instructions, sorts their users, control predecessors, and control
+// successors, according to the orders of those lists in the original
+// instructions, before cloning. context and replace help us to map original
+// instructions to cloned instructions, in addition to creating a list of
+// cloned instructions.
+void SortClonedInstructionUsersAndControlLists(
+ const HloCloneContext& context,
+ const std::function<const HloInstruction*(const HloInstruction*)>& replace,
+ const HloComputation::InstructionList& sorted_instructions) {
+ using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
+ InstructionSorter::MapPtrFn instruction_mapper =
+ [context, &replace](const HloInstruction* i) {
+ return context.FindInstruction(replace(i));
+ };
+ for (const std::unique_ptr<HloInstruction>& instruction :
+ sorted_instructions) {
+ HloInstruction* cloned_instruction =
+ context.FindInstruction(replace(instruction.get()));
+ if (!cloned_instruction) {
+ continue;
+ }
+ cloned_instruction->SortInstructionUsersAndControlLists(instruction_mapper,
+ *instruction);
+ }
+}
+
+} // namespace
+
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
@@ -1131,6 +1216,11 @@
}
instructions.push_back(std::move(new_instr));
}
+
+ // To make clone behavior match uncloned behavior, we reorder instructions to
+ // match the order in instructions_.
+ SortClonedInstructions(*context, replace, *this, instructions_, instructions);
+
Builder builder(suffix.empty() ? name() : name() + "." + suffix);
for (auto& instr : instructions) {
builder.AddInstruction(std::move(instr));
@@ -1151,7 +1241,13 @@
}
}
}
+
+ // To make clone behavior match uncloned behavior, we reorder the user and
+ // control lists, kept by cloned instructions.
+ SortClonedInstructionUsersAndControlLists(*context, replace, instructions_);
+
context->MapComputation(this, result.get());
+
return result;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index e77e9ab..29d571f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -62,6 +62,9 @@
// f(y), f(z)].)
class HloComputation {
public:
+ // Used by instructions_.
+ using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
+
// Builder class for HloComputation.
class Builder {
public:
@@ -681,7 +684,6 @@
// Store instructions in std::list as they can be added and removed
// arbitrarily and we want a stable iteration order. Keep a map from
// instruction pointer to location in the list for fast lookup.
- using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ca76d28..446683f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -47,6 +47,7 @@
#include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -55,6 +56,7 @@
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
@@ -4389,6 +4391,36 @@
outer_dimension_partitions_ = outer_dimension_partitions;
}
+void HloInstruction::SortInstructionUsersAndControlLists(
+ const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
+ const HloInstruction& sorted_instruction) {
+ using Sorter = MappedPtrContainerSorter<HloInstruction>;
+ auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+ sorted_instruction.users_, users_);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to sort instruction users for " << name() << "; "
+ << status;
+ }
+ user_map_.clear();
+ for (uint64_t i = 0; i < users_.size(); ++i) {
+ user_map_[users_[i]] = i;
+ }
+ status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+ sorted_instruction.control_predecessors_,
+ control_predecessors_);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to sort instruction control predecessors for "
+ << name() << "; " << status;
+ }
+ status =
+ Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+ sorted_instruction.control_successors_, control_successors_);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to sort instruction control successors for " << name()
+ << "; " << status;
+ }
+}
+
// TODO(b/80131774): Remove these temporary methods after transition.
int64_t HloInstruction::feature_index() const {
return Cast<HloBatchNormInstruction>(this)->feature_index();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 7df5a2f..44ee604 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -49,6 +49,7 @@
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/types.h"
@@ -1836,6 +1837,13 @@
void set_outer_dimension_partitions(
const std::vector<int64_t>& outer_dimension_partitions);
+ // A method that sorts users_, control_predecessors_, and control_successors_
+ // according to the orders used in sorted_instruction. The sorting is used
+ // during cloning, to make clone behavior match uncloned behavior.
+ void SortInstructionUsersAndControlLists(
+ const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
+ const HloInstruction& sorted_instruction);
+
// Old methods kept for smooth subclassing transition BEGIN.
// TODO(b/80131774): Remove this code.
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index c390fb2..85624c8 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -16,7 +16,10 @@
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include <algorithm>
+#include <cstdint>
+#include <functional>
#include <iterator>
+#include <memory>
#include <set>
#include <sstream>
#include <string>
@@ -33,6 +36,7 @@
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -40,6 +44,7 @@
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stacktrace.h"
namespace xla {
@@ -849,6 +854,22 @@
const auto& indices = parameter_indices.second;
module->AddCrossProgramPrefetch(parameter, indices);
}
+
+ // To make clone behavior match uncloned behavior, we reorder
+ // module->computations_ to match the order in computations_.
+ using ComputationSorter = MappedPtrContainerSorter<HloComputation>;
+ ComputationSorter::MapPtrFn computation_map_fn =
+ [&context](const HloComputation* c) {
+ return context.FindComputation(c);
+ };
+ auto status = ComputationSorter::Sort(
+ computation_map_fn, ComputationSorter::IndexAfterMappedElementsFn(),
+ computations_, module->computations_);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to sort module computations for " << name() << "; "
+ << status;
+ }
+
return module;
}
diff --git a/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h
new file mode 100644
index 0000000..8e602fb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h
@@ -0,0 +1,456 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Below, we specify an example usage, in which clone is sorted according to
+// original, using map_fn to map from pointers in original to pointers in clone.
+//
+// std::vector<std::unique_ptr<HloInstruction*>> original = ...;
+// std::vector<std::unique_ptr<HloInstruction*>> clone = ...;
+// HloCloneContext* ctx = ...;
+// using Sorter = MappedPtrContainerSorter<HloInstruction>;
+// Sorter::MappedPtrFn map_fn = [ctx](const HloInstruction* i) {
+// return ctx->FindInstruction(i);
+// };
+//
+// auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+// original, clone);
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
+
+#include <array>
+#include <cstddef>
+#include <functional>
+#include <limits>
+#include <list>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/statusor.h"
+
+namespace xla {
+
+// A class for sorting an unordered container of pointers according to the sort
+// order of an ordered container of pointers. Sorting is stable.
+//
+// Terminology:
+// - unmapped element: An element from the unordered container that does not
+// have a corresponding element in the ordered container.
+template <typename PointedToTy>
+class MappedPtrContainerSorter {
+ public:
+ // A function to map elements from an ordered container to elements in an
+ // unordered container. Not every element in ordered_container need map to an
+ // element in unordered_container and vice versa.
+ using MapPtrFn = std::function<const PointedToTy*(const PointedToTy*)>;
+
+ // A function that maps unmapped elements (from an unordered container) to an
+ // index in the final sorted result. The returned index indicates that the
+ // unmapped element should be placed just after the mapped element at that
+ // index, in the result without unmapped elements. See
+ // IndexBeforeMappedElementsFn() and IndexAfterMappedElementsFn() for how to
+ // indicate that an unmapped element should be placed before or after all
+ // mapped elements, respectively. Unmapped elements destined for the same
+ // index will retain their order from the unordered container.
+ using UnmappedPtrIndexFn = std::function<size_t(const PointedToTy*)>;
+
+ // Functions that return an UnmappedElementIndexFn that indicates that
+ // ummapped elements (from an unordered container) should be placed before or
+ // after all mapped elements, respectively.
+ static const UnmappedPtrIndexFn& IndexBeforeMappedElementsFn();
+ static const UnmappedPtrIndexFn& IndexAfterMappedElementsFn();
+
+ // Returned function always returns an error.
+ static const UnmappedPtrIndexFn& InvalidIndexFn();
+
+ // Sorts an unordered container of pointers according to the order of an
+ // ordered container of pointers. Sorting is stable. Works with POD pointers,
+ // const POD pointers, and unique_ptrs. If an error is returned,
+ // unordered_container is not modified. Returns an error status if:
+ // - unmapped_index() returns an invalid index
+ // - An internal error occurs. (This should theoretically not happen.)
+ template <typename OrderedTy, typename UnorderedTy>
+ static Status Sort(const MapPtrFn& map_ptr,
+ const UnmappedPtrIndexFn& unmapped_index,
+ const OrderedTy& ordered_container,
+ UnorderedTy& unordered_container);
+
+ private:
+ // A class for sorting the indices of the unordered_container.
+ class SortedIndices {
+ public:
+ // max_partial_order_exclusive is 1 greater than the maximum partial order
+ // value allowed to be sent to AddMappedElement().
+ SortedIndices(size_t max_partial_order_exclusive,
+ size_t unordered_container_size)
+ : max_partial_order_exclusive_(max_partial_order_exclusive),
+ unordered_container_size_(unordered_container_size),
+ mapped_element_indices_by_partial_order_(
+ max_partial_order_exclusive) {}
+
+ // Specify the partial ordering value of a mapped element from the
+ // unordered container. The partial ordering is amongst other mapped
+ // elements.
+ Status AddMappedElement(size_t unordered_container_index,
+ size_t partial_order);
+
+ // Specify the index (amongst mapped elements), where an unmapped element
+ // should be inserted. The unmapped element is inserted just after the
+ // mapped element with index target_index_amongst_mapped_elements.
+ void AddUnmappedElement(size_t unordered_container_index,
+ size_t target_index_amongst_mapped_elements);
+
+ std::string ToString() const;
+
+ // The result maps each element in the unordered_container to the target
+ // index that it will occupy in the sorted result.
+ StatusOr<std::vector<size_t>> Flatten() const;
+
+ private:
+ SortedIndices() = delete;
+
+ size_t max_partial_order_exclusive_;
+ size_t unordered_container_size_;
+ std::vector<std::vector<size_t>> mapped_element_indices_by_partial_order_;
+ absl::flat_hash_map<size_t, std::vector<size_t>>
+ target_index_to_unmapped_element_index_;
+ };
+
+ static size_t IndexBeforeMappedElements() {
+ return std::numeric_limits<size_t>::max() - 2;
+ }
+
+ static size_t IndexAfterMappedElements() {
+ return std::numeric_limits<size_t>::max() - 1;
+ }
+
+ static size_t InvalidIndex() { return std::numeric_limits<size_t>::max(); }
+
+ // Returns a mapping in which the element at index i indicates the target
+ // index that unordered_container[i] should occupy in the sorted result.
+ template <typename OrderedTy, typename UnorderedTy>
+ static StatusOr<std::vector<size_t>> ComputeNewIndices(
+ const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
+ const OrderedTy& ordered_container,
+ const UnorderedTy& unordered_container);
+
+ // Reorders unordered_container according to the indices in new_indices. See
+ // ComputeNewIndices() for how to interpret new_indices.
+ template <typename UnorderedTy>
+ static void Reorder(std::vector<size_t> new_indices,
+ UnorderedTy& unordered_container);
+};
+
+///// Template implementation below /////
+
+namespace mapped_ptr_container_sorter_internal {
+
+template <typename I, typename O>
+struct PtrGetter {
+ // Extracts a pointer of type O from i.
+ static O Get(I i);
+};
+
+template <typename T>
+struct PtrGetter<T* const&, const T*> {
+ static const T* Get(T* const& p) { return p; }
+};
+
+template <typename T>
+struct PtrGetter<T const* const&, const T*> {
+ static const T* Get(T const* const& p) { return p; }
+};
+
+template <typename T>
+struct PtrGetter<T*&, T*> {
+ static T* Get(T*& p) { return p; }
+};
+
+template <typename T>
+struct PtrGetter<const std::unique_ptr<T>&, const T*> {
+ static const T* Get(const std::unique_ptr<T>& p) { return p.get(); }
+};
+
+template <typename T>
+struct PtrGetter<std::unique_ptr<T>&, T*> {
+ static T* Get(std::unique_ptr<T>& p) { return p.get(); }
+};
+
+} // namespace mapped_ptr_container_sorter_internal
+
+template <typename PointedToTy>
+const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
+MappedPtrContainerSorter<PointedToTy>::IndexBeforeMappedElementsFn() {
+ static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn(
+ [](const PointedToTy*) { return IndexBeforeMappedElements(); });
+ return *fn;
+}
+
+template <typename PointedToTy>
+const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
+MappedPtrContainerSorter<PointedToTy>::IndexAfterMappedElementsFn() {
+ static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn(
+ [](const PointedToTy*) { return IndexAfterMappedElements(); });
+ return *fn;
+}
+
+template <typename PointedToTy>
+const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
+MappedPtrContainerSorter<PointedToTy>::InvalidIndexFn() {
+ static const UnmappedPtrIndexFn* fn =
+ new UnmappedPtrIndexFn([](const PointedToTy*) { return InvalidIndex(); });
+ return *fn;
+}
+
+template <typename PointedToTy>
+Status MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddMappedElement(
+ size_t unordered_container_index, size_t partial_order) {
+ if (partial_order >= mapped_element_indices_by_partial_order_.size()) {
+ return InternalErrorStrCat(
+ "invalid partial order: ", partial_order, " v max(",
+ mapped_element_indices_by_partial_order_.size(), ")");
+ }
+
+ mapped_element_indices_by_partial_order_[partial_order].push_back(
+ unordered_container_index);
+ return Status::OK();
+}
+
+template <typename PointedToTy>
+void MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddUnmappedElement(
+ size_t unordered_container_index,
+ size_t target_index_amongst_mapped_elements) {
+ target_index_to_unmapped_element_index_[target_index_amongst_mapped_elements]
+ .push_back(unordered_container_index);
+}
+
+template <typename PointedToTy>
+std::string MappedPtrContainerSorter<PointedToTy>::SortedIndices::ToString()
+ const {
+ std::vector<std::string> mapped_element_strs;
+ mapped_element_strs.reserve(mapped_element_indices_by_partial_order_.size());
+ for (const auto& indices : mapped_element_indices_by_partial_order_) {
+ mapped_element_strs.push_back(
+ absl::StrCat("[", absl::StrJoin(indices, ", "), "]"));
+ }
+ std::vector<std::string> unmapped_element_strs;
+ unmapped_element_strs.reserve(target_index_to_unmapped_element_index_.size());
+ for (const auto& kv : target_index_to_unmapped_element_index_) {
+ std::string key = absl::StrCat(kv.first);
+ if (kv.first == IndexBeforeMappedElements()) {
+ key = "before_mapped";
+ }
+ if (kv.first == IndexAfterMappedElements()) {
+ key = "after_mapped";
+ }
+ if (kv.first == InvalidIndex()) {
+ key = "invalid";
+ }
+ unmapped_element_strs.push_back(
+ absl::StrCat(key, ": [", absl::StrJoin(kv.second, ", "), "]"));
+ }
+
+ return absl::StrCat(
+ "max_partial_order_exclusive_: ", max_partial_order_exclusive_, "\n",
+ "unordered_container_size_: ", unordered_container_size_, "\n",
+ "mapped_element_indices_by_partial_order_: [",
+ absl::StrJoin(mapped_element_strs, ", "), "]\n",
+ "target_index_to_unmapped_element_index_: {",
+ absl::StrJoin(unmapped_element_strs, ", "), "}\n");
+}
+
+template <typename PointedToTy>
+StatusOr<std::vector<size_t>>
+MappedPtrContainerSorter<PointedToTy>::SortedIndices::Flatten() const {
+ std::vector<size_t> result(unordered_container_size_, InvalidIndex());
+ size_t next_available_index = 0;
+ auto next_index_fn = [&]() -> StatusOr<size_t> {
+ if (next_available_index >= unordered_container_size_) {
+ return InternalErrorStrCat(
+ "invalid unordered_container index: ", next_available_index,
+ " v size(", unordered_container_size_, ")");
+ }
+ return next_available_index++;
+ };
+
+ if (target_index_to_unmapped_element_index_.contains(
+ IndexBeforeMappedElements())) {
+ const auto& indices =
+ target_index_to_unmapped_element_index_.at(IndexBeforeMappedElements());
+ for (size_t index : indices) {
+ TF_ASSIGN_OR_RETURN(result[index], next_index_fn());
+ }
+ }
+ size_t num_inserted_mapped_elements = 0;
+ for (const auto& mapped_element_indices :
+ mapped_element_indices_by_partial_order_) {
+ for (size_t mapped_element_index : mapped_element_indices) {
+ TF_ASSIGN_OR_RETURN(result[mapped_element_index], next_index_fn());
+ ++num_inserted_mapped_elements;
+ if (target_index_to_unmapped_element_index_.contains(
+ num_inserted_mapped_elements - 1)) {
+ const auto& unmapped_element_indices =
+ target_index_to_unmapped_element_index_.at(
+ num_inserted_mapped_elements - 1);
+ for (size_t unmapped_element_index : unmapped_element_indices) {
+ TF_ASSIGN_OR_RETURN(result[unmapped_element_index], next_index_fn());
+ }
+ }
+ }
+ }
+ if (target_index_to_unmapped_element_index_.contains(
+ IndexAfterMappedElements())) {
+ const auto& indices =
+ target_index_to_unmapped_element_index_.at(IndexAfterMappedElements());
+ for (size_t index : indices) {
+ TF_ASSIGN_OR_RETURN(result[index], next_index_fn());
+ }
+ }
+
+ // Ensure that every element in unordered_container has a valid new index.
+ absl::flat_hash_set<size_t> used_indices;
+ for (size_t index : result) {
+ if (used_indices.contains(index)) {
+ return InternalErrorStrCat(
+ "2 elements in unordered_container are destined for the same "
+ "index: ",
+ index);
+ }
+ if (index >= unordered_container_size_) {
+ return InvalidArgumentStrCat("invalid unordered_container index: ", index,
+ " v size(", unordered_container_size_, ")");
+ }
+ }
+
+ return result;
+}
+
+template <typename PointedToTy>
+template <typename OrderedTy, typename UnorderedTy>
+StatusOr<std::vector<size_t>>
+MappedPtrContainerSorter<PointedToTy>::ComputeNewIndices(
+ const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
+ const OrderedTy& ordered_container,
+ const UnorderedTy& unordered_container) {
+ using UnorderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter<
+ typename UnorderedTy::const_reference, const PointedToTy*>;
+ using OrderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter<
+ typename OrderedTy::const_reference, const PointedToTy*>;
+
+ if (unordered_container.size() >= IndexBeforeMappedElements()) {
+ return InvalidArgumentStrCat("Unordered container is too large to sort.");
+ }
+
+ // Step 1: build a set of the ptrs in unordered_container
+ absl::flat_hash_set<const PointedToTy*> unordered_ptrs;
+ for (const auto& unordered_element : unordered_container) {
+ const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_element);
+ unordered_ptrs.insert(ptr);
+ }
+
+ // Step 2: for mapped elements (in unordered_container), create a map from
+ // mapped ptr -> partial ordering
+ absl::flat_hash_map<const PointedToTy*, std::list<size_t>>
+ mapped_ptr_to_partial_order;
+ size_t next_partial_order_value = 0;
+ for (const auto& ordered_element : ordered_container) {
+ const PointedToTy* ordered_ptr = OrderedPtrGetter::Get(ordered_element);
+ const PointedToTy* unordered_ptr = map_ptr(ordered_ptr);
+ if (!unordered_ptr) {
+ // A corresponding unordered element does not exist.
+ continue;
+ }
+ if (!unordered_ptrs.contains(unordered_ptr)) {
+ // A pointer exists that maps to the ordered element, but it's not in our
+ // unordered_container.
+ continue;
+ }
+ mapped_ptr_to_partial_order[unordered_ptr].push_back(
+ next_partial_order_value);
+ ++next_partial_order_value;
+ }
+
+ // Step 3: create sorted unordered element indices
+ SortedIndices result(next_partial_order_value, unordered_container.size());
+ for (size_t i = 0; i < unordered_container.size(); ++i) {
+ const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_container[i]);
+ if (!mapped_ptr_to_partial_order.contains(ptr)) {
+ // ptr is unmapped
+ result.AddUnmappedElement(i, unmapped_index(ptr));
+ continue;
+ }
+
+ // ptr is mapped
+ //
+ // Potentially, several elements in ordered_container map to ptr.
+ // We assign ptr theindex corresponding to the next such ordered element.
+ auto& index_list = mapped_ptr_to_partial_order[ptr];
+ TF_RETURN_IF_ERROR(result.AddMappedElement(i, index_list.front()));
+ // Do not map more than one unordered element to the same index, unless we
+ // have no choice.
+ if (index_list.size() > 1) {
+ // We never remove the last ordered index, in case ptr appears in the
+ // unordered_container more times than the ordered container.
+ index_list.pop_front();
+ }
+ }
+
+ VLOG(5) << "Pre flatten unordered_container result:\n" << result.ToString();
+ return result.Flatten();
+}
+
+template <typename PointedToTy>
+template <typename UnorderedTy>
+void MappedPtrContainerSorter<PointedToTy>::Reorder(
+ std::vector<size_t> new_indices, UnorderedTy& unordered_container) {
+ size_t old_pos = 0;
+ while (old_pos < new_indices.size()) {
+ size_t new_pos = new_indices[old_pos];
+ if (old_pos == new_pos) {
+ ++old_pos;
+ continue;
+ }
+ std::swap(new_indices[old_pos], new_indices[new_pos]);
+ std::swap(unordered_container[old_pos], unordered_container[new_pos]);
+ }
+}
+
+template <typename PointedToTy>
+template <typename OrderedTy, typename UnorderedTy>
+Status MappedPtrContainerSorter<PointedToTy>::Sort(
+ const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
+ const OrderedTy& ordered_container, UnorderedTy& unordered_container) {
+ std::vector<size_t> indices;
+ TF_ASSIGN_OR_RETURN(
+ indices, ComputeNewIndices(map_ptr, unmapped_index, ordered_container,
+ unordered_container));
+ Reorder(std::move(indices), unordered_container);
+ return Status::OK();
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
diff --git a/tensorflow/compiler/xla/service/mapped_ptr_container_sorter_test.cc b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter_test.cc
new file mode 100644
index 0000000..f45d49a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter_test.cc
@@ -0,0 +1,291 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
+
+#include <cstddef>
+#include <list>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pointee;
+
+std::vector<std::unique_ptr<std::string>> CreateUniquePtrContainer(
+ const std::vector<std::string>& values) {
+ std::vector<std::unique_ptr<std::string>> container;
+ for (auto value : values) {
+ container.push_back(std::make_unique<std::string>(value));
+ }
+ return container;
+}
+
+class MappedPtrContainerSorterTest : public ::testing::Test {
+ public:
+ using Sorter = MappedPtrContainerSorter<std::string>;
+
+ MappedPtrContainerSorterTest()
+ : map_ptr_(
+ [this](const std::string* ordered) { return MapPtr(ordered); }),
+ ordered_unique_ptrs_(CreateUniquePtrContainer(
+ {"m0", "m1", "m2", "m3", "not_in_unordered"})),
+ unordered_unique_ptrs_(
+ CreateUniquePtrContainer({"m3", "m1", "m0", "m2"})) {
+ for (auto& unique : ordered_unique_ptrs_) {
+ ordered_raw_ptrs_.push_back(unique.get());
+ ordered_const_raw_ptrs_.push_back(unique.get());
+ }
+ for (auto& unique : unordered_unique_ptrs_) {
+ unordered_raw_ptrs_.push_back(unique.get());
+ unordered_const_raw_ptrs_.push_back(unique.get());
+ }
+ }
+
+ protected:
+ const std::string* MapPtr(const std::string* ordered) {
+ for (size_t i = 0; i < unordered_unique_ptrs_.size(); ++i) {
+ if (*ordered == *unordered_unique_ptrs_[i]) {
+ return unordered_unique_ptrs_[i].get();
+ }
+ }
+ return nullptr;
+ }
+
+ // unordered_unique_ptrs_: u0, m3, u1, u2, m2, m0, m2, u3
+ void AddUnmappedElementsToUnorderedUniquePtrs() {
+ unordered_unique_ptrs_.insert(unordered_unique_ptrs_.begin(),
+ std::make_unique<std::string>("u0"));
+ unordered_unique_ptrs_.insert(unordered_unique_ptrs_.begin() + 2,
+ std::make_unique<std::string>("u1"));
+ unordered_unique_ptrs_.insert(unordered_unique_ptrs_.begin() + 3,
+ std::make_unique<std::string>("u2"));
+ unordered_unique_ptrs_.insert(unordered_unique_ptrs_.end(),
+ std::make_unique<std::string>("u3"));
+ }
+
+ Sorter::MapPtrFn map_ptr_;
+ std::vector<std::unique_ptr<std::string>> ordered_unique_ptrs_;
+ std::vector<std::unique_ptr<std::string>> unordered_unique_ptrs_;
+ std::vector<std::string*> ordered_raw_ptrs_;
+ std::vector<std::string*> unordered_raw_ptrs_;
+ std::vector<const std::string*> ordered_const_raw_ptrs_;
+ std::vector<const std::string*> unordered_const_raw_ptrs_;
+};
+
+TEST_F(MappedPtrContainerSorterTest, SortUniquePtrs) {
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(),
+ ordered_unique_ptrs_, unordered_unique_ptrs_));
+ EXPECT_THAT(
+ unordered_unique_ptrs_,
+ ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, RawPtrs) {
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(),
+ ordered_raw_ptrs_, unordered_raw_ptrs_));
+ EXPECT_THAT(
+ unordered_raw_ptrs_,
+ ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, ConstRawPtrs) {
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(),
+ ordered_const_raw_ptrs_,
+ unordered_const_raw_ptrs_));
+ EXPECT_THAT(
+ unordered_const_raw_ptrs_,
+ ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, DifferentContainerTypes) {
+ std::list<std::unique_ptr<std::string>> ordered_ptrs;
+ for (auto& ptr : ordered_unique_ptrs_) {
+ ordered_ptrs.push_back(std::move(ptr));
+ }
+
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(), ordered_ptrs,
+ unordered_unique_ptrs_));
+ EXPECT_THAT(
+ unordered_unique_ptrs_,
+ ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, WithUnmappedPtrsAfterMappedPtrs) {
+ AddUnmappedElementsToUnorderedUniquePtrs();
+
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexAfterMappedElementsFn(),
+ ordered_unique_ptrs_, unordered_unique_ptrs_));
+ EXPECT_THAT(
+ unordered_unique_ptrs_,
+ ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+ Pointee(std::string("m2")), Pointee(std::string("m3")),
+ // Unmapped pointers come after mapped ptrs
+ Pointee(std::string("u0")), Pointee(std::string("u1")),
+ Pointee(std::string("u2")), Pointee(std::string("u3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, WithUnmappedPtrsBeforeMappedPtrs) {
+ AddUnmappedElementsToUnorderedUniquePtrs();
+
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexBeforeMappedElementsFn(),
+ ordered_unique_ptrs_, unordered_unique_ptrs_));
+ EXPECT_THAT(unordered_unique_ptrs_,
+ ElementsAre(
+ // Unmapped pointers come before mapped ptrs
+ Pointee(std::string("u0")), Pointee(std::string("u1")),
+ Pointee(std::string("u2")), Pointee(std::string("u3")),
+ Pointee(std::string("m0")), Pointee(std::string("m1")),
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, WithUnmappedPtrsInCustomLocations) {
+ Sorter::UnmappedPtrIndexFn unmapped_ptr_index =
+ [](const std::string* s) -> size_t {
+ if (*s == "u0") {
+ return Sorter::IndexAfterMappedElementsFn()(s);
+ }
+ if (*s == "u1") {
+ return 2;
+ }
+ if (*s == "u2") {
+ return 2;
+ }
+ if (*s == "u3") {
+ return Sorter::IndexBeforeMappedElementsFn()(s);
+ }
+ LOG(FATAL) << "We should not be getting an unmapped ptr index for " << *s;
+ };
+ AddUnmappedElementsToUnorderedUniquePtrs();
+
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, unmapped_ptr_index, ordered_unique_ptrs_,
+ unordered_unique_ptrs_));
+ EXPECT_THAT(
+ unordered_unique_ptrs_,
+ ElementsAre(
+ Pointee(std::string("u3")), // unmapped u3 comes before mapped ptrs
+ Pointee(std::string("m0")), // mapped index 0
+ Pointee(std::string("m1")), // mapped index 1
+ Pointee(std::string("m2")), // mapped index 2
+ Pointee(std::string("u1")), // unmapped u1 comes after mapped index 2
+ Pointee(std::string("u2")), // unmapped u2 comes after mapped index 2
+ Pointee(std::string("m3")), // mapped index 3
+ Pointee(std::string("u0")) // unmapped u0 comes after mapped ptrs
+ ));
+}
+
+TEST_F(MappedPtrContainerSorterTest,
+ ManyOrderedElementsMapToFewUnorderedElements) {
+ std::string* ordered_m1 = nullptr;
+ for (auto ptr : ordered_raw_ptrs_) {
+ if (*ptr == "m1") {
+ ordered_m1 = ptr;
+ break;
+ }
+ }
+ ASSERT_NE(ordered_m1, nullptr);
+ std::string* unordered_m1 = nullptr;
+ for (auto ptr : unordered_raw_ptrs_) {
+ if (*ptr == "m1") {
+ unordered_m1 = ptr;
+ break;
+ }
+ }
+ ASSERT_NE(unordered_m1, nullptr);
+
+ // Add 2 more instances of m1 to the ordered container and 1 more to the
+ // unordered container.
+ ordered_raw_ptrs_.insert(ordered_raw_ptrs_.begin(), ordered_m1);
+ ordered_raw_ptrs_.push_back(ordered_m1);
+ unordered_raw_ptrs_.push_back(unordered_m1);
+
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexBeforeMappedElementsFn(),
+ ordered_raw_ptrs_, unordered_raw_ptrs_));
+ EXPECT_THAT(
+ unordered_raw_ptrs_,
+ ElementsAre(
+ Pointee(std::string("m1")), // Corresponds to 1st m1 in ordered
+ Pointee(std::string("m0")),
+ Pointee(std::string("m1")), // Corresponds to 2nd m1 in ordered
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest,
+ FewOrderedElementsMapToManyUnorderedElements) {
+ std::string* ordered_m1 = nullptr;
+ for (auto ptr : ordered_raw_ptrs_) {
+ if (*ptr == "m1") {
+ ordered_m1 = ptr;
+ break;
+ }
+ }
+ ASSERT_NE(ordered_m1, nullptr);
+ std::string* unordered_m1 = nullptr;
+ for (auto ptr : unordered_raw_ptrs_) {
+ if (*ptr == "m1") {
+ unordered_m1 = ptr;
+ break;
+ }
+ }
+ ASSERT_NE(unordered_m1, nullptr);
+
+ // Add 1 more instances of m1 to the ordered container and 2 more to the
+ // unordered container.
+ ordered_raw_ptrs_.insert(ordered_raw_ptrs_.begin(), ordered_m1);
+ unordered_raw_ptrs_.push_back(unordered_m1);
+ unordered_raw_ptrs_.push_back(unordered_m1);
+
+ TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexBeforeMappedElementsFn(),
+ ordered_raw_ptrs_, unordered_raw_ptrs_));
+ EXPECT_THAT(
+ unordered_raw_ptrs_,
+ ElementsAre(
+ Pointee(std::string("m1")), // Corresponds to 1st m1 in ordered
+ Pointee(std::string("m0")),
+ Pointee(std::string("m1")), // Corresponds to 2nd m1 in ordered
+ Pointee(std::string("m1")), // Reuse position of 2nd m1 in ordered
+ Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, InvalidUnmappedIndex) {
+ unordered_unique_ptrs_.push_back(std::make_unique<std::string>("u0"));
+ Sorter::UnmappedPtrIndexFn unmapped_index_fn =
+ [](const std::string* unmapped) -> size_t {
+ if (*unmapped == "u0") {
+ // There are 4 mapped elements, so index 3 is the highest valid index,
+ // (excluding special indices)
+ return 4;
+ }
+ return Sorter::IndexBeforeMappedElementsFn()(unmapped);
+ };
+
+ EXPECT_FALSE(Sorter::Sort(map_ptr_, unmapped_index_fn, ordered_unique_ptrs_,
+ unordered_unique_ptrs_)
+ .ok());
+}
+
+} // namespace
+} // namespace xla