A fix so that a module and its clone produce similar compilation results. This fix ensures that fingerprint(compilation(M)) == fingerprint(compilation(clone(M))).

PiperOrigin-RevId: 449909501
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 5182537..a1fc3d7 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -558,6 +558,7 @@
     deps = [
         ":hlo_module_config",
         ":hlo_proto_cc",
+        ":mapped_ptr_container_sorter",
         ":name_uniquer",
         "//tensorflow/compiler/xla:array",
         "//tensorflow/compiler/xla:comparison_util",
@@ -6303,3 +6304,30 @@
         "@absl_py//absl/testing:absltest",
     ] + xla_py_test_deps(),
 )
+
+cc_library(
+    name = "mapped_ptr_container_sorter",
+    hdrs = ["mapped_ptr_container_sorter.h"],
+    deps = [
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:portable_gif_internal",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:statusor",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+tf_cc_test(
+    name = "mapped_ptr_container_sorter_test",
+    srcs = ["mapped_ptr_container_sorter_test.cc"],
+    deps = [
+        ":mapped_ptr_container_sorter",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:test",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 606a9d6..b84f72e 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -17,8 +17,10 @@
 
 #include <algorithm>
 #include <cstddef>
+#include <cstdint>
 #include <functional>
 #include <list>
+#include <memory>
 #include <queue>
 #include <set>
 #include <sstream>
@@ -35,9 +37,12 @@
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -1037,6 +1042,86 @@
                                context, suffix);
 }
 
+namespace {
+
+// Sorts unordered_instructions according to the order of ordered_instructions,
+// using MappedPtrContainerSorter. context and replace are used to map
+// instructions in ordered_instructions to instructions in
+// unordered_instructions. Unmapped parameter instructions are placed just after
+// the last parameter instruction in the sorted mapped instruction order. All
+// other mapped instructions are placed at the end.
+void SortClonedInstructions(
+    const HloCloneContext& context,
+    const std::function<const HloInstruction*(const HloInstruction*)>& replace,
+    const HloComputation& computation,
+    const HloComputation::InstructionList& ordered_instructions,
+    std::vector<std::unique_ptr<HloInstruction>>& unordered_instructions) {
+  using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
+  InstructionSorter::MapPtrFn instruction_mapper =
+      [&context, &replace](const HloInstruction* i) {
+        return context.FindInstruction(replace(i));
+      };
+  size_t num_mapped_instructions = 0;
+  size_t mapped_index_of_last_parameter_plus_one = 0;
+  for (const auto& instruction : ordered_instructions) {
+    if (!instruction_mapper(instruction.get())) {
+      continue;
+    }
+    ++num_mapped_instructions;
+    if (!dynamic_cast<const HloParameterInstruction*>(instruction.get())) {
+      continue;
+    }
+    mapped_index_of_last_parameter_plus_one = num_mapped_instructions;
+  }
+  InstructionSorter::UnmappedPtrIndexFn unmapped_ptr_index =
+      [num_mapped_instructions,
+       mapped_index_of_last_parameter_plus_one](const HloInstruction* i) {
+        if (dynamic_cast<const HloParameterInstruction*>(i)) {
+          if (num_mapped_instructions > 0 &&
+              mapped_index_of_last_parameter_plus_one > 0) {
+            return mapped_index_of_last_parameter_plus_one - 1;
+          }
+          return InstructionSorter::IndexBeforeMappedElementsFn()(i);
+        }
+        return InstructionSorter::IndexAfterMappedElementsFn()(i);
+      };
+  auto status =
+      InstructionSorter::Sort(instruction_mapper, unmapped_ptr_index,
+                              ordered_instructions, unordered_instructions);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to reorder instructions while cloning computation: "
+               << computation.name() << "; " << status;
+  }
+}
+
+// For cloned instructions, sorts their users, control predecessors, and control
+// successors, according to the orders of those lists in the original
+// instructions, before cloning. context and replace help us to map original
+// instructions to cloned instructions, in addition to creating a list of
+// cloned instructions.
+void SortClonedInstructionUsersAndControlLists(
+    const HloCloneContext& context,
+    const std::function<const HloInstruction*(const HloInstruction*)>& replace,
+    const HloComputation::InstructionList& sorted_instructions) {
+  using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
+  InstructionSorter::MapPtrFn instruction_mapper =
+      [context, &replace](const HloInstruction* i) {
+        return context.FindInstruction(replace(i));
+      };
+  for (const std::unique_ptr<HloInstruction>& instruction :
+       sorted_instructions) {
+    HloInstruction* cloned_instruction =
+        context.FindInstruction(replace(instruction.get()));
+    if (!cloned_instruction) {
+      continue;
+    }
+    cloned_instruction->SortInstructionUsersAndControlLists(instruction_mapper,
+                                                            *instruction);
+  }
+}
+
+}  // namespace
+
 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
         replacements,
@@ -1131,6 +1216,11 @@
     }
     instructions.push_back(std::move(new_instr));
   }
+
+  // To make clone behavior match uncloned behavior, we reorder instructions to
+  // match the order in instructions_.
+  SortClonedInstructions(*context, replace, *this, instructions_, instructions);
+
   Builder builder(suffix.empty() ? name() : name() + "." + suffix);
   for (auto& instr : instructions) {
     builder.AddInstruction(std::move(instr));
@@ -1151,7 +1241,13 @@
       }
     }
   }
+
+  // To make clone behavior match uncloned behavior, we reorder the user and
+  // control lists, kept by cloned instructions.
+  SortClonedInstructionUsersAndControlLists(*context, replace, instructions_);
+
   context->MapComputation(this, result.get());
+
   return result;
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index e77e9ab..29d571f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -62,6 +62,9 @@
 // f(y), f(z)].)
 class HloComputation {
  public:
+  // Used by instructions_.
+  using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
+
   // Builder class for HloComputation.
   class Builder {
    public:
@@ -681,7 +684,6 @@
   // Store instructions in std::list as they can be added and removed
   // arbitrarily and we want a stable iteration order. Keep a map from
   // instruction pointer to location in the list for fast lookup.
-  using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
   InstructionList instructions_;
   absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
       instruction_iterators_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ca76d28..446683f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -47,6 +47,7 @@
 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
 #include "tensorflow/compiler/xla/service/name_uniquer.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -55,6 +56,7 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/human_readable_json.h"
 #include "tensorflow/core/platform/logging.h"
 
@@ -4389,6 +4391,36 @@
   outer_dimension_partitions_ = outer_dimension_partitions;
 }
 
+void HloInstruction::SortInstructionUsersAndControlLists(
+    const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
+    const HloInstruction& sorted_instruction) {
+  using Sorter = MappedPtrContainerSorter<HloInstruction>;
+  auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+                             sorted_instruction.users_, users_);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to sort instruction users for " << name() << "; "
+               << status;
+  }
+  user_map_.clear();
+  for (uint64_t i = 0; i < users_.size(); ++i) {
+    user_map_[users_[i]] = i;
+  }
+  status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+                        sorted_instruction.control_predecessors_,
+                        control_predecessors_);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to sort instruction control predecessors for "
+               << name() << "; " << status;
+  }
+  status =
+      Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+                   sorted_instruction.control_successors_, control_successors_);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to sort instruction control successors for " << name()
+               << "; " << status;
+  }
+}
+
 // TODO(b/80131774): Remove these temporary methods after transition.
 int64_t HloInstruction::feature_index() const {
   return Cast<HloBatchNormInstruction>(this)->feature_index();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 7df5a2f..44ee604 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -49,6 +49,7 @@
 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
 #include "tensorflow/compiler/xla/service/name_uniquer.h"
 #include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -1836,6 +1837,13 @@
   void set_outer_dimension_partitions(
       const std::vector<int64_t>& outer_dimension_partitions);
 
+  // A method that sorts users_, control_predecessors_, and control_successors_
+  // according to the orders used in sorted_instruction. The sorting is used
+  // during cloning, to make clone behavior match uncloned behavior.
+  void SortInstructionUsersAndControlLists(
+      const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
+      const HloInstruction& sorted_instruction);
+
   // Old methods kept for smooth subclassing transition BEGIN.
   // TODO(b/80131774): Remove this code.
 
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index c390fb2..85624c8 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -16,7 +16,10 @@
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 
 #include <algorithm>
+#include <cstdint>
+#include <functional>
 #include <iterator>
+#include <memory>
 #include <set>
 #include <sstream>
 #include <string>
@@ -33,6 +36,7 @@
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -40,6 +44,7 @@
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/stacktrace.h"
 
 namespace xla {
@@ -849,6 +854,22 @@
     const auto& indices = parameter_indices.second;
     module->AddCrossProgramPrefetch(parameter, indices);
   }
+
+  // To make clone behavior match uncloned behavior, we reorder
+  // module->computations_ to match the order in computations_.
+  using ComputationSorter = MappedPtrContainerSorter<HloComputation>;
+  ComputationSorter::MapPtrFn computation_map_fn =
+      [&context](const HloComputation* c) {
+        return context.FindComputation(c);
+      };
+  auto status = ComputationSorter::Sort(
+      computation_map_fn, ComputationSorter::IndexAfterMappedElementsFn(),
+      computations_, module->computations_);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to sort module computations for " << name() << "; "
+               << status;
+  }
+
   return module;
 }
 
diff --git a/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h
new file mode 100644
index 0000000..8e602fb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h
@@ -0,0 +1,456 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Below, we specify an example usage, in which clone is sorted according to
+// original, using map_fn to map from pointers in original to pointers in clone.
+//
+//   std::vector<std::unique_ptr<HloInstruction*>> original = ...;
+//   std::vector<std::unique_ptr<HloInstruction*>> clone = ...;
+//   HloCloneContext* ctx = ...;
+//   using Sorter = MappedPtrContainerSorter<HloInstruction>;
+//   Sorter::MappedPtrFn map_fn = [ctx](const HloInstruction* i) {
+//       return ctx->FindInstruction(i);
+//     };
+//
+//   auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
+//                              original, clone);
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
+
+#include <array>
+#include <cstddef>
+#include <functional>
+#include <limits>
+#include <list>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/statusor.h"
+
+namespace xla {
+
+// A class for sorting an unordered container of pointers according to the sort
+// order of an ordered container of pointers. Sorting is stable.
+//
+// Terminology:
+// - unmapped element: An element from the unordered container that does not
+//   have a corresponding element in the ordered container.
+template <typename PointedToTy>
+class MappedPtrContainerSorter {
+ public:
+  // A function to map elements from an ordered container to elements in an
+  // unordered container. Not every element in ordered_container need map to an
+  // element in unordered_container and vice versa.
+  using MapPtrFn = std::function<const PointedToTy*(const PointedToTy*)>;
+
+  // A function that maps unmapped elements (from an unordered container) to an
+  // index in the final sorted result. The returned index indicates that the
+  // unmapped element should be placed just after the mapped element at that
+  // index, in the result without unmapped elements. See
+  // IndexBeforeMappedElementsFn() and IndexAfterMappedElementsFn() for how to
+  // indicate that an unmapped element should be placed before or after all
+  // mapped elements, respectively. Unmapped elements destined for the same
+  // index will retain their order from the unordered container.
+  using UnmappedPtrIndexFn = std::function<size_t(const PointedToTy*)>;
+
+  // Functions that return an UnmappedElementIndexFn that indicates that
+  // ummapped elements (from an unordered container) should be placed before or
+  // after all mapped elements, respectively.
+  static const UnmappedPtrIndexFn& IndexBeforeMappedElementsFn();
+  static const UnmappedPtrIndexFn& IndexAfterMappedElementsFn();
+
+  // Returned function always returns an error.
+  static const UnmappedPtrIndexFn& InvalidIndexFn();
+
+  // Sorts an unordered container of pointers according to the order of an
+  // ordered container of pointers. Sorting is stable. Works with POD pointers,
+  // const POD pointers, and unique_ptrs. If an error is returned,
+  // unordered_container is not modified. Returns an error status if:
+  // - unmapped_index() returns an invalid index
+  // - An internal error occurs. (This should theoretically not happen.)
+  template <typename OrderedTy, typename UnorderedTy>
+  static Status Sort(const MapPtrFn& map_ptr,
+                     const UnmappedPtrIndexFn& unmapped_index,
+                     const OrderedTy& ordered_container,
+                     UnorderedTy& unordered_container);
+
+ private:
+  // A class for sorting the indices of the unordered_container.
+  class SortedIndices {
+   public:
+    // max_partial_order_exclusive is 1 greater than the maximum partial order
+    // value allowed to be sent to AddMappedElement().
+    SortedIndices(size_t max_partial_order_exclusive,
+                  size_t unordered_container_size)
+        : max_partial_order_exclusive_(max_partial_order_exclusive),
+          unordered_container_size_(unordered_container_size),
+          mapped_element_indices_by_partial_order_(
+              max_partial_order_exclusive) {}
+
+    // Specify the partial ordering value of a mapped element from the
+    // unordered container. The partial ordering is amongst other mapped
+    // elements.
+    Status AddMappedElement(size_t unordered_container_index,
+                            size_t partial_order);
+
+    // Specify the index (amongst mapped elements), where an unmapped element
+    // should be inserted. The unmapped element is inserted just after the
+    // mapped element with index target_index_amongst_mapped_elements.
+    void AddUnmappedElement(size_t unordered_container_index,
+                            size_t target_index_amongst_mapped_elements);
+
+    std::string ToString() const;
+
+    // The result maps each element in the unordered_container to the target
+    // index that it will occupy in the sorted result.
+    StatusOr<std::vector<size_t>> Flatten() const;
+
+   private:
+    SortedIndices() = delete;
+
+    size_t max_partial_order_exclusive_;
+    size_t unordered_container_size_;
+    std::vector<std::vector<size_t>> mapped_element_indices_by_partial_order_;
+    absl::flat_hash_map<size_t, std::vector<size_t>>
+        target_index_to_unmapped_element_index_;
+  };
+
+  static size_t IndexBeforeMappedElements() {
+    return std::numeric_limits<size_t>::max() - 2;
+  }
+
+  static size_t IndexAfterMappedElements() {
+    return std::numeric_limits<size_t>::max() - 1;
+  }
+
+  static size_t InvalidIndex() { return std::numeric_limits<size_t>::max(); }
+
+  // Returns a mapping in which the element at index i indicates the target
+  // index that unordered_container[i] should occupy in the sorted result.
+  template <typename OrderedTy, typename UnorderedTy>
+  static StatusOr<std::vector<size_t>> ComputeNewIndices(
+      const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
+      const OrderedTy& ordered_container,
+      const UnorderedTy& unordered_container);
+
+  // Reorders unordered_container according to the indices in new_indices. See
+  // ComputeNewIndices() for how to interpret new_indices.
+  template <typename UnorderedTy>
+  static void Reorder(std::vector<size_t> new_indices,
+                      UnorderedTy& unordered_container);
+};
+
+///// Template implementation below /////
+
+namespace mapped_ptr_container_sorter_internal {
+
+template <typename I, typename O>
+struct PtrGetter {
+  // Extracts a pointer of type O from i.
+  static O Get(I i);
+};
+
+template <typename T>
+struct PtrGetter<T* const&, const T*> {
+  static const T* Get(T* const& p) { return p; }
+};
+
+template <typename T>
+struct PtrGetter<T const* const&, const T*> {
+  static const T* Get(T const* const& p) { return p; }
+};
+
+template <typename T>
+struct PtrGetter<T*&, T*> {
+  static T* Get(T*& p) { return p; }
+};
+
+template <typename T>
+struct PtrGetter<const std::unique_ptr<T>&, const T*> {
+  static const T* Get(const std::unique_ptr<T>& p) { return p.get(); }
+};
+
+template <typename T>
+struct PtrGetter<std::unique_ptr<T>&, T*> {
+  static T* Get(std::unique_ptr<T>& p) { return p.get(); }
+};
+
+}  // namespace mapped_ptr_container_sorter_internal
+
+template <typename PointedToTy>
+const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
+MappedPtrContainerSorter<PointedToTy>::IndexBeforeMappedElementsFn() {
+  static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn(
+      [](const PointedToTy*) { return IndexBeforeMappedElements(); });
+  return *fn;
+}
+
+template <typename PointedToTy>
+const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
+MappedPtrContainerSorter<PointedToTy>::IndexAfterMappedElementsFn() {
+  static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn(
+      [](const PointedToTy*) { return IndexAfterMappedElements(); });
+  return *fn;
+}
+
+template <typename PointedToTy>
+const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
+MappedPtrContainerSorter<PointedToTy>::InvalidIndexFn() {
+  static const UnmappedPtrIndexFn* fn =
+      new UnmappedPtrIndexFn([](const PointedToTy*) { return InvalidIndex(); });
+  return *fn;
+}
+
+template <typename PointedToTy>
+Status MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddMappedElement(
+    size_t unordered_container_index, size_t partial_order) {
+  if (partial_order >= mapped_element_indices_by_partial_order_.size()) {
+    return InternalErrorStrCat(
+        "invalid partial order: ", partial_order, " v max(",
+        mapped_element_indices_by_partial_order_.size(), ")");
+  }
+
+  mapped_element_indices_by_partial_order_[partial_order].push_back(
+      unordered_container_index);
+  return Status::OK();
+}
+
+template <typename PointedToTy>
+void MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddUnmappedElement(
+    size_t unordered_container_index,
+    size_t target_index_amongst_mapped_elements) {
+  target_index_to_unmapped_element_index_[target_index_amongst_mapped_elements]
+      .push_back(unordered_container_index);
+}
+
+template <typename PointedToTy>
+std::string MappedPtrContainerSorter<PointedToTy>::SortedIndices::ToString()
+    const {
+  std::vector<std::string> mapped_element_strs;
+  mapped_element_strs.reserve(mapped_element_indices_by_partial_order_.size());
+  for (const auto& indices : mapped_element_indices_by_partial_order_) {
+    mapped_element_strs.push_back(
+        absl::StrCat("[", absl::StrJoin(indices, ", "), "]"));
+  }
+  std::vector<std::string> unmapped_element_strs;
+  unmapped_element_strs.reserve(target_index_to_unmapped_element_index_.size());
+  for (const auto& kv : target_index_to_unmapped_element_index_) {
+    std::string key = absl::StrCat(kv.first);
+    if (kv.first == IndexBeforeMappedElements()) {
+      key = "before_mapped";
+    }
+    if (kv.first == IndexAfterMappedElements()) {
+      key = "after_mapped";
+    }
+    if (kv.first == InvalidIndex()) {
+      key = "invalid";
+    }
+    unmapped_element_strs.push_back(
+        absl::StrCat(key, ": [", absl::StrJoin(kv.second, ", "), "]"));
+  }
+
+  return absl::StrCat(
+      "max_partial_order_exclusive_: ", max_partial_order_exclusive_, "\n",
+      "unordered_container_size_: ", unordered_container_size_, "\n",
+      "mapped_element_indices_by_partial_order_: [",
+      absl::StrJoin(mapped_element_strs, ", "), "]\n",
+      "target_index_to_unmapped_element_index_: {",
+      absl::StrJoin(unmapped_element_strs, ", "), "}\n");
+}
+
+template <typename PointedToTy>
+StatusOr<std::vector<size_t>>
+MappedPtrContainerSorter<PointedToTy>::SortedIndices::Flatten() const {
+  std::vector<size_t> result(unordered_container_size_, InvalidIndex());
+  size_t next_available_index = 0;
+  auto next_index_fn = [&]() -> StatusOr<size_t> {
+    if (next_available_index >= unordered_container_size_) {
+      return InternalErrorStrCat(
+          "invalid unordered_container index: ", next_available_index,
+          " v size(", unordered_container_size_, ")");
+    }
+    return next_available_index++;
+  };
+
+  if (target_index_to_unmapped_element_index_.contains(
+          IndexBeforeMappedElements())) {
+    const auto& indices =
+        target_index_to_unmapped_element_index_.at(IndexBeforeMappedElements());
+    for (size_t index : indices) {
+      TF_ASSIGN_OR_RETURN(result[index], next_index_fn());
+    }
+  }
+  size_t num_inserted_mapped_elements = 0;
+  for (const auto& mapped_element_indices :
+       mapped_element_indices_by_partial_order_) {
+    for (size_t mapped_element_index : mapped_element_indices) {
+      TF_ASSIGN_OR_RETURN(result[mapped_element_index], next_index_fn());
+      ++num_inserted_mapped_elements;
+      if (target_index_to_unmapped_element_index_.contains(
+              num_inserted_mapped_elements - 1)) {
+        const auto& unmapped_element_indices =
+            target_index_to_unmapped_element_index_.at(
+                num_inserted_mapped_elements - 1);
+        for (size_t unmapped_element_index : unmapped_element_indices) {
+          TF_ASSIGN_OR_RETURN(result[unmapped_element_index], next_index_fn());
+        }
+      }
+    }
+  }
+  if (target_index_to_unmapped_element_index_.contains(
+          IndexAfterMappedElements())) {
+    const auto& indices =
+        target_index_to_unmapped_element_index_.at(IndexAfterMappedElements());
+    for (size_t index : indices) {
+      TF_ASSIGN_OR_RETURN(result[index], next_index_fn());
+    }
+  }
+
+  // Ensure that every element in unordered_container has a valid new index.
+  absl::flat_hash_set<size_t> used_indices;
+  for (size_t index : result) {
+    if (used_indices.contains(index)) {
+      return InternalErrorStrCat(
+          "2 elements in unordered_container are destined for the same "
+          "index: ",
+          index);
+    }
+    if (index >= unordered_container_size_) {
+      return InvalidArgumentStrCat("invalid unordered_container index: ", index,
+                                   " v size(", unordered_container_size_, ")");
+    }
+  }
+
+  return result;
+}
+
+template <typename PointedToTy>
+template <typename OrderedTy, typename UnorderedTy>
+StatusOr<std::vector<size_t>>
+MappedPtrContainerSorter<PointedToTy>::ComputeNewIndices(
+    const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
+    const OrderedTy& ordered_container,
+    const UnorderedTy& unordered_container) {
+  using UnorderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter<
+      typename UnorderedTy::const_reference, const PointedToTy*>;
+  using OrderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter<
+      typename OrderedTy::const_reference, const PointedToTy*>;
+
+  if (unordered_container.size() >= IndexBeforeMappedElements()) {
+    return InvalidArgumentStrCat("Unordered container is too large to sort.");
+  }
+
+  // Step 1: build a set of the ptrs in unordered_container
+  absl::flat_hash_set<const PointedToTy*> unordered_ptrs;
+  for (const auto& unordered_element : unordered_container) {
+    const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_element);
+    unordered_ptrs.insert(ptr);
+  }
+
+  // Step 2: for mapped elements (in unordered_container), create a map from
+  // mapped ptr -> partial ordering
+  absl::flat_hash_map<const PointedToTy*, std::list<size_t>>
+      mapped_ptr_to_partial_order;
+  size_t next_partial_order_value = 0;
+  for (const auto& ordered_element : ordered_container) {
+    const PointedToTy* ordered_ptr = OrderedPtrGetter::Get(ordered_element);
+    const PointedToTy* unordered_ptr = map_ptr(ordered_ptr);
+    if (!unordered_ptr) {
+      // A corresponding unordered element does not exist.
+      continue;
+    }
+    if (!unordered_ptrs.contains(unordered_ptr)) {
+      // A pointer exists that maps to the ordered element, but it's not in our
+      // unordered_container.
+      continue;
+    }
+    mapped_ptr_to_partial_order[unordered_ptr].push_back(
+        next_partial_order_value);
+    ++next_partial_order_value;
+  }
+
+  // Step 3: create sorted unordered element indices
+  SortedIndices result(next_partial_order_value, unordered_container.size());
+  for (size_t i = 0; i < unordered_container.size(); ++i) {
+    const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_container[i]);
+    if (!mapped_ptr_to_partial_order.contains(ptr)) {
+      // ptr is unmapped
+      result.AddUnmappedElement(i, unmapped_index(ptr));
+      continue;
+    }
+
+    // ptr is mapped
+    //
+    // Potentially, several elements in ordered_container map to ptr.
+    // We assign ptr theindex corresponding to the next such ordered element.
+    auto& index_list = mapped_ptr_to_partial_order[ptr];
+    TF_RETURN_IF_ERROR(result.AddMappedElement(i, index_list.front()));
+    // Do not map more than one unordered element to the same index, unless we
+    // have no choice.
+    if (index_list.size() > 1) {
+      // We never remove the last ordered index, in case ptr appears in the
+      // unordered_container more times than the ordered container.
+      index_list.pop_front();
+    }
+  }
+
+  VLOG(5) << "Pre flatten unordered_container result:\n" << result.ToString();
+  return result.Flatten();
+}
+
+template <typename PointedToTy>
+template <typename UnorderedTy>
+void MappedPtrContainerSorter<PointedToTy>::Reorder(
+    std::vector<size_t> new_indices, UnorderedTy& unordered_container) {
+  size_t old_pos = 0;
+  while (old_pos < new_indices.size()) {
+    size_t new_pos = new_indices[old_pos];
+    if (old_pos == new_pos) {
+      ++old_pos;
+      continue;
+    }
+    std::swap(new_indices[old_pos], new_indices[new_pos]);
+    std::swap(unordered_container[old_pos], unordered_container[new_pos]);
+  }
+}
+
+template <typename PointedToTy>
+template <typename OrderedTy, typename UnorderedTy>
+Status MappedPtrContainerSorter<PointedToTy>::Sort(
+    const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
+    const OrderedTy& ordered_container, UnorderedTy& unordered_container) {
+  std::vector<size_t> indices;
+  TF_ASSIGN_OR_RETURN(
+      indices, ComputeNewIndices(map_ptr, unmapped_index, ordered_container,
+                                 unordered_container));
+  Reorder(std::move(indices), unordered_container);
+  return Status::OK();
+}
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
diff --git a/tensorflow/compiler/xla/service/mapped_ptr_container_sorter_test.cc b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter_test.cc
new file mode 100644
index 0000000..f45d49a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/mapped_ptr_container_sorter_test.cc
@@ -0,0 +1,291 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
+
+#include <cstddef>
+#include <list>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pointee;
+
+std::vector<std::unique_ptr<std::string>> CreateUniquePtrContainer(
+    const std::vector<std::string>& values) {
+  std::vector<std::unique_ptr<std::string>> container;
+  for (auto value : values) {
+    container.push_back(std::make_unique<std::string>(value));
+  }
+  return container;
+}
+
+class MappedPtrContainerSorterTest : public ::testing::Test {
+ public:
+  using Sorter = MappedPtrContainerSorter<std::string>;
+
+  MappedPtrContainerSorterTest()
+      : map_ptr_(
+            [this](const std::string* ordered) { return MapPtr(ordered); }),
+        ordered_unique_ptrs_(CreateUniquePtrContainer(
+            {"m0", "m1", "m2", "m3", "not_in_unordered"})),
+        unordered_unique_ptrs_(
+            CreateUniquePtrContainer({"m3", "m1", "m0", "m2"})) {
+    for (auto& unique : ordered_unique_ptrs_) {
+      ordered_raw_ptrs_.push_back(unique.get());
+      ordered_const_raw_ptrs_.push_back(unique.get());
+    }
+    for (auto& unique : unordered_unique_ptrs_) {
+      unordered_raw_ptrs_.push_back(unique.get());
+      unordered_const_raw_ptrs_.push_back(unique.get());
+    }
+  }
+
+ protected:
+  const std::string* MapPtr(const std::string* ordered) {
+    for (size_t i = 0; i < unordered_unique_ptrs_.size(); ++i) {
+      if (*ordered == *unordered_unique_ptrs_[i]) {
+        return unordered_unique_ptrs_[i].get();
+      }
+    }
+    return nullptr;
+  }
+
+  // unordered_unique_ptrs_: u0, m3, u1, u2, m2, m0, m2, u3
+  void AddUnmappedElementsToUnorderedUniquePtrs() {
+    unordered_unique_ptrs_.insert(unordered_unique_ptrs_.begin(),
+                                  std::make_unique<std::string>("u0"));
+    unordered_unique_ptrs_.insert(unordered_unique_ptrs_.begin() + 2,
+                                  std::make_unique<std::string>("u1"));
+    unordered_unique_ptrs_.insert(unordered_unique_ptrs_.begin() + 3,
+                                  std::make_unique<std::string>("u2"));
+    unordered_unique_ptrs_.insert(unordered_unique_ptrs_.end(),
+                                  std::make_unique<std::string>("u3"));
+  }
+
+  Sorter::MapPtrFn map_ptr_;
+  std::vector<std::unique_ptr<std::string>> ordered_unique_ptrs_;
+  std::vector<std::unique_ptr<std::string>> unordered_unique_ptrs_;
+  std::vector<std::string*> ordered_raw_ptrs_;
+  std::vector<std::string*> unordered_raw_ptrs_;
+  std::vector<const std::string*> ordered_const_raw_ptrs_;
+  std::vector<const std::string*> unordered_const_raw_ptrs_;
+};
+
+TEST_F(MappedPtrContainerSorterTest, SortUniquePtrs) {
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(),
+                            ordered_unique_ptrs_, unordered_unique_ptrs_));
+  EXPECT_THAT(
+      unordered_unique_ptrs_,
+      ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+                  Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, RawPtrs) {
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(),
+                            ordered_raw_ptrs_, unordered_raw_ptrs_));
+  EXPECT_THAT(
+      unordered_raw_ptrs_,
+      ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+                  Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, ConstRawPtrs) {
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(),
+                            ordered_const_raw_ptrs_,
+                            unordered_const_raw_ptrs_));
+  EXPECT_THAT(
+      unordered_const_raw_ptrs_,
+      ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+                  Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, DifferentContainerTypes) {
+  std::list<std::unique_ptr<std::string>> ordered_ptrs;
+  for (auto& ptr : ordered_unique_ptrs_) {
+    ordered_ptrs.push_back(std::move(ptr));
+  }
+
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::InvalidIndexFn(), ordered_ptrs,
+                            unordered_unique_ptrs_));
+  EXPECT_THAT(
+      unordered_unique_ptrs_,
+      ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+                  Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, WithUnmappedPtrsAfterMappedPtrs) {
+  AddUnmappedElementsToUnorderedUniquePtrs();
+
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexAfterMappedElementsFn(),
+                            ordered_unique_ptrs_, unordered_unique_ptrs_));
+  EXPECT_THAT(
+      unordered_unique_ptrs_,
+      ElementsAre(Pointee(std::string("m0")), Pointee(std::string("m1")),
+                  Pointee(std::string("m2")), Pointee(std::string("m3")),
+                  // Unmapped pointers come after mapped ptrs
+                  Pointee(std::string("u0")), Pointee(std::string("u1")),
+                  Pointee(std::string("u2")), Pointee(std::string("u3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, WithUnmappedPtrsBeforeMappedPtrs) {
+  AddUnmappedElementsToUnorderedUniquePtrs();
+
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexBeforeMappedElementsFn(),
+                            ordered_unique_ptrs_, unordered_unique_ptrs_));
+  EXPECT_THAT(unordered_unique_ptrs_,
+              ElementsAre(
+                  // Unmapped pointers come before mapped ptrs
+                  Pointee(std::string("u0")), Pointee(std::string("u1")),
+                  Pointee(std::string("u2")), Pointee(std::string("u3")),
+                  Pointee(std::string("m0")), Pointee(std::string("m1")),
+                  Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, WithUnmappedPtrsInCustomLocations) {
+  Sorter::UnmappedPtrIndexFn unmapped_ptr_index =
+      [](const std::string* s) -> size_t {
+    if (*s == "u0") {
+      return Sorter::IndexAfterMappedElementsFn()(s);
+    }
+    if (*s == "u1") {
+      return 2;
+    }
+    if (*s == "u2") {
+      return 2;
+    }
+    if (*s == "u3") {
+      return Sorter::IndexBeforeMappedElementsFn()(s);
+    }
+    LOG(FATAL) << "We should not be getting an unmapped ptr index for " << *s;
+  };
+  AddUnmappedElementsToUnorderedUniquePtrs();
+
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, unmapped_ptr_index, ordered_unique_ptrs_,
+                            unordered_unique_ptrs_));
+  EXPECT_THAT(
+      unordered_unique_ptrs_,
+      ElementsAre(
+          Pointee(std::string("u3")),  // unmapped u3 comes before mapped ptrs
+          Pointee(std::string("m0")),  // mapped index 0
+          Pointee(std::string("m1")),  // mapped index 1
+          Pointee(std::string("m2")),  // mapped index 2
+          Pointee(std::string("u1")),  // unmapped u1 comes after mapped index 2
+          Pointee(std::string("u2")),  // unmapped u2 comes after mapped index 2
+          Pointee(std::string("m3")),  // mapped index 3
+          Pointee(std::string("u0"))   // unmapped u0 comes after mapped ptrs
+          ));
+}
+
+TEST_F(MappedPtrContainerSorterTest,
+       ManyOrderedElementsMapToFewUnorderedElements) {
+  std::string* ordered_m1 = nullptr;
+  for (auto ptr : ordered_raw_ptrs_) {
+    if (*ptr == "m1") {
+      ordered_m1 = ptr;
+      break;
+    }
+  }
+  ASSERT_NE(ordered_m1, nullptr);
+  std::string* unordered_m1 = nullptr;
+  for (auto ptr : unordered_raw_ptrs_) {
+    if (*ptr == "m1") {
+      unordered_m1 = ptr;
+      break;
+    }
+  }
+  ASSERT_NE(unordered_m1, nullptr);
+
+  // Add 2 more instances of m1 to the ordered container and 1 more to the
+  // unordered container.
+  ordered_raw_ptrs_.insert(ordered_raw_ptrs_.begin(), ordered_m1);
+  ordered_raw_ptrs_.push_back(ordered_m1);
+  unordered_raw_ptrs_.push_back(unordered_m1);
+
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexBeforeMappedElementsFn(),
+                            ordered_raw_ptrs_, unordered_raw_ptrs_));
+  EXPECT_THAT(
+      unordered_raw_ptrs_,
+      ElementsAre(
+          Pointee(std::string("m1")),  // Corresponds to 1st m1 in ordered
+          Pointee(std::string("m0")),
+          Pointee(std::string("m1")),  // Corresponds to 2nd m1 in ordered
+          Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest,
+       FewOrderedElementsMapToManyUnorderedElements) {
+  std::string* ordered_m1 = nullptr;
+  for (auto ptr : ordered_raw_ptrs_) {
+    if (*ptr == "m1") {
+      ordered_m1 = ptr;
+      break;
+    }
+  }
+  ASSERT_NE(ordered_m1, nullptr);
+  std::string* unordered_m1 = nullptr;
+  for (auto ptr : unordered_raw_ptrs_) {
+    if (*ptr == "m1") {
+      unordered_m1 = ptr;
+      break;
+    }
+  }
+  ASSERT_NE(unordered_m1, nullptr);
+
+  // Add 1 more instances of m1 to the ordered container and 2 more to the
+  // unordered container.
+  ordered_raw_ptrs_.insert(ordered_raw_ptrs_.begin(), ordered_m1);
+  unordered_raw_ptrs_.push_back(unordered_m1);
+  unordered_raw_ptrs_.push_back(unordered_m1);
+
+  TF_EXPECT_OK(Sorter::Sort(map_ptr_, Sorter::IndexBeforeMappedElementsFn(),
+                            ordered_raw_ptrs_, unordered_raw_ptrs_));
+  EXPECT_THAT(
+      unordered_raw_ptrs_,
+      ElementsAre(
+          Pointee(std::string("m1")),  // Corresponds to 1st m1 in ordered
+          Pointee(std::string("m0")),
+          Pointee(std::string("m1")),  // Corresponds to 2nd m1 in ordered
+          Pointee(std::string("m1")),  // Reuse position of 2nd m1 in ordered
+          Pointee(std::string("m2")), Pointee(std::string("m3"))));
+}
+
+TEST_F(MappedPtrContainerSorterTest, InvalidUnmappedIndex) {
+  unordered_unique_ptrs_.push_back(std::make_unique<std::string>("u0"));
+  Sorter::UnmappedPtrIndexFn unmapped_index_fn =
+      [](const std::string* unmapped) -> size_t {
+    if (*unmapped == "u0") {
+      // There are 4 mapped elements, so index 3 is the highest valid index,
+      // (excluding special indices)
+      return 4;
+    }
+    return Sorter::IndexBeforeMappedElementsFn()(unmapped);
+  };
+
+  EXPECT_FALSE(Sorter::Sort(map_ptr_, unmapped_index_fn, ordered_unique_ptrs_,
+                            unordered_unique_ptrs_)
+                   .ok());
+}
+
+}  // namespace
+}  // namespace xla