[XLA] Cleanup / optimize `ShapeTree`.

PiperOrigin-RevId: 434889286
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index e7c8411..3641111 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -800,16 +800,17 @@
 
 cc_library(
     name = "shape_tree",
+    srcs = ["shape_tree.cc"],
     hdrs = ["shape_tree.h"],
     visibility = ["//visibility:public"],
     deps = [
         ":shape_util",
         ":status_macros",
-        ":xla_data_proto_cc",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/functional:function_ref",
         "@com_google_absl//absl/types:span",
     ],
 )
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 0a3cad5..d32a1b4 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -16,11 +16,11 @@
 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
 
 #include <deque>
+#include <functional>
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
-#include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/service/call_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -41,9 +41,8 @@
 
 void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
                    Workset* workset) {
-  if (!workset->contains(instruction)) {
+  if (workset->insert(instruction).second) {
     worklist->push_back(instruction);
-    workset->insert(instruction);
     VLOG(3) << "ADD instruction: " << instruction->name();
   }
 }
@@ -67,18 +66,17 @@
                      const ShapeIndex& shape_index,
                      HloLivenessAnalysis::HloIndexMap* live_index_map,
                      Worklist* worklist, Workset* workset) {
-  auto it = live_index_map->find(instruction);
-  if (it == live_index_map->end()) {
-    auto it_added = live_index_map->emplace(
-        std::piecewise_construct, std::forward_as_tuple(instruction),
-        std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
-    it = it_added.first;
+  std::unique_ptr<ShapeTree<bool>>& liveness = (*live_index_map)[instruction];
+  if (liveness == nullptr) {
+    liveness = std::make_unique<ShapeTree<bool>>(instruction->shape(),
+                                                 /*init_value=*/false);
   }
-  if (it->second.element(shape_index) == false) {
+  bool& alive = *liveness->mutable_element(shape_index);
+  if (!alive) {
     AddToWorklist(instruction, worklist, workset);
-    *it->second.mutable_element(shape_index) = true;
+    alive = true;
     VLOG(3) << "MARK instruction: " << instruction->name()
-            << " shape_index: " << shape_index.ToString();
+            << " shape_index: " << shape_index;
   }
 }
 
@@ -87,23 +85,21 @@
                           HloLivenessAnalysis::HloIndexMap* live_index_map,
                           Worklist* worklist, Workset* workset) {
   bool add_to_worklist = false;
-  auto it = live_index_map->find(instruction);
-  if (it == live_index_map->end()) {
-    live_index_map->emplace(
-        std::piecewise_construct, std::forward_as_tuple(instruction),
-        std::forward_as_tuple(instruction->shape(), /*init_value=*/true));
+
+  std::unique_ptr<ShapeTree<bool>>& liveness = (*live_index_map)[instruction];
+  if (liveness == nullptr) {
+    liveness = std::make_unique<ShapeTree<bool>>(instruction->shape(),
+                                                 /*init_value=*/true);
     add_to_worklist = true;
   } else {
-    ShapeUtil::ForEachSubshape(
-        instruction->shape(),
-        [&](const Shape& sub_shape, const ShapeIndex& shape_index) {
-          if (it->second.element(shape_index) == false) {
-            add_to_worklist = true;
-            *it->second.mutable_element(shape_index) = true;
-            VLOG(3) << "MARK instruction: " << instruction->name()
-                    << " shape_index: " << shape_index.ToString();
-          }
-        });
+    for (auto& entry : *liveness) {
+      if (!entry.second) {
+        add_to_worklist = true;
+        entry.second = true;
+        VLOG(3) << "MARK instruction: " << instruction->name()
+                << " shape_index: " << entry.first;
+      }
+    }
   }
   if (add_to_worklist) {
     AddToWorklist(instruction, worklist, workset);
@@ -122,7 +118,7 @@
   CHECK_EQ(instruction->opcode(), HloOpcode::kTuple);
   for (int64_t operand_index = 0; operand_index < instruction->operand_count();
        ++operand_index) {
-    const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+    const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
     ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
       if (shape_index.empty() || shape_index[0] != operand_index) {
         return;
@@ -152,7 +148,7 @@
   // Mark operand top-level index.
   MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist,
                   workset);
-  const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+  const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
   // Propagate live shape indices along GTE -> Tuple edge.
   ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
     ShapeIndex operand_shape_index(shape_index);
@@ -171,7 +167,7 @@
     HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
     Workset* workset) {
   CHECK_EQ(instruction->opcode(), HloOpcode::kWhile);
-  const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+  const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
 
   ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
     // Propagate liveness to while body computation root instruction.
@@ -202,8 +198,7 @@
     for (const CallSite& callsite : call_graph_node.caller_callsites()) {
       if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
         auto* xla_while = callsite.instruction();
-        const ShapeTree<bool>& index_tree =
-            FindOrDie(*live_index_map, instruction);
+        const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
         ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
           // Propagate liveness to while result{shape_index}
           MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist,
@@ -256,7 +251,7 @@
               // If 'instruction' is a parameter, propagate live shape indices
               // to the associated callsite's argument shape indices.
               const ShapeTree<bool>& index_tree =
-                  FindOrDie(*live_index_map, instruction);
+                  *live_index_map->at(instruction);
               ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
                 MarkLiveAtIndex(caller->operand(operand_index), shape_index,
                                 live_index_map, worklist, workset);
@@ -334,10 +329,8 @@
 
 bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction,
                                  const ShapeIndex& shape_index) const {
-  if (ContainsKey(live_index_map_, instruction)) {
-    return FindOrDie(live_index_map_, instruction).element(shape_index);
-  }
-  return false;
+  auto it = live_index_map_.find(instruction);
+  return (it != live_index_map_.end()) && it->second->element(shape_index);
 }
 
 /* static */
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h
index c780153..f990f8c 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
 
+#include <memory>
+
 #include "absl/container/flat_hash_map.h"
 #include "tensorflow/compiler/xla/service/call_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -37,8 +39,8 @@
 class HloLivenessAnalysis {
  public:
   // Maps from an HloInstruction to its live/dead output shape indices.
-  using HloIndexMap =
-      absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>;
+  using HloIndexMap = absl::flat_hash_map<const HloInstruction*,
+                                          std::unique_ptr<ShapeTree<bool>>>;
 
   // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object
   // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'.
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 0f2e141..4c7b7d9 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -50,7 +50,7 @@
   // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
   // into buffers_, we also need to update this pointer so that buffers_ doesn't
   // point into s.
-  buffers_.replace_shape_ptr(&on_device_shape_);
+  buffers_.replace_shape_ptr(on_device_shape_);
 }
 
 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
@@ -61,7 +61,7 @@
   // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
   // into buffers_, we also need to update this pointer so that buffers_ doesn't
   // point into s.
-  buffers_.replace_shape_ptr(&on_device_shape_);
+  buffers_.replace_shape_ptr(on_device_shape_);
   return *this;
 }
 
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index e708216..5cc6976 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -91,7 +91,7 @@
   void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
     CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
     buffers_ = std::move(buffers);
-    buffers_.replace_shape_ptr(&on_device_shape_);
+    buffers_.replace_shape_ptr(on_device_shape_);
   }
 
   // Reset the shape of this shaped buffer and underlying buffer structure.
@@ -103,7 +103,7 @@
         << ", old: " << on_device_shape_;
     on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape);
     on_device_shape_ = on_device_shape;
-    buffers_.replace_shape_ptr(&on_device_shape_);
+    buffers_.replace_shape_ptr(on_device_shape_);
   }
   // TODO(b/170310047): remove this overload.
   void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
diff --git a/tensorflow/compiler/xla/shape_tree.cc b/tensorflow/compiler/xla/shape_tree.cc
new file mode 100644
index 0000000..33d99bb
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_tree.cc
@@ -0,0 +1,54 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_tree.h"
+
+namespace xla {
+namespace internal {
+
+IndexTable::IndexTable(const Shape& shape) : entries_(1) {
+  size_t next_node_id = 0;
+  CreateEntry(entries_[0], shape, next_node_id);
+}
+
+// TODO(cjfj): Index table cache?.
+void IndexTable::CreateEntry(Entry& entry, const Shape& shape,
+                             size_t& next_node_id) {
+  entry.node_id = next_node_id++;
+  if (!shape.IsTuple()) return;
+
+  // The nodes are in depth-first pre-order. However, in order to efficiently
+  // lookup indices, we generate the index table using breadth-first.
+  size_t children_start_id = entries_.size();
+  entry.children_start_id = children_start_id;
+  // Add entry for children first, before recursing, so they are consecutive.
+  entries_.resize(entries_.size() + shape.tuple_shapes_size());
+  for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) {
+    CreateEntry(entries_[children_start_id + i], shape.tuple_shapes(i),
+                next_node_id);
+  }
+}
+
+const IndexTable::Entry& IndexTable::operator[](ShapeIndexView index) const {
+  const Entry* result = &entries_.front();
+  for (int64_t i : index) {
+    CHECK_GE(result->children_start_id, 0);
+    result = &entries_[result->children_start_id + i];
+  }
+  return *result;
+}
+
+}  // namespace internal
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 178e05b..03ccbd2 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -16,19 +16,20 @@
 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
 
+#include <algorithm>
 #include <functional>
 #include <iterator>
 #include <memory>
+#include <type_traits>
+#include <utility>
 #include <vector>
 
-#include "absl/memory/memory.h"
-#include "absl/types/optional.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/functional/function_ref.h"
 #include "absl/types/span.h"
-#include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/iterator_range.h"
 #include "tensorflow/core/platform/logging.h"
@@ -37,41 +38,32 @@
 
 namespace internal {
 
-// Internal representation of each node in a ShapeTree.
-template <typename T>
-struct ShapeTreeNode {
-  // Data corresponding to this node.
-  std::pair<ShapeIndex, T> data;
+class IndexTable {
+ public:
+  // Use indices, rather than pointers, so index table can be copied between
+  // ShapeTrees.
+  struct Entry {
+    // Index of the node in the nodes vector.
+    size_t node_id;
+    // Index of the first child of this node in the index table (-1 for leaves).
+    std::make_signed_t<size_t> children_start_id = -1;
+  };
 
-  bool is_leaf = true;
+  IndexTable() = default;
+  explicit IndexTable(const Shape& shape);
 
-  explicit ShapeTreeNode(ShapeIndex index)
-      : ShapeTreeNode(std::move(index), T()) {}
-  ShapeTreeNode(ShapeIndex index, T data)
-      : data(std::move(index), std::move(data)) {}
-};
+  bool empty() const { return entries_.empty(); }
 
-// Internal representation of an index table entry.
-struct IndexTableEntry {
-  // Index of the node in the ShapeTreeNode vector.
-  uint32_t index;
-  // Index of the first child in a IndexTableEntry vector. In the index
-  // table all children entries for a given node will be placed next to each
-  // other. This allows us to use a single field to index them.
-  uint32_t children_start;
-#ifndef NDEBUG
-  // Number of children, used for bounds checking.
-  uint32_t children_count;
-#endif
+  const Entry& operator[](ShapeIndexView index) const;
+
+ private:
+  void CreateEntry(Entry& entry, const Shape& shape, size_t& next_node_id);
+
+  absl::InlinedVector<Entry, 1> entries_;
 };
 
 }  // namespace internal
 
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeIterator;
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeLeafIterator;
-
 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
@@ -88,14 +80,22 @@
 //
 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
 // it's helpful not to copy a Shape just to make a ShapeTree.  In these cases,
-// you can pass a Shape* instead of a Shape& to the ShapeTree constructor.  It's
-// then up to you to ensure that the pointed-to Shape doesn't die or mutate
-// before its ShapeTree goes away.
+// you can pass a Shape* instead of a Shape to the ShapeTree constructor.  It's
+// then up to you to ensure that the pointed-to Shape isn't freed, moved or
+// modified before its ShapeTree goes away.
 template <typename T>
 class ShapeTree {
+  template <typename U>
+  friend class ShapeTree;
+
  public:
-  using Node = internal::ShapeTreeNode<T>;
-  using Index = internal::IndexTableEntry;
+  // TODO(cjfj): Don't store ShapeIndex with data. Generate it or cache it?
+  using Node = std::pair<ShapeIndex, T>;
+  using Nodes = absl::InlinedVector<Node, 1>;
+  using IndexTable = internal::IndexTable;
+
+  template <typename Iterator, typename ValueType>
+  class LeafIterator;
 
   // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
   ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
@@ -106,19 +106,23 @@
   // The version that takes a pointer may be cheaper because it doesn't require
   // any Shape copies, but then it's up to you to ensure that the pointer stays
   // alive longer than this ShapeTree.
-  explicit ShapeTree(Shape shape);
-  explicit ShapeTree(const Shape* shape);
-  explicit ShapeTree(const std::shared_ptr<Shape>& shape);
+  explicit ShapeTree(Shape shape)
+      : ShapeTree(std::make_shared<Shape>(std::move(shape))) {}
+
+  explicit ShapeTree(const Shape* shape)
+      : ShapeTree(shape, CreateNodes(*shape)) {}
 
   // Create ShapeTree with the given shape, and init_value for all nodes.
-  ShapeTree(Shape shape, const T& init_value);
-  ShapeTree(const Shape* shape, const T& init_value);
-  ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
+  ShapeTree(Shape shape, const T& init_value)
+      : ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {}
+
+  ShapeTree(const Shape* shape, const T& init_value)
+      : ShapeTree(shape, CreateNodes(*shape, [&] { return init_value; })) {}
 
   // Returns the data element associated with the array in the shape at the
   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
-  const T& element(ShapeIndexView index) const;
-  T* mutable_element(ShapeIndexView index);
+  const T& element(ShapeIndexView index) const { return find(index)->second; }
+  T* mutable_element(ShapeIndexView index) { return &find(index)->second; }
 
   // Return the shape represented with this ShapeTree.
   const Shape& shape() const { return *shape_; }
@@ -128,76 +132,57 @@
   // This API replaces the underlying Shape object to the one supplied by the
   // caller, whom must ensure the object remain valid for the whole lifetime of
   // this ShapeTree object, and also that the Shape is consistent with it.
-  void replace_shape_ptr(const Shape* shape) {
+  void replace_shape_ptr(const Shape& shape) {
     if (shape_storage_ != nullptr) {
-      DCHECK_EQ(*shape, *shape_storage_);
+      DCHECK_EQ(shape, *shape_storage_);
       shape_storage_ = nullptr;
     }
-    shape_ = shape;
+    shape_ = &shape;
   }
 
   // Returns true if the node at the given index is a leaf node (an array
   // shape).
-  bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
+  bool IsLeaf(ShapeIndexView index) const {
+    return Lookup(index).children_start_id == -1;
+  }
 
-  ShapeTree(const ShapeTree&) = default;
-  ShapeTree& operator=(const ShapeTree&) = default;
-  ShapeTree(ShapeTree&&) = default;
-  ShapeTree& operator=(ShapeTree&& other) = default;
+  using iterator = typename Nodes::iterator;
+  using const_iterator = typename Nodes::const_iterator;
+  using reverse_iterator = typename Nodes::reverse_iterator;
+  using const_reverse_iterator = typename Nodes::const_reverse_iterator;
 
-  // iterator implements a bidirectional_iterator with
-  //  value_type = std::pair<ShapeIndex, T>.
-  //
-  // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
-  using iterator =
-      ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
-                        std::pair<ShapeIndex, T>>;
-  using const_iterator =
-      ShapeTreeIterator<const std::vector<Node>,
-                        typename std::vector<Node>::const_iterator,
-                        const std::pair<ShapeIndex, T>>;
-  using reverse_iterator = std::reverse_iterator<iterator>;
-  using const_reverse_iterator = std::reverse_iterator<const_iterator>;
-
-  using leaf_iterator =
-      ShapeTreeLeafIterator<std::vector<Node>,
-                            typename std::vector<Node>::iterator,
-                            std::pair<ShapeIndex, T>>;
-  using const_leaf_iterator =
-      ShapeTreeLeafIterator<const std::vector<Node>,
-                            typename std::vector<Node>::const_iterator,
-                            const std::pair<ShapeIndex, T>>;
+  using leaf_iterator = LeafIterator<iterator, Node>;
+  using const_leaf_iterator = LeafIterator<const_iterator, const Node>;
   using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
   using const_reverse_leaf_iterator =
       std::reverse_iterator<const_leaf_iterator>;
 
-  // begin/end for iterating over all nodes.
-  iterator begin() { return iterator(&nodes_, nodes_.begin()); }
-  iterator end() { return iterator(&nodes_, nodes_.end()); }
-  const_iterator begin() const {
-    return const_iterator(&nodes_, nodes_.begin());
-  }
-  const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); }
+  iterator begin() { return nodes_.begin(); }
+  iterator end() { return nodes_.end(); }
+  const_iterator begin() const { return nodes_.begin(); }
+  const_iterator end() const { return nodes_.end(); }
 
-  // rbegin/rend for iterating over all nodes in reverse.
-  reverse_iterator rbegin() { return reverse_iterator(end()); }
-  reverse_iterator rend() { return reverse_iterator(begin()); }
-  const_reverse_iterator rbegin() const {
-    return const_reverse_iterator(end());
-  }
-  const_reverse_iterator rend() const {
-    return const_reverse_iterator(begin());
-  }
+  reverse_iterator rbegin() { return nodes_.rbegin(); }
+  reverse_iterator rend() { return nodes_.rend(); }
+  const_reverse_iterator rbegin() const { return nodes_.rbegin(); }
+  const_reverse_iterator rend() const { return nodes_.rend(); }
 
   // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
-  // children).
-  leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); }
-  leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); }
+  // children). We check for empty tuples here, so the leaf iterator's `IsLeaf`
+  // check is simpler.
+  leaf_iterator leaf_begin() {
+    return ShapeUtil::IsEmptyTuple(*shape_)
+               ? leaf_end()
+               : leaf_iterator(nodes_, nodes_.begin());
+  }
+  leaf_iterator leaf_end() { return leaf_iterator(nodes_, nodes_.end()); }
   const_leaf_iterator leaf_begin() const {
-    return const_leaf_iterator(&nodes_, nodes_.begin());
+    return ShapeUtil::IsEmptyTuple(*shape_)
+               ? leaf_end()
+               : const_leaf_iterator(nodes_, nodes_.begin());
   }
   const_leaf_iterator leaf_end() const {
-    return const_leaf_iterator(&nodes_, nodes_.end());
+    return const_leaf_iterator(nodes_, nodes_.end());
   }
   // range-based iterator for leaf_begin()/leaf_end().
   tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
@@ -223,108 +208,152 @@
   // Returns an iterator pointing to the given ShapeIndex.
   // REQUIRES: index must exist in the ShapeTree.
   iterator find(ShapeIndexView index) {
-    Node* element = Lookup(index);
-    auto element_iter = nodes_.begin() + (element - &nodes_[0]);
-    return iterator(&nodes_, element_iter);
+    return nodes_.begin() + Lookup(index).node_id;
   }
   const_iterator find(ShapeIndexView index) const {
-    const Node* element = Lookup(index);
-    auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
-    return const_iterator(&nodes_, element_iter);
+    return nodes_.begin() + Lookup(index).node_id;
   }
 
   // Returns the number of leaf nodes in the tree.
   int64_t leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
 
+  // TODO(cjfj): Remove the `ForEach...` methods. They are redundant.
   // Recursively traverses the shape and calls the given function at each
-  // element. The function has the following arguments:
-  //
-  //   Fn :    A callable of type void(const ShapeIndex& index, const T& data)
-  //           (or compatible).
-  //   index : the index of the element in the shape. See ShapeUtil::GetSubshape
-  //           for definition of index.
-  //   data : The data value at this element.
-  template <typename Fn>
-  void ForEachElement(const Fn& func) const;
+  // element.
+  void ForEachElement(
+      absl::FunctionRef<void(const ShapeIndex&, const T&)> func) const {
+    for (const Node& node : nodes_) {
+      func(node.first, node.second);
+    }
+  }
 
-  // Like ForEachElement, but the callable has type
-  //
-  //   void (const ShapeIndex& index, T* data).
-  //
-  template <typename Fn>
-  void ForEachMutableElement(const Fn& func);
+  void ForEachMutableElement(
+      absl::FunctionRef<void(const ShapeIndex&, T*)> func) {
+    for (Node& node : nodes_) {
+      func(node.first, &node.second);
+    }
+  }
 
   // Like ForEach(Mutable)Element, but the callable returns a Status instead of
   // void.  The first non-OK return value is returned by the ForEach* function.
-  template <typename Fn>
-  Status ForEachElementWithStatus(const Fn& func) const;
-  template <typename Fn>
-  Status ForEachMutableElementWithStatus(const Fn& func);
+  Status ForEachElementWithStatus(
+      absl::FunctionRef<Status(const ShapeIndex&, const T&)> func) const {
+    for (const Node& node : nodes_) {
+      TF_RETURN_IF_ERROR(func(node.first, node.second));
+    }
+    return Status::OK();
+  }
+
+  Status ForEachMutableElementWithStatus(
+      absl::FunctionRef<Status(const ShapeIndex&, T*)> func) {
+    for (Node& node : nodes_) {
+      TF_RETURN_IF_ERROR(func(node.first, &node.second));
+    }
+    return Status::OK();
+  }
 
   // Maps each element to generate a new tree with the same shape.
   template <typename U>
-  ShapeTree<U> Map(const std::function<U(const T&)>& func) {
-    ShapeTree<U> result(shape_storage_);
-    ForEachElement([&](const ShapeIndex& index, const T& t) {
-      *result.mutable_element(index) = func(t);
-    });
+  ShapeTree<U> Map(absl::FunctionRef<U(const T&)> func) {
+    typename ShapeTree<U>::Nodes result_nodes;
+    result_nodes.reserve(nodes_.size());
+    for (const Node& node : nodes_) {
+      result_nodes.push_back({node.first, func(node.second)});
+    }
+
+    ShapeTree<U> result(shape_, std::move(result_nodes));
+    result.index_table_ = index_table_;
+    result.shape_storage_ = shape_storage_;
     return result;
   }
 
-  template <typename U>
-  ShapeTree<U> Map(const std::function<U(T*)>& func) {
-    ShapeTree<U> result(shape_storage_);
-    ForEachMutableElement([&](const ShapeIndex& index, T* t) {
-      *result.mutable_element(index) = func(t);
-    });
-    return result;
-  }
-
-  // Copy the subtree of values from 'other' rooted at ShapeIndex
-  // 'source_base_index' into the subtree of value in this ShapeTree rooted at
-  // 'target_base_index'.
+  // Copy the subtree of values from 'other' rooted at ShapeIndex 'src_index'
+  // into the subtree of value in this ShapeTree rooted at 'dst_index'.
   //
-  // Precondition: The subshape of other.shape() at index source_base_index must
-  // be compatible with the subshape of shape() at index target_base_index.
-  void CopySubtreeFrom(const ShapeTree<T>& other,
-                       const ShapeIndex& source_base_index,
-                       const ShapeIndex& target_base_index);
+  // Precondition: The subshape of other.shape() at index src_index must be
+  // compatible with the subshape of shape() at index dst_index.
+  void CopySubtreeFrom(const ShapeTree<T>& other, const ShapeIndex& src_index,
+                       const ShapeIndex& dst_index) {
+    const Shape& src_shape = ShapeUtil::GetSubshape(other.shape(), src_index);
+    const Shape& dst_shape = ShapeUtil::GetSubshape(shape(), dst_index);
+    CHECK(ShapeUtil::Compatible(src_shape, dst_shape))
+        << src_shape << ", " << dst_shape;
 
-  StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const;
+    // Replace the prefix `src_index` with `dst_index`.
+    auto replace_shape_index_prefix = [&](const ShapeIndex& index) {
+      auto without_prefix = ShapeIndexView(index).subspan(src_index.size());
+      ShapeIndex result;
+      result.reserve(dst_index.size() + without_prefix.size());
+      result.insert(result.end(), dst_index.begin(), dst_index.end());
+      result.insert(result.end(), without_prefix.begin(), without_prefix.end());
+      return result;
+    };
 
-  bool operator==(const ShapeTree<T>& other) const;
+    auto first = other.find(src_index);
+    auto last = first + ShapeUtil::SubshapeCount(src_shape);
+
+    std::transform(first, last, find(dst_index), [&](const Node& node) -> Node {
+      return {replace_shape_index_prefix(node.first), node.second};
+    });
+  }
+
+  StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const {
+    TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
+                        ShapeUtil::TryGetSubshape(shape(), index));
+    size_t count = ShapeUtil::SubshapeCount(*sub_shape);
+    Nodes sub_tree_nodes;
+    sub_tree_nodes.reserve(count);
+    for (auto it = find(index), end = it + count; it != end; ++it) {
+      // For each shape index, remove the prefix `index`.
+      auto without_prefix = ShapeIndexView(it->first).subspan(index.size());
+      sub_tree_nodes.push_back(Node{without_prefix, it->second});
+    }
+    return ShapeTree(sub_shape, std::move(sub_tree_nodes));
+  }
+
+  bool operator==(const ShapeTree<T>& other) const {
+    return nodes_ == other.nodes_;
+  }
   bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
 
  private:
-  // Initialize node->children based on 'shape'. All children are assigned the
-  // the given 'init_value'.
-  void InitChildren(const Shape& shape, const T& init_value, Node* node,
-                    Index* index);
+  explicit ShapeTree(std::shared_ptr<Shape> shape) : ShapeTree(shape.get()) {
+    shape_storage_.swap(shape);
+  }
 
-  // Initialize node->children based on 'shape'. All children have
-  // default-constructed data values.
-  void InitChildren(const Shape& shape, Node* node, Index* index);
+  ShapeTree(std::shared_ptr<Shape> shape, const T& init_value)
+      : ShapeTree(shape.get(), init_value) {
+    shape_storage_.swap(shape);
+  }
 
-  // Returns the number of subshapes, including interior nodes, in shape.
-  int64_t CountSubshapes(const Shape& shape);
+  ShapeTree(const Shape* shape, Nodes nodes)
+      : nodes_(std::move(nodes)), shape_(shape) {
+    DCHECK_EQ(nodes_.size(), ShapeUtil::SubshapeCount(*shape));
+  }
 
-  // Helpers for traversing the shape via ForEachElement. The helpers
-  // recursively traverse the subtree rooted at "index" (defined as in
-  // ShapeUtil::GetSubshape).
-  template <typename Fn>
-  static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
-  template <typename Fn>
-  static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
+  static Nodes CreateNodes(
+      const Shape& shape, absl::FunctionRef<T()> gen = [] { return T(); }) {
+    Nodes nodes;
+    ShapeUtil::ForEachSubshape(shape,
+                               [&](const Shape&, const ShapeIndex& index) {
+                                 nodes.push_back({index, gen()});
+                               });
+    return nodes;
+  }
 
-  // Return the tree node at the given index.
-  Node* Lookup(ShapeIndexView index);
-  const Node* Lookup(ShapeIndexView index) const;
+  // Returns the index table entry for the given shape index.
+  const IndexTable::Entry& Lookup(ShapeIndexView index) const {
+    // The index table is evaluated lazily.
+    if (index_table_.empty()) index_table_ = IndexTable(*shape_);
+    return index_table_[index];
+  }
 
   // The nodes in this shape tree.
-  std::vector<Node> nodes_;
+  Nodes nodes_;
 
-  // Index table for node lookups.
-  std::vector<Index> index_table_;
+  // Index table for node lookups. Each entry contains the index of the first
+  // child of the node at that index, or -1 for leaf nodes. Evaluated lazily.
+  mutable IndexTable index_table_;
 
   // If we own our Shape, this field contains it, and shape_ is a pointer into
   // here.  Otherwise if we don't own our shape, this is nullptr.
@@ -335,399 +364,60 @@
   const Shape* shape_;
 };
 
-// Internal iterator that performs a pre-order walk. This is cheap to copy.
-// The iterator value_type is equivalent to a
-// std::pair<ShapeIndex,T>&, similar to std::map.
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeIterator
-    : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
- public:
-  ShapeTreeIterator(ContainerType* nodes, IteratorType node)
-      : nodes_(nodes), node_(std::move(node)) {}
-
-  ShapeTreeIterator& operator++() {
-    ++node_;
-    return *this;
-  }
-  ShapeTreeIterator operator++(int) {
-    auto i = *this;
-    ++(*this);
-    return i;
-  }
-
-  ShapeTreeIterator& operator--() {
-    --node_;
-    return *this;
-  }
-  ShapeTreeIterator operator--(int) {
-    auto i = *this;
-    --(*this);
-    return i;
-  }
-
-  bool operator==(const ShapeTreeIterator& other) const {
-    return node_ == other.node_;
-  }
-  bool operator!=(const ShapeTreeIterator& other) const {
-    return node_ != other.node_;
-  }
-  ValueType& operator*() const { return node_->data; }
-  ValueType* operator->() const { return &node_->data; }
-
- private:
-  ContainerType* nodes_;
-  IteratorType node_;
-};
-
 // Internal iterator that performs a pre-order walk of the leaves. This is cheap
 // to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
 // similar to std::map.
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeLeafIterator
+template <typename T>
+template <typename Iterator, typename ValueType>
+class ShapeTree<T>::LeafIterator
     : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
  public:
-  ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node)
-      : nodes_(nodes), node_(std::move(node)) {
-    while (node_ != nodes_->end() && !node_->is_leaf) {
-      ++node_;
-    }
+  LeafIterator(const Nodes& nodes, Iterator it) : nodes_(nodes), it_(it) {
+    while ((it_ != nodes.end()) && !IsLeaf()) ++it_;
   }
 
-  ShapeTreeLeafIterator& operator++() {
-    ++node_;
-    while (node_ != nodes_->end() && !node_->is_leaf) {
-      ++node_;
-    }
+  LeafIterator& operator++() {
+    do {
+      ++it_;
+    } while ((it_ != nodes_.end()) && !IsLeaf());
     return *this;
   }
-  ShapeTreeLeafIterator operator++(int) {
-    auto i = *this;
+
+  LeafIterator operator++(int) {
+    auto prev = *this;
     ++(*this);
-    return i;
+    return prev;
   }
 
-  ShapeTreeLeafIterator& operator--() {
-    --node_;
-    while (node_ > nodes_->begin() && !node_->is_leaf) {
-      --node_;
-    }
+  LeafIterator& operator--() {
+    do {
+      --it_;
+    } while ((it_ != nodes_.begin()) && !IsLeaf());
     return *this;
   }
-  ShapeTreeLeafIterator operator--(int) {
-    auto i = *this;
+
+  LeafIterator operator--(int) {
+    auto prev = *this;
     --(*this);
-    return i;
+    return prev;
   }
 
-  bool operator==(const ShapeTreeLeafIterator& other) const {
-    return node_ == other.node_;
-  }
-  bool operator!=(const ShapeTreeLeafIterator& other) const {
-    return node_ != other.node_;
-  }
-  ValueType& operator*() const { return node_->data; }
-  ValueType* operator->() const { return &node_->data; }
+  bool operator==(const LeafIterator& other) const { return it_ == other.it_; }
+  bool operator!=(const LeafIterator& other) const { return !(*this == other); }
+  ValueType& operator*() const { return *it_; }
+  ValueType* operator->() const { return &*it_; }
 
  private:
-  ContainerType* nodes_;
-  IteratorType node_;
+  bool IsLeaf() const {
+    auto next = it_ + 1;
+    // If the node is not a leaf, the next node will have a longer shape index.
+    return (next == nodes_.end()) || (it_->first.size() >= next->first.size());
+  }
+
+  const Nodes& nodes_;
+  Iterator it_;
 };
 
-template <typename T>
-int64_t ShapeTree<T>::CountSubshapes(const Shape& shape) {
-  int64_t current_count = 1;
-  if (shape.IsTuple()) {
-    int64_t count = ShapeUtil::TupleElementCount(shape);
-    for (int i = 0; i < count; ++i) {
-      current_count += CountSubshapes(shape.tuple_shapes(i));
-    }
-  }
-  return current_count;
-}
-
-template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
-                                Node* node, Index* index) {
-  if (shape.IsTuple()) {
-    const int64_t size = ShapeUtil::TupleElementCount(shape);
-#ifndef NDEBUG
-    index->children_count = size;
-#endif
-    node->is_leaf = false;
-    ShapeIndex shape_index = node->data.first;
-    shape_index.push_back(0);
-
-    // At the end of the index_table, reserve a continuous space to hold the
-    // children of current node. In order to enforce the invariant that all
-    // children of a given node are placed together, we need to do the
-    // reservation before we recurse into any of its children.
-    int64_t children_start_position = index_table_.size();
-    index_table_.resize(index_table_.size() + size);
-
-    for (int i = 0; i < size; ++i) {
-      shape_index[shape_index.size() - 1] = i;
-      index_table_[children_start_position + i].index = nodes_.size();
-      // The first child of the node in the index table is placed at the end of
-      // the table.
-      index_table_[children_start_position + i].children_start =
-          index_table_.size();
-      nodes_.emplace_back(shape_index, init_value);
-      InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
-                   &index_table_[children_start_position + i]);
-    }
-  } else {
-#ifndef NDEBUG
-    index->children_count = 0;
-#endif
-  }
-}
-
-template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
-  if (shape.IsTuple()) {
-    const int64_t size = ShapeUtil::TupleElementCount(shape);
-#ifndef NDEBUG
-    index->children_count = size;
-#endif
-    node->is_leaf = false;
-    ShapeIndex shape_index = node->data.first;
-    shape_index.push_back(0);
-
-    // At the end of the index_table, reserve a continuous space to hold the
-    // children of current node. In order to enforce the invariant that all
-    // children of a given node are placed together, we need to do the
-    // reservation before we recurse into any of its children.
-    int64_t children_start_position = index_table_.size();
-    index_table_.resize(index_table_.size() + size);
-
-    for (int i = 0; i < size; ++i) {
-      shape_index[shape_index.size() - 1] = i;
-      index_table_[children_start_position + i].index = nodes_.size();
-      // The first child of the node in the index table is placed at the end of
-      // the table.
-      index_table_[children_start_position + i].children_start =
-          index_table_.size();
-      nodes_.emplace_back(shape_index);
-      InitChildren(shape.tuple_shapes(i), &nodes_.back(),
-                   &index_table_[children_start_position + i]);
-    }
-  } else {
-#ifndef NDEBUG
-    index->children_count = 0;
-#endif
-  }
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(Shape shape)
-    : shape_storage_(std::make_shared<Shape>(std::move(shape))),
-      shape_(shape_storage_.get()) {
-  const int64_t count = CountSubshapes(*shape_);
-  nodes_.reserve(count);
-  nodes_.emplace_back(ShapeIndex{});
-
-  index_table_.reserve(count);
-  index_table_.emplace_back(Index{0, 1});
-  InitChildren(*shape_, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
-  const int64_t count = CountSubshapes(*shape_);
-  nodes_.reserve(count);
-  nodes_.emplace_back(ShapeIndex{});
-
-  index_table_.reserve(count);
-  index_table_.emplace_back(Index{0, 1});
-  InitChildren(*shape_, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
-    : shape_storage_(shape), shape_(shape_storage_.get()) {
-  const int64_t count = CountSubshapes(*shape_);
-  nodes_.reserve(count);
-  nodes_.emplace_back(ShapeIndex{});
-
-  index_table_.reserve(count);
-  index_table_.emplace_back(Index{0, 1});
-  InitChildren(*shape_, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
-    : shape_storage_(std::make_shared<Shape>(std::move(shape))),
-      shape_(shape_storage_.get()) {
-  const int64_t count = CountSubshapes(*shape_);
-  nodes_.reserve(count);
-  nodes_.emplace_back(ShapeIndex{}, init_value);
-
-  index_table_.reserve(count);
-  index_table_.emplace_back(Index{0, 1});
-  InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
-    : shape_(shape) {
-  const int64_t count = CountSubshapes(*shape_);
-  nodes_.reserve(count);
-  nodes_.emplace_back(ShapeIndex{}, init_value);
-
-  index_table_.reserve(count);
-  index_table_.emplace_back(Index{0, 1});
-  InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
-                        const T& init_value)
-    : shape_storage_(shape), shape_(shape_storage_.get()) {
-  const int64_t count = CountSubshapes(*shape_);
-  nodes_.reserve(count);
-  nodes_.emplace_back(ShapeIndex{}, init_value);
-
-  index_table_.reserve(count);
-  index_table_.emplace_back(Index{0, 1});
-  InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-const T& ShapeTree<T>::element(ShapeIndexView index) const {
-  return Lookup(index)->data.second;
-}
-
-template <typename T>
-T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
-  return &Lookup(index)->data.second;
-}
-
-template <typename T>
-internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
-  Index* iter = &index_table_[0];
-  for (const int64_t i : index) {
-    CHECK_GE(i, 0);
-#ifndef NDEBUG
-    CHECK_LT(i, iter->children_count);
-#endif
-    iter = &index_table_[iter->children_start + i];
-  }
-
-  return &nodes_[iter->index];
-}
-
-template <typename T>
-const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
-    ShapeIndexView index) const {
-  return const_cast<ShapeTree*>(this)->Lookup(index);
-}
-
-/* static */
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachHelper(const Fn& func,
-                                   const std::vector<Node>& nodes) {
-  for (const auto& node : nodes) {
-    TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
-  }
-  return Status::OK();
-}
-
-/* static */
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
-                                          std::vector<Node>* nodes) {
-  for (auto& node : *nodes) {
-    TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
-  }
-  return Status::OK();
-}
-
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
-  return ForEachHelper(func, nodes_);
-}
-
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
-  return ForEachMutableHelper(func, &nodes_);
-}
-
-template <typename T>
-template <typename Fn>
-void ShapeTree<T>::ForEachElement(const Fn& func) const {
-  return ForEachHelper(
-             [&func](const ShapeIndex& index, const T& data) {
-               func(index, data);
-               return Status::OK();
-             },
-             nodes_)
-      .IgnoreError();
-}
-
-template <typename T>
-template <typename Fn>
-void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
-  return ForEachMutableHelper(
-             [&func](const ShapeIndex& index, T* data) {
-               func(index, data);
-               return Status::OK();
-             },
-             &nodes_)
-      .IgnoreError();
-}
-
-template <typename T>
-void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
-                                   const ShapeIndex& source_base_index,
-                                   const ShapeIndex& target_base_index) {
-  CHECK(ShapeUtil::Compatible(
-      ShapeUtil::GetSubshape(shape(), target_base_index),
-      ShapeUtil::GetSubshape(other.shape(), source_base_index)))
-      << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
-      << ShapeUtil::GetSubshape(other.shape(), source_base_index);
-  ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
-                            const ShapeIndex& index, T* data) {
-    // Copy the data element only if index is in the
-    // subtree rooted at target_base_index.
-    for (int i = 0; i < target_base_index.size(); ++i) {
-      if (i >= index.size() || index[i] != target_base_index[i]) {
-        return;
-      }
-    }
-    // Construct source element index to copy from.
-    ShapeIndex source_index = source_base_index;
-    for (int i = target_base_index.size(); i < index.size(); ++i) {
-      source_index.push_back(index[i]);
-    }
-    *data = other.element(source_index);
-  });
-}
-
-template <typename T>
-StatusOr<ShapeTree<T>> ShapeTree<T>::SubShapeTree(
-    const ShapeIndex& index) const {
-  TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
-                      ShapeUtil::TryGetSubshape(shape(), index));
-  ShapeTree<T> sub_shape_tree(*sub_shape);
-  sub_shape_tree.CopySubtreeFrom(*this, index, {});
-  return std::move(sub_shape_tree);
-}
-
-template <typename T>
-bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
-  bool equal = true;
-  ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
-    if (data != other.element(index)) {
-      equal = false;
-    }
-  });
-  return equal;
-}
-
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
index 4f3610e..fec042b 100644
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
@@ -174,9 +174,8 @@
     xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
                                     device_ordinal);
     shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
-        [](xla::MaybeOwningDeviceMemory* buffer) {
-          CHECK(buffer);
-          return buffer->AsDeviceMemoryBase();
+        [](const xla::MaybeOwningDeviceMemory& buffer) {
+          return buffer.AsDeviceMemoryBase();
         }));
     return shaped_buffer;
   }
diff --git a/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc b/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
index 6b1d657..dd5ae6f 100644
--- a/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
@@ -187,9 +187,8 @@
   xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
                                   device_ordinal);
   shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
-      [](xla::MaybeOwningDeviceMemory* buffer) {
-        CHECK(buffer);
-        return buffer->AsDeviceMemoryBase();
+      [](const xla::MaybeOwningDeviceMemory& buffer) {
+        return buffer.AsDeviceMemoryBase();
       }));
 
   // Write input root tuple.