[XLA] Cleanup / optimize `ShapeTree`.
PiperOrigin-RevId: 434889286
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index e7c8411..3641111 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -800,16 +800,17 @@
cc_library(
name = "shape_tree",
+ srcs = ["shape_tree.cc"],
hdrs = ["shape_tree.h"],
visibility = ["//visibility:public"],
deps = [
":shape_util",
":status_macros",
- ":xla_data_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 0a3cad5..d32a1b4 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -16,11 +16,11 @@
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
#include <deque>
+#include <functional>
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
-#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -41,9 +41,8 @@
void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
Workset* workset) {
- if (!workset->contains(instruction)) {
+ if (workset->insert(instruction).second) {
worklist->push_back(instruction);
- workset->insert(instruction);
VLOG(3) << "ADD instruction: " << instruction->name();
}
}
@@ -67,18 +66,17 @@
const ShapeIndex& shape_index,
HloLivenessAnalysis::HloIndexMap* live_index_map,
Worklist* worklist, Workset* workset) {
- auto it = live_index_map->find(instruction);
- if (it == live_index_map->end()) {
- auto it_added = live_index_map->emplace(
- std::piecewise_construct, std::forward_as_tuple(instruction),
- std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
- it = it_added.first;
+ std::unique_ptr<ShapeTree<bool>>& liveness = (*live_index_map)[instruction];
+ if (liveness == nullptr) {
+ liveness = std::make_unique<ShapeTree<bool>>(instruction->shape(),
+ /*init_value=*/false);
}
- if (it->second.element(shape_index) == false) {
+ bool& alive = *liveness->mutable_element(shape_index);
+ if (!alive) {
AddToWorklist(instruction, worklist, workset);
- *it->second.mutable_element(shape_index) = true;
+ alive = true;
VLOG(3) << "MARK instruction: " << instruction->name()
- << " shape_index: " << shape_index.ToString();
+ << " shape_index: " << shape_index;
}
}
@@ -87,23 +85,21 @@
HloLivenessAnalysis::HloIndexMap* live_index_map,
Worklist* worklist, Workset* workset) {
bool add_to_worklist = false;
- auto it = live_index_map->find(instruction);
- if (it == live_index_map->end()) {
- live_index_map->emplace(
- std::piecewise_construct, std::forward_as_tuple(instruction),
- std::forward_as_tuple(instruction->shape(), /*init_value=*/true));
+
+ std::unique_ptr<ShapeTree<bool>>& liveness = (*live_index_map)[instruction];
+ if (liveness == nullptr) {
+ liveness = std::make_unique<ShapeTree<bool>>(instruction->shape(),
+ /*init_value=*/true);
add_to_worklist = true;
} else {
- ShapeUtil::ForEachSubshape(
- instruction->shape(),
- [&](const Shape& sub_shape, const ShapeIndex& shape_index) {
- if (it->second.element(shape_index) == false) {
- add_to_worklist = true;
- *it->second.mutable_element(shape_index) = true;
- VLOG(3) << "MARK instruction: " << instruction->name()
- << " shape_index: " << shape_index.ToString();
- }
- });
+ for (auto& entry : *liveness) {
+ if (!entry.second) {
+ add_to_worklist = true;
+ entry.second = true;
+ VLOG(3) << "MARK instruction: " << instruction->name()
+ << " shape_index: " << entry.first;
+ }
+ }
}
if (add_to_worklist) {
AddToWorklist(instruction, worklist, workset);
@@ -122,7 +118,7 @@
CHECK_EQ(instruction->opcode(), HloOpcode::kTuple);
for (int64_t operand_index = 0; operand_index < instruction->operand_count();
++operand_index) {
- const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+ const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
if (shape_index.empty() || shape_index[0] != operand_index) {
return;
@@ -152,7 +148,7 @@
// Mark operand top-level index.
MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist,
workset);
- const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+ const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
// Propagate live shape indices along GTE -> Tuple edge.
ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
ShapeIndex operand_shape_index(shape_index);
@@ -171,7 +167,7 @@
HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
Workset* workset) {
CHECK_EQ(instruction->opcode(), HloOpcode::kWhile);
- const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+ const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
// Propagate liveness to while body computation root instruction.
@@ -202,8 +198,7 @@
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
auto* xla_while = callsite.instruction();
- const ShapeTree<bool>& index_tree =
- FindOrDie(*live_index_map, instruction);
+ const ShapeTree<bool>& index_tree = *live_index_map->at(instruction);
ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
// Propagate liveness to while result{shape_index}
MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist,
@@ -256,7 +251,7 @@
// If 'instruction' is a parameter, propagate live shape indices
// to the associated callsite's argument shape indices.
const ShapeTree<bool>& index_tree =
- FindOrDie(*live_index_map, instruction);
+ *live_index_map->at(instruction);
ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
MarkLiveAtIndex(caller->operand(operand_index), shape_index,
live_index_map, worklist, workset);
@@ -334,10 +329,8 @@
bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction,
const ShapeIndex& shape_index) const {
- if (ContainsKey(live_index_map_, instruction)) {
- return FindOrDie(live_index_map_, instruction).element(shape_index);
- }
- return false;
+ auto it = live_index_map_.find(instruction);
+ return (it != live_index_map_.end()) && it->second->element(shape_index);
}
/* static */
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h
index c780153..f990f8c 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h
@@ -16,6 +16,8 @@
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
+#include <memory>
+
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -37,8 +39,8 @@
class HloLivenessAnalysis {
public:
// Maps from an HloInstruction to its live/dead output shape indices.
- using HloIndexMap =
- absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>;
+ using HloIndexMap = absl::flat_hash_map<const HloInstruction*,
+ std::unique_ptr<ShapeTree<bool>>>;
// Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object
// which exports liveness for each {HloInstruction, ShapeIndex} in 'module'.
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 0f2e141..4c7b7d9 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -50,7 +50,7 @@
// s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
// into buffers_, we also need to update this pointer so that buffers_ doesn't
// point into s.
- buffers_.replace_shape_ptr(&on_device_shape_);
+ buffers_.replace_shape_ptr(on_device_shape_);
}
ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
@@ -61,7 +61,7 @@
// buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
// into buffers_, we also need to update this pointer so that buffers_ doesn't
// point into s.
- buffers_.replace_shape_ptr(&on_device_shape_);
+ buffers_.replace_shape_ptr(on_device_shape_);
return *this;
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index e708216..5cc6976 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -91,7 +91,7 @@
void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
buffers_ = std::move(buffers);
- buffers_.replace_shape_ptr(&on_device_shape_);
+ buffers_.replace_shape_ptr(on_device_shape_);
}
// Reset the shape of this shaped buffer and underlying buffer structure.
@@ -103,7 +103,7 @@
<< ", old: " << on_device_shape_;
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape);
on_device_shape_ = on_device_shape;
- buffers_.replace_shape_ptr(&on_device_shape_);
+ buffers_.replace_shape_ptr(on_device_shape_);
}
// TODO(b/170310047): remove this overload.
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
diff --git a/tensorflow/compiler/xla/shape_tree.cc b/tensorflow/compiler/xla/shape_tree.cc
new file mode 100644
index 0000000..33d99bb
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_tree.cc
@@ -0,0 +1,54 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_tree.h"
+
+namespace xla {
+namespace internal {
+
+IndexTable::IndexTable(const Shape& shape) : entries_(1) {
+ size_t next_node_id = 0;
+ CreateEntry(entries_[0], shape, next_node_id);
+}
+
+// TODO(cjfj): Index table cache?.
+void IndexTable::CreateEntry(Entry& entry, const Shape& shape,
+ size_t& next_node_id) {
+ entry.node_id = next_node_id++;
+ if (!shape.IsTuple()) return;
+
+ // The nodes are in depth-first pre-order. However, in order to efficiently
+ // lookup indices, we generate the index table using breadth-first.
+ size_t children_start_id = entries_.size();
+ entry.children_start_id = children_start_id;
+ // Add entry for children first, before recursing, so they are consecutive.
+ entries_.resize(entries_.size() + shape.tuple_shapes_size());
+ for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) {
+ CreateEntry(entries_[children_start_id + i], shape.tuple_shapes(i),
+ next_node_id);
+ }
+}
+
+const IndexTable::Entry& IndexTable::operator[](ShapeIndexView index) const {
+ const Entry* result = &entries_.front();
+ for (int64_t i : index) {
+ CHECK_GE(result->children_start_id, 0);
+ result = &entries_[result->children_start_id + i];
+ }
+ return *result;
+}
+
+} // namespace internal
+} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 178e05b..03ccbd2 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -16,19 +16,20 @@
#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
#define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
+#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
+#include <type_traits>
+#include <utility>
#include <vector>
-#include "absl/memory/memory.h"
-#include "absl/types/optional.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/functional/function_ref.h"
#include "absl/types/span.h"
-#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
@@ -37,41 +38,32 @@
namespace internal {
-// Internal representation of each node in a ShapeTree.
-template <typename T>
-struct ShapeTreeNode {
- // Data corresponding to this node.
- std::pair<ShapeIndex, T> data;
+class IndexTable {
+ public:
+ // Use indices, rather than pointers, so index table can be copied between
+ // ShapeTrees.
+ struct Entry {
+ // Index of the node in the nodes vector.
+ size_t node_id;
+ // Index of the first child of this node in the index table (-1 for leaves).
+ std::make_signed_t<size_t> children_start_id = -1;
+ };
- bool is_leaf = true;
+ IndexTable() = default;
+ explicit IndexTable(const Shape& shape);
- explicit ShapeTreeNode(ShapeIndex index)
- : ShapeTreeNode(std::move(index), T()) {}
- ShapeTreeNode(ShapeIndex index, T data)
- : data(std::move(index), std::move(data)) {}
-};
+ bool empty() const { return entries_.empty(); }
-// Internal representation of an index table entry.
-struct IndexTableEntry {
- // Index of the node in the ShapeTreeNode vector.
- uint32_t index;
- // Index of the first child in a IndexTableEntry vector. In the index
- // table all children entries for a given node will be placed next to each
- // other. This allows us to use a single field to index them.
- uint32_t children_start;
-#ifndef NDEBUG
- // Number of children, used for bounds checking.
- uint32_t children_count;
-#endif
+ const Entry& operator[](ShapeIndexView index) const;
+
+ private:
+ void CreateEntry(Entry& entry, const Shape& shape, size_t& next_node_id);
+
+ absl::InlinedVector<Entry, 1> entries_;
};
} // namespace internal
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeIterator;
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeLeafIterator;
-
// A ShapeTree<T> is a recursive data structure which mirrors the structure of a
// XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
// in the shape. For array shapes, a ShapeTree trivially holds a single value of
@@ -88,14 +80,22 @@
//
// Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
// it's helpful not to copy a Shape just to make a ShapeTree. In these cases,
-// you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's
-// then up to you to ensure that the pointed-to Shape doesn't die or mutate
-// before its ShapeTree goes away.
+// you can pass a Shape* instead of a Shape to the ShapeTree constructor. It's
+// then up to you to ensure that the pointed-to Shape isn't freed, moved or
+// modified before its ShapeTree goes away.
template <typename T>
class ShapeTree {
+ template <typename U>
+ friend class ShapeTree;
+
public:
- using Node = internal::ShapeTreeNode<T>;
- using Index = internal::IndexTableEntry;
+ // TODO(cjfj): Don't store ShapeIndex with data. Generate it or cache it?
+ using Node = std::pair<ShapeIndex, T>;
+ using Nodes = absl::InlinedVector<Node, 1>;
+ using IndexTable = internal::IndexTable;
+
+ template <typename Iterator, typename ValueType>
+ class LeafIterator;
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
@@ -106,19 +106,23 @@
// The version that takes a pointer may be cheaper because it doesn't require
// any Shape copies, but then it's up to you to ensure that the pointer stays
// alive longer than this ShapeTree.
- explicit ShapeTree(Shape shape);
- explicit ShapeTree(const Shape* shape);
- explicit ShapeTree(const std::shared_ptr<Shape>& shape);
+ explicit ShapeTree(Shape shape)
+ : ShapeTree(std::make_shared<Shape>(std::move(shape))) {}
+
+ explicit ShapeTree(const Shape* shape)
+ : ShapeTree(shape, CreateNodes(*shape)) {}
// Create ShapeTree with the given shape, and init_value for all nodes.
- ShapeTree(Shape shape, const T& init_value);
- ShapeTree(const Shape* shape, const T& init_value);
- ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
+ ShapeTree(Shape shape, const T& init_value)
+ : ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {}
+
+ ShapeTree(const Shape* shape, const T& init_value)
+ : ShapeTree(shape, CreateNodes(*shape, [&] { return init_value; })) {}
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
- const T& element(ShapeIndexView index) const;
- T* mutable_element(ShapeIndexView index);
+ const T& element(ShapeIndexView index) const { return find(index)->second; }
+ T* mutable_element(ShapeIndexView index) { return &find(index)->second; }
// Return the shape represented with this ShapeTree.
const Shape& shape() const { return *shape_; }
@@ -128,76 +132,57 @@
// This API replaces the underlying Shape object to the one supplied by the
// caller, whom must ensure the object remain valid for the whole lifetime of
// this ShapeTree object, and also that the Shape is consistent with it.
- void replace_shape_ptr(const Shape* shape) {
+ void replace_shape_ptr(const Shape& shape) {
if (shape_storage_ != nullptr) {
- DCHECK_EQ(*shape, *shape_storage_);
+ DCHECK_EQ(shape, *shape_storage_);
shape_storage_ = nullptr;
}
- shape_ = shape;
+ shape_ = &shape;
}
// Returns true if the node at the given index is a leaf node (an array
// shape).
- bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
+ bool IsLeaf(ShapeIndexView index) const {
+ return Lookup(index).children_start_id == -1;
+ }
- ShapeTree(const ShapeTree&) = default;
- ShapeTree& operator=(const ShapeTree&) = default;
- ShapeTree(ShapeTree&&) = default;
- ShapeTree& operator=(ShapeTree&& other) = default;
+ using iterator = typename Nodes::iterator;
+ using const_iterator = typename Nodes::const_iterator;
+ using reverse_iterator = typename Nodes::reverse_iterator;
+ using const_reverse_iterator = typename Nodes::const_reverse_iterator;
- // iterator implements a bidirectional_iterator with
- // value_type = std::pair<ShapeIndex, T>.
- //
- // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
- using iterator =
- ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
- std::pair<ShapeIndex, T>>;
- using const_iterator =
- ShapeTreeIterator<const std::vector<Node>,
- typename std::vector<Node>::const_iterator,
- const std::pair<ShapeIndex, T>>;
- using reverse_iterator = std::reverse_iterator<iterator>;
- using const_reverse_iterator = std::reverse_iterator<const_iterator>;
-
- using leaf_iterator =
- ShapeTreeLeafIterator<std::vector<Node>,
- typename std::vector<Node>::iterator,
- std::pair<ShapeIndex, T>>;
- using const_leaf_iterator =
- ShapeTreeLeafIterator<const std::vector<Node>,
- typename std::vector<Node>::const_iterator,
- const std::pair<ShapeIndex, T>>;
+ using leaf_iterator = LeafIterator<iterator, Node>;
+ using const_leaf_iterator = LeafIterator<const_iterator, const Node>;
using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
using const_reverse_leaf_iterator =
std::reverse_iterator<const_leaf_iterator>;
- // begin/end for iterating over all nodes.
- iterator begin() { return iterator(&nodes_, nodes_.begin()); }
- iterator end() { return iterator(&nodes_, nodes_.end()); }
- const_iterator begin() const {
- return const_iterator(&nodes_, nodes_.begin());
- }
- const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); }
+ iterator begin() { return nodes_.begin(); }
+ iterator end() { return nodes_.end(); }
+ const_iterator begin() const { return nodes_.begin(); }
+ const_iterator end() const { return nodes_.end(); }
- // rbegin/rend for iterating over all nodes in reverse.
- reverse_iterator rbegin() { return reverse_iterator(end()); }
- reverse_iterator rend() { return reverse_iterator(begin()); }
- const_reverse_iterator rbegin() const {
- return const_reverse_iterator(end());
- }
- const_reverse_iterator rend() const {
- return const_reverse_iterator(begin());
- }
+ reverse_iterator rbegin() { return nodes_.rbegin(); }
+ reverse_iterator rend() { return nodes_.rend(); }
+ const_reverse_iterator rbegin() const { return nodes_.rbegin(); }
+ const_reverse_iterator rend() const { return nodes_.rend(); }
// leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
- // children).
- leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); }
- leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); }
+ // children). We check for empty tuples here, so the leaf iterator's `IsLeaf`
+ // check is simpler.
+ leaf_iterator leaf_begin() {
+ return ShapeUtil::IsEmptyTuple(*shape_)
+ ? leaf_end()
+ : leaf_iterator(nodes_, nodes_.begin());
+ }
+ leaf_iterator leaf_end() { return leaf_iterator(nodes_, nodes_.end()); }
const_leaf_iterator leaf_begin() const {
- return const_leaf_iterator(&nodes_, nodes_.begin());
+ return ShapeUtil::IsEmptyTuple(*shape_)
+ ? leaf_end()
+ : const_leaf_iterator(nodes_, nodes_.begin());
}
const_leaf_iterator leaf_end() const {
- return const_leaf_iterator(&nodes_, nodes_.end());
+ return const_leaf_iterator(nodes_, nodes_.end());
}
// range-based iterator for leaf_begin()/leaf_end().
tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
@@ -223,108 +208,152 @@
// Returns an iterator pointing to the given ShapeIndex.
// REQUIRES: index must exist in the ShapeTree.
iterator find(ShapeIndexView index) {
- Node* element = Lookup(index);
- auto element_iter = nodes_.begin() + (element - &nodes_[0]);
- return iterator(&nodes_, element_iter);
+ return nodes_.begin() + Lookup(index).node_id;
}
const_iterator find(ShapeIndexView index) const {
- const Node* element = Lookup(index);
- auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
- return const_iterator(&nodes_, element_iter);
+ return nodes_.begin() + Lookup(index).node_id;
}
// Returns the number of leaf nodes in the tree.
int64_t leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
+ // TODO(cjfj): Remove the `ForEach...` methods. They are redundant.
// Recursively traverses the shape and calls the given function at each
- // element. The function has the following arguments:
- //
- // Fn : A callable of type void(const ShapeIndex& index, const T& data)
- // (or compatible).
- // index : the index of the element in the shape. See ShapeUtil::GetSubshape
- // for definition of index.
- // data : The data value at this element.
- template <typename Fn>
- void ForEachElement(const Fn& func) const;
+ // element.
+ void ForEachElement(
+ absl::FunctionRef<void(const ShapeIndex&, const T&)> func) const {
+ for (const Node& node : nodes_) {
+ func(node.first, node.second);
+ }
+ }
- // Like ForEachElement, but the callable has type
- //
- // void (const ShapeIndex& index, T* data).
- //
- template <typename Fn>
- void ForEachMutableElement(const Fn& func);
+ void ForEachMutableElement(
+ absl::FunctionRef<void(const ShapeIndex&, T*)> func) {
+ for (Node& node : nodes_) {
+ func(node.first, &node.second);
+ }
+ }
// Like ForEach(Mutable)Element, but the callable returns a Status instead of
// void. The first non-OK return value is returned by the ForEach* function.
- template <typename Fn>
- Status ForEachElementWithStatus(const Fn& func) const;
- template <typename Fn>
- Status ForEachMutableElementWithStatus(const Fn& func);
+ Status ForEachElementWithStatus(
+ absl::FunctionRef<Status(const ShapeIndex&, const T&)> func) const {
+ for (const Node& node : nodes_) {
+ TF_RETURN_IF_ERROR(func(node.first, node.second));
+ }
+ return Status::OK();
+ }
+
+ Status ForEachMutableElementWithStatus(
+ absl::FunctionRef<Status(const ShapeIndex&, T*)> func) {
+ for (Node& node : nodes_) {
+ TF_RETURN_IF_ERROR(func(node.first, &node.second));
+ }
+ return Status::OK();
+ }
// Maps each element to generate a new tree with the same shape.
template <typename U>
- ShapeTree<U> Map(const std::function<U(const T&)>& func) {
- ShapeTree<U> result(shape_storage_);
- ForEachElement([&](const ShapeIndex& index, const T& t) {
- *result.mutable_element(index) = func(t);
- });
+ ShapeTree<U> Map(absl::FunctionRef<U(const T&)> func) {
+ typename ShapeTree<U>::Nodes result_nodes;
+ result_nodes.reserve(nodes_.size());
+ for (const Node& node : nodes_) {
+ result_nodes.push_back({node.first, func(node.second)});
+ }
+
+ ShapeTree<U> result(shape_, std::move(result_nodes));
+ result.index_table_ = index_table_;
+ result.shape_storage_ = shape_storage_;
return result;
}
- template <typename U>
- ShapeTree<U> Map(const std::function<U(T*)>& func) {
- ShapeTree<U> result(shape_storage_);
- ForEachMutableElement([&](const ShapeIndex& index, T* t) {
- *result.mutable_element(index) = func(t);
- });
- return result;
- }
-
- // Copy the subtree of values from 'other' rooted at ShapeIndex
- // 'source_base_index' into the subtree of value in this ShapeTree rooted at
- // 'target_base_index'.
+ // Copy the subtree of values from 'other' rooted at ShapeIndex 'src_index'
+ // into the subtree of value in this ShapeTree rooted at 'dst_index'.
//
- // Precondition: The subshape of other.shape() at index source_base_index must
- // be compatible with the subshape of shape() at index target_base_index.
- void CopySubtreeFrom(const ShapeTree<T>& other,
- const ShapeIndex& source_base_index,
- const ShapeIndex& target_base_index);
+ // Precondition: The subshape of other.shape() at index src_index must be
+ // compatible with the subshape of shape() at index dst_index.
+ void CopySubtreeFrom(const ShapeTree<T>& other, const ShapeIndex& src_index,
+ const ShapeIndex& dst_index) {
+ const Shape& src_shape = ShapeUtil::GetSubshape(other.shape(), src_index);
+ const Shape& dst_shape = ShapeUtil::GetSubshape(shape(), dst_index);
+ CHECK(ShapeUtil::Compatible(src_shape, dst_shape))
+ << src_shape << ", " << dst_shape;
- StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const;
+ // Replace the prefix `src_index` with `dst_index`.
+ auto replace_shape_index_prefix = [&](const ShapeIndex& index) {
+ auto without_prefix = ShapeIndexView(index).subspan(src_index.size());
+ ShapeIndex result;
+ result.reserve(dst_index.size() + without_prefix.size());
+ result.insert(result.end(), dst_index.begin(), dst_index.end());
+ result.insert(result.end(), without_prefix.begin(), without_prefix.end());
+ return result;
+ };
- bool operator==(const ShapeTree<T>& other) const;
+ auto first = other.find(src_index);
+ auto last = first + ShapeUtil::SubshapeCount(src_shape);
+
+ std::transform(first, last, find(dst_index), [&](const Node& node) -> Node {
+ return {replace_shape_index_prefix(node.first), node.second};
+ });
+ }
+
+ StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const {
+ TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
+ ShapeUtil::TryGetSubshape(shape(), index));
+ size_t count = ShapeUtil::SubshapeCount(*sub_shape);
+ Nodes sub_tree_nodes;
+ sub_tree_nodes.reserve(count);
+ for (auto it = find(index), end = it + count; it != end; ++it) {
+ // For each shape index, remove the prefix `index`.
+ auto without_prefix = ShapeIndexView(it->first).subspan(index.size());
+ sub_tree_nodes.push_back(Node{without_prefix, it->second});
+ }
+ return ShapeTree(sub_shape, std::move(sub_tree_nodes));
+ }
+
+ bool operator==(const ShapeTree<T>& other) const {
+ return nodes_ == other.nodes_;
+ }
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
private:
- // Initialize node->children based on 'shape'. All children are assigned the
- // the given 'init_value'.
- void InitChildren(const Shape& shape, const T& init_value, Node* node,
- Index* index);
+ explicit ShapeTree(std::shared_ptr<Shape> shape) : ShapeTree(shape.get()) {
+ shape_storage_.swap(shape);
+ }
- // Initialize node->children based on 'shape'. All children have
- // default-constructed data values.
- void InitChildren(const Shape& shape, Node* node, Index* index);
+ ShapeTree(std::shared_ptr<Shape> shape, const T& init_value)
+ : ShapeTree(shape.get(), init_value) {
+ shape_storage_.swap(shape);
+ }
- // Returns the number of subshapes, including interior nodes, in shape.
- int64_t CountSubshapes(const Shape& shape);
+ ShapeTree(const Shape* shape, Nodes nodes)
+ : nodes_(std::move(nodes)), shape_(shape) {
+ DCHECK_EQ(nodes_.size(), ShapeUtil::SubshapeCount(*shape));
+ }
- // Helpers for traversing the shape via ForEachElement. The helpers
- // recursively traverse the subtree rooted at "index" (defined as in
- // ShapeUtil::GetSubshape).
- template <typename Fn>
- static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
- template <typename Fn>
- static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
+ static Nodes CreateNodes(
+ const Shape& shape, absl::FunctionRef<T()> gen = [] { return T(); }) {
+ Nodes nodes;
+ ShapeUtil::ForEachSubshape(shape,
+ [&](const Shape&, const ShapeIndex& index) {
+ nodes.push_back({index, gen()});
+ });
+ return nodes;
+ }
- // Return the tree node at the given index.
- Node* Lookup(ShapeIndexView index);
- const Node* Lookup(ShapeIndexView index) const;
+ // Returns the index table entry for the given shape index.
+ const IndexTable::Entry& Lookup(ShapeIndexView index) const {
+ // The index table is evaluated lazily.
+ if (index_table_.empty()) index_table_ = IndexTable(*shape_);
+ return index_table_[index];
+ }
// The nodes in this shape tree.
- std::vector<Node> nodes_;
+ Nodes nodes_;
- // Index table for node lookups.
- std::vector<Index> index_table_;
+ // Index table for node lookups. Each entry contains the index of the first
+ // child of the node at that index, or -1 for leaf nodes. Evaluated lazily.
+ mutable IndexTable index_table_;
// If we own our Shape, this field contains it, and shape_ is a pointer into
// here. Otherwise if we don't own our shape, this is nullptr.
@@ -335,399 +364,60 @@
const Shape* shape_;
};
-// Internal iterator that performs a pre-order walk. This is cheap to copy.
-// The iterator value_type is equivalent to a
-// std::pair<ShapeIndex,T>&, similar to std::map.
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeIterator
- : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
- public:
- ShapeTreeIterator(ContainerType* nodes, IteratorType node)
- : nodes_(nodes), node_(std::move(node)) {}
-
- ShapeTreeIterator& operator++() {
- ++node_;
- return *this;
- }
- ShapeTreeIterator operator++(int) {
- auto i = *this;
- ++(*this);
- return i;
- }
-
- ShapeTreeIterator& operator--() {
- --node_;
- return *this;
- }
- ShapeTreeIterator operator--(int) {
- auto i = *this;
- --(*this);
- return i;
- }
-
- bool operator==(const ShapeTreeIterator& other) const {
- return node_ == other.node_;
- }
- bool operator!=(const ShapeTreeIterator& other) const {
- return node_ != other.node_;
- }
- ValueType& operator*() const { return node_->data; }
- ValueType* operator->() const { return &node_->data; }
-
- private:
- ContainerType* nodes_;
- IteratorType node_;
-};
-
// Internal iterator that performs a pre-order walk of the leaves. This is cheap
// to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
// similar to std::map.
-template <typename ContainerType, typename IteratorType, typename ValueType>
-class ShapeTreeLeafIterator
+template <typename T>
+template <typename Iterator, typename ValueType>
+class ShapeTree<T>::LeafIterator
: public std::iterator<std::bidirectional_iterator_tag, ValueType> {
public:
- ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node)
- : nodes_(nodes), node_(std::move(node)) {
- while (node_ != nodes_->end() && !node_->is_leaf) {
- ++node_;
- }
+ LeafIterator(const Nodes& nodes, Iterator it) : nodes_(nodes), it_(it) {
+ while ((it_ != nodes.end()) && !IsLeaf()) ++it_;
}
- ShapeTreeLeafIterator& operator++() {
- ++node_;
- while (node_ != nodes_->end() && !node_->is_leaf) {
- ++node_;
- }
+ LeafIterator& operator++() {
+ do {
+ ++it_;
+ } while ((it_ != nodes_.end()) && !IsLeaf());
return *this;
}
- ShapeTreeLeafIterator operator++(int) {
- auto i = *this;
+
+ LeafIterator operator++(int) {
+ auto prev = *this;
++(*this);
- return i;
+ return prev;
}
- ShapeTreeLeafIterator& operator--() {
- --node_;
- while (node_ > nodes_->begin() && !node_->is_leaf) {
- --node_;
- }
+ LeafIterator& operator--() {
+ do {
+ --it_;
+ } while ((it_ != nodes_.begin()) && !IsLeaf());
return *this;
}
- ShapeTreeLeafIterator operator--(int) {
- auto i = *this;
+
+ LeafIterator operator--(int) {
+ auto prev = *this;
--(*this);
- return i;
+ return prev;
}
- bool operator==(const ShapeTreeLeafIterator& other) const {
- return node_ == other.node_;
- }
- bool operator!=(const ShapeTreeLeafIterator& other) const {
- return node_ != other.node_;
- }
- ValueType& operator*() const { return node_->data; }
- ValueType* operator->() const { return &node_->data; }
+ bool operator==(const LeafIterator& other) const { return it_ == other.it_; }
+ bool operator!=(const LeafIterator& other) const { return !(*this == other); }
+ ValueType& operator*() const { return *it_; }
+ ValueType* operator->() const { return &*it_; }
private:
- ContainerType* nodes_;
- IteratorType node_;
+ bool IsLeaf() const {
+ auto next = it_ + 1;
+ // If the node is not a leaf, the next node will have a longer shape index.
+ return (next == nodes_.end()) || (it_->first.size() >= next->first.size());
+ }
+
+ const Nodes& nodes_;
+ Iterator it_;
};
-template <typename T>
-int64_t ShapeTree<T>::CountSubshapes(const Shape& shape) {
- int64_t current_count = 1;
- if (shape.IsTuple()) {
- int64_t count = ShapeUtil::TupleElementCount(shape);
- for (int i = 0; i < count; ++i) {
- current_count += CountSubshapes(shape.tuple_shapes(i));
- }
- }
- return current_count;
-}
-
-template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
- Node* node, Index* index) {
- if (shape.IsTuple()) {
- const int64_t size = ShapeUtil::TupleElementCount(shape);
-#ifndef NDEBUG
- index->children_count = size;
-#endif
- node->is_leaf = false;
- ShapeIndex shape_index = node->data.first;
- shape_index.push_back(0);
-
- // At the end of the index_table, reserve a continuous space to hold the
- // children of current node. In order to enforce the invariant that all
- // children of a given node are placed together, we need to do the
- // reservation before we recurse into any of its children.
- int64_t children_start_position = index_table_.size();
- index_table_.resize(index_table_.size() + size);
-
- for (int i = 0; i < size; ++i) {
- shape_index[shape_index.size() - 1] = i;
- index_table_[children_start_position + i].index = nodes_.size();
- // The first child of the node in the index table is placed at the end of
- // the table.
- index_table_[children_start_position + i].children_start =
- index_table_.size();
- nodes_.emplace_back(shape_index, init_value);
- InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
- &index_table_[children_start_position + i]);
- }
- } else {
-#ifndef NDEBUG
- index->children_count = 0;
-#endif
- }
-}
-
-template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
- if (shape.IsTuple()) {
- const int64_t size = ShapeUtil::TupleElementCount(shape);
-#ifndef NDEBUG
- index->children_count = size;
-#endif
- node->is_leaf = false;
- ShapeIndex shape_index = node->data.first;
- shape_index.push_back(0);
-
- // At the end of the index_table, reserve a continuous space to hold the
- // children of current node. In order to enforce the invariant that all
- // children of a given node are placed together, we need to do the
- // reservation before we recurse into any of its children.
- int64_t children_start_position = index_table_.size();
- index_table_.resize(index_table_.size() + size);
-
- for (int i = 0; i < size; ++i) {
- shape_index[shape_index.size() - 1] = i;
- index_table_[children_start_position + i].index = nodes_.size();
- // The first child of the node in the index table is placed at the end of
- // the table.
- index_table_[children_start_position + i].children_start =
- index_table_.size();
- nodes_.emplace_back(shape_index);
- InitChildren(shape.tuple_shapes(i), &nodes_.back(),
- &index_table_[children_start_position + i]);
- }
- } else {
-#ifndef NDEBUG
- index->children_count = 0;
-#endif
- }
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(Shape shape)
- : shape_storage_(std::make_shared<Shape>(std::move(shape))),
- shape_(shape_storage_.get()) {
- const int64_t count = CountSubshapes(*shape_);
- nodes_.reserve(count);
- nodes_.emplace_back(ShapeIndex{});
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
- const int64_t count = CountSubshapes(*shape_);
- nodes_.reserve(count);
- nodes_.emplace_back(ShapeIndex{});
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
- : shape_storage_(shape), shape_(shape_storage_.get()) {
- const int64_t count = CountSubshapes(*shape_);
- nodes_.reserve(count);
- nodes_.emplace_back(ShapeIndex{});
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
- : shape_storage_(std::make_shared<Shape>(std::move(shape))),
- shape_(shape_storage_.get()) {
- const int64_t count = CountSubshapes(*shape_);
- nodes_.reserve(count);
- nodes_.emplace_back(ShapeIndex{}, init_value);
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
- : shape_(shape) {
- const int64_t count = CountSubshapes(*shape_);
- nodes_.reserve(count);
- nodes_.emplace_back(ShapeIndex{}, init_value);
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
- const T& init_value)
- : shape_storage_(shape), shape_(shape_storage_.get()) {
- const int64_t count = CountSubshapes(*shape_);
- nodes_.reserve(count);
- nodes_.emplace_back(ShapeIndex{}, init_value);
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
-}
-
-template <typename T>
-const T& ShapeTree<T>::element(ShapeIndexView index) const {
- return Lookup(index)->data.second;
-}
-
-template <typename T>
-T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
- return &Lookup(index)->data.second;
-}
-
-template <typename T>
-internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
- Index* iter = &index_table_[0];
- for (const int64_t i : index) {
- CHECK_GE(i, 0);
-#ifndef NDEBUG
- CHECK_LT(i, iter->children_count);
-#endif
- iter = &index_table_[iter->children_start + i];
- }
-
- return &nodes_[iter->index];
-}
-
-template <typename T>
-const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
- ShapeIndexView index) const {
- return const_cast<ShapeTree*>(this)->Lookup(index);
-}
-
-/* static */
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachHelper(const Fn& func,
- const std::vector<Node>& nodes) {
- for (const auto& node : nodes) {
- TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
- }
- return Status::OK();
-}
-
-/* static */
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
- std::vector<Node>* nodes) {
- for (auto& node : *nodes) {
- TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
- }
- return Status::OK();
-}
-
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
- return ForEachHelper(func, nodes_);
-}
-
-template <typename T>
-template <typename Fn>
-Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
- return ForEachMutableHelper(func, &nodes_);
-}
-
-template <typename T>
-template <typename Fn>
-void ShapeTree<T>::ForEachElement(const Fn& func) const {
- return ForEachHelper(
- [&func](const ShapeIndex& index, const T& data) {
- func(index, data);
- return Status::OK();
- },
- nodes_)
- .IgnoreError();
-}
-
-template <typename T>
-template <typename Fn>
-void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
- return ForEachMutableHelper(
- [&func](const ShapeIndex& index, T* data) {
- func(index, data);
- return Status::OK();
- },
- &nodes_)
- .IgnoreError();
-}
-
-template <typename T>
-void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
- const ShapeIndex& source_base_index,
- const ShapeIndex& target_base_index) {
- CHECK(ShapeUtil::Compatible(
- ShapeUtil::GetSubshape(shape(), target_base_index),
- ShapeUtil::GetSubshape(other.shape(), source_base_index)))
- << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
- << ShapeUtil::GetSubshape(other.shape(), source_base_index);
- ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
- const ShapeIndex& index, T* data) {
- // Copy the data element only if index is in the
- // subtree rooted at target_base_index.
- for (int i = 0; i < target_base_index.size(); ++i) {
- if (i >= index.size() || index[i] != target_base_index[i]) {
- return;
- }
- }
- // Construct source element index to copy from.
- ShapeIndex source_index = source_base_index;
- for (int i = target_base_index.size(); i < index.size(); ++i) {
- source_index.push_back(index[i]);
- }
- *data = other.element(source_index);
- });
-}
-
-template <typename T>
-StatusOr<ShapeTree<T>> ShapeTree<T>::SubShapeTree(
- const ShapeIndex& index) const {
- TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
- ShapeUtil::TryGetSubshape(shape(), index));
- ShapeTree<T> sub_shape_tree(*sub_shape);
- sub_shape_tree.CopySubtreeFrom(*this, index, {});
- return std::move(sub_shape_tree);
-}
-
-template <typename T>
-bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
- bool equal = true;
- ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
- if (data != other.element(index)) {
- equal = false;
- }
- });
- return equal;
-}
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
index 4f3610e..fec042b 100644
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
@@ -174,9 +174,8 @@
xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
device_ordinal);
shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
- [](xla::MaybeOwningDeviceMemory* buffer) {
- CHECK(buffer);
- return buffer->AsDeviceMemoryBase();
+ [](const xla::MaybeOwningDeviceMemory& buffer) {
+ return buffer.AsDeviceMemoryBase();
}));
return shaped_buffer;
}
diff --git a/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc b/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
index 6b1d657..dd5ae6f 100644
--- a/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc
@@ -187,9 +187,8 @@
xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
device_ordinal);
shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
- [](xla::MaybeOwningDeviceMemory* buffer) {
- CHECK(buffer);
- return buffer->AsDeviceMemoryBase();
+ [](const xla::MaybeOwningDeviceMemory& buffer) {
+ return buffer.AsDeviceMemoryBase();
}));
// Write input root tuple.