| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ |
| #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ |
| |
| #include <algorithm> |
| #include <functional> |
| #include <iterator> |
| #include <memory> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/functional/function_ref.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/gtl/iterator_range.h" |
| #include "tensorflow/core/platform/logging.h" |
| |
| namespace xla { |
| |
| namespace internal { |
| |
| class IndexTable { |
| public: |
| // Use indices, rather than pointers, so index table can be copied between |
| // ShapeTrees. |
| struct Entry { |
| // Index of the node in the nodes vector. |
| size_t node_id; |
| // Index of the first child of this node in the index table (-1 for leaves). |
| std::make_signed_t<size_t> children_start_id = -1; |
| }; |
| |
| IndexTable() = default; |
| explicit IndexTable(const Shape& shape); |
| |
| bool empty() const { return entries_.empty(); } |
| |
| const Entry& operator[](ShapeIndexView index) const; |
| |
| private: |
| void CreateEntry(Entry& entry, const Shape& shape, size_t& next_node_id); |
| |
| absl::InlinedVector<Entry, 1> entries_; |
| }; |
| |
| } // namespace internal |
| |
| // A ShapeTree<T> is a recursive data structure which mirrors the structure of a |
| // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) |
| // in the shape. For array shapes, a ShapeTree trivially holds a single value of |
| // type T. |
| // |
| // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a |
| // ShapeTree is an identically structured tree with data elements of type T at |
| // every node. I.e. the root is a tuple by definition, all interior nodes are |
| // also tuples, and all leaves are arrays. |
| // |
| // Like the Shape data structure, this is a tree and tuple elements cannot be |
| // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T |
| // object. |
| // |
| // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes |
| // it's helpful not to copy a Shape just to make a ShapeTree. In these cases, |
| // you can pass a Shape* instead of a Shape to the ShapeTree constructor. It's |
| // then up to you to ensure that the pointed-to Shape isn't freed, moved or |
| // modified before its ShapeTree goes away. |
| template <typename T> |
| class ShapeTree { |
| template <typename U> |
| friend class ShapeTree; |
| |
| public: |
| // TODO(cjfj): Don't store ShapeIndex with data. Generate it or cache it? |
| using Node = std::pair<ShapeIndex, T>; |
| using Nodes = absl::InlinedVector<Node, 1>; |
| using IndexTable = internal::IndexTable; |
| |
| template <typename Iterator, typename ValueType> |
| class LeafIterator; |
| |
| // Default constructor creates a tree with a nil shape (i.e. an empty tuple). |
| ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} |
| |
| // Create ShapeTree with the given shape, and default-constructed T values for |
| // all nodes. |
| // |
| // The version that takes a pointer may be cheaper because it doesn't require |
| // any Shape copies, but then it's up to you to ensure that the pointer stays |
| // alive longer than this ShapeTree. |
| explicit ShapeTree(Shape shape) |
| : ShapeTree(std::make_shared<Shape>(std::move(shape))) {} |
| |
| explicit ShapeTree(const Shape* shape) |
| : ShapeTree(shape, CreateNodes(*shape)) {} |
| |
| // Create ShapeTree with the given shape, and init_value for all nodes. |
| ShapeTree(Shape shape, const T& init_value) |
| : ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {} |
| |
| ShapeTree(const Shape* shape, const T& init_value) |
| : ShapeTree(shape, CreateNodes(*shape, [&] { return init_value; })) {} |
| |
| // Returns the data element associated with the array in the shape at the |
| // given index (see ShapeUtil::GetSubshape for how indexes are defined). |
| const T& element(ShapeIndexView index) const { return find(index)->second; } |
| T* mutable_element(ShapeIndexView index) { return &find(index)->second; } |
| |
| // Return the shape represented with this ShapeTree. |
| const Shape& shape() const { return *shape_; } |
| |
| // A ShapeTree object can own the underlying Shape pointer (via the |
| // shape_storage_ member), or can point to a Shape object owned by the caller. |
| // This API replaces the underlying Shape object to the one supplied by the |
| // caller, whom must ensure the object remain valid for the whole lifetime of |
| // this ShapeTree object, and also that the Shape is consistent with it. |
| void replace_shape_ptr(const Shape& shape) { |
| if (shape_storage_ != nullptr) { |
| DCHECK_EQ(shape, *shape_storage_); |
| shape_storage_ = nullptr; |
| } |
| shape_ = &shape; |
| } |
| |
| // Returns true if the node at the given index is a leaf node (an array |
| // shape). |
| bool IsLeaf(ShapeIndexView index) const { |
| return Lookup(index).children_start_id == -1; |
| } |
| |
| using iterator = typename Nodes::iterator; |
| using const_iterator = typename Nodes::const_iterator; |
| using reverse_iterator = typename Nodes::reverse_iterator; |
| using const_reverse_iterator = typename Nodes::const_reverse_iterator; |
| |
| using leaf_iterator = LeafIterator<iterator, Node>; |
| using const_leaf_iterator = LeafIterator<const_iterator, const Node>; |
| using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>; |
| using const_reverse_leaf_iterator = |
| std::reverse_iterator<const_leaf_iterator>; |
| |
| iterator begin() { return nodes_.begin(); } |
| iterator end() { return nodes_.end(); } |
| const_iterator begin() const { return nodes_.begin(); } |
| const_iterator end() const { return nodes_.end(); } |
| |
| reverse_iterator rbegin() { return nodes_.rbegin(); } |
| reverse_iterator rend() { return nodes_.rend(); } |
| const_reverse_iterator rbegin() const { return nodes_.rbegin(); } |
| const_reverse_iterator rend() const { return nodes_.rend(); } |
| |
| // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no |
| // children). We check for empty tuples here, so the leaf iterator's `IsLeaf` |
| // check is simpler. |
| leaf_iterator leaf_begin() { |
| return ShapeUtil::IsEmptyTuple(*shape_) |
| ? leaf_end() |
| : leaf_iterator(nodes_, nodes_.begin()); |
| } |
| leaf_iterator leaf_end() { return leaf_iterator(nodes_, nodes_.end()); } |
| const_leaf_iterator leaf_begin() const { |
| return ShapeUtil::IsEmptyTuple(*shape_) |
| ? leaf_end() |
| : const_leaf_iterator(nodes_, nodes_.begin()); |
| } |
| const_leaf_iterator leaf_end() const { |
| return const_leaf_iterator(nodes_, nodes_.end()); |
| } |
| // range-based iterator for leaf_begin()/leaf_end(). |
| tensorflow::gtl::iterator_range<leaf_iterator> leaves() { |
| return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); |
| } |
| tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const { |
| return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); |
| } |
| |
| reverse_leaf_iterator leaf_rbegin() { |
| return reverse_leaf_iterator(leaf_end()); |
| } |
| reverse_leaf_iterator leaf_rend() { |
| return reverse_leaf_iterator(leaf_begin()); |
| } |
| const_reverse_leaf_iterator leaf_rbegin() const { |
| return const_reverse_leaf_iterator(leaf_end()); |
| } |
| const_reverse_leaf_iterator leaf_rend() const { |
| return const_reverse_leaf_iterator(leaf_begin()); |
| } |
| |
| // Returns an iterator pointing to the given ShapeIndex. |
| // REQUIRES: index must exist in the ShapeTree. |
| iterator find(ShapeIndexView index) { |
| return nodes_.begin() + Lookup(index).node_id; |
| } |
| const_iterator find(ShapeIndexView index) const { |
| return nodes_.begin() + Lookup(index).node_id; |
| } |
| |
| // Returns the number of leaf nodes in the tree. |
| int64_t leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } |
| |
| // TODO(cjfj): Remove the `ForEach...` methods. They are redundant. |
| // Recursively traverses the shape and calls the given function at each |
| // element. |
| void ForEachElement( |
| absl::FunctionRef<void(const ShapeIndex&, const T&)> func) const { |
| for (const Node& node : nodes_) { |
| func(node.first, node.second); |
| } |
| } |
| |
| void ForEachMutableElement( |
| absl::FunctionRef<void(const ShapeIndex&, T*)> func) { |
| for (Node& node : nodes_) { |
| func(node.first, &node.second); |
| } |
| } |
| |
| // Like ForEach(Mutable)Element, but the callable returns a Status instead of |
| // void. The first non-OK return value is returned by the ForEach* function. |
| Status ForEachElementWithStatus( |
| absl::FunctionRef<Status(const ShapeIndex&, const T&)> func) const { |
| for (const Node& node : nodes_) { |
| TF_RETURN_IF_ERROR(func(node.first, node.second)); |
| } |
| return Status::OK(); |
| } |
| |
| Status ForEachMutableElementWithStatus( |
| absl::FunctionRef<Status(const ShapeIndex&, T*)> func) { |
| for (Node& node : nodes_) { |
| TF_RETURN_IF_ERROR(func(node.first, &node.second)); |
| } |
| return Status::OK(); |
| } |
| |
| // Maps each element to generate a new tree with the same shape. |
| template <typename U> |
| ShapeTree<U> Map(absl::FunctionRef<U(const T&)> func) { |
| typename ShapeTree<U>::Nodes result_nodes; |
| result_nodes.reserve(nodes_.size()); |
| for (const Node& node : nodes_) { |
| result_nodes.push_back({node.first, func(node.second)}); |
| } |
| |
| ShapeTree<U> result(shape_, std::move(result_nodes)); |
| result.index_table_ = index_table_; |
| result.shape_storage_ = shape_storage_; |
| return result; |
| } |
| |
| // Copy the subtree of values from 'other' rooted at ShapeIndex 'src_index' |
| // into the subtree of value in this ShapeTree rooted at 'dst_index'. |
| // |
| // Precondition: The subshape of other.shape() at index src_index must be |
| // compatible with the subshape of shape() at index dst_index. |
| void CopySubtreeFrom(const ShapeTree<T>& other, const ShapeIndex& src_index, |
| const ShapeIndex& dst_index) { |
| const Shape& src_shape = ShapeUtil::GetSubshape(other.shape(), src_index); |
| const Shape& dst_shape = ShapeUtil::GetSubshape(shape(), dst_index); |
| CHECK(ShapeUtil::Compatible(src_shape, dst_shape)) |
| << src_shape << ", " << dst_shape; |
| |
| // Replace the prefix `src_index` with `dst_index`. |
| auto replace_shape_index_prefix = [&](const ShapeIndex& index) { |
| auto without_prefix = ShapeIndexView(index).subspan(src_index.size()); |
| ShapeIndex result; |
| result.reserve(dst_index.size() + without_prefix.size()); |
| result.insert(result.end(), dst_index.begin(), dst_index.end()); |
| result.insert(result.end(), without_prefix.begin(), without_prefix.end()); |
| return result; |
| }; |
| |
| auto first = other.find(src_index); |
| auto last = first + ShapeUtil::SubshapeCount(src_shape); |
| |
| std::transform(first, last, find(dst_index), [&](const Node& node) -> Node { |
| return {replace_shape_index_prefix(node.first), node.second}; |
| }); |
| } |
| |
| StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const { |
| TF_ASSIGN_OR_RETURN(const Shape* sub_shape, |
| ShapeUtil::TryGetSubshape(shape(), index)); |
| size_t count = ShapeUtil::SubshapeCount(*sub_shape); |
| Nodes sub_tree_nodes; |
| sub_tree_nodes.reserve(count); |
| for (auto it = find(index), end = it + count; it != end; ++it) { |
| // For each shape index, remove the prefix `index`. |
| auto without_prefix = ShapeIndexView(it->first).subspan(index.size()); |
| sub_tree_nodes.push_back(Node{without_prefix, it->second}); |
| } |
| return ShapeTree(sub_shape, std::move(sub_tree_nodes)); |
| } |
| |
| bool operator==(const ShapeTree<T>& other) const { |
| return nodes_ == other.nodes_; |
| } |
| bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); } |
| |
| private: |
| explicit ShapeTree(std::shared_ptr<Shape> shape) : ShapeTree(shape.get()) { |
| shape_storage_.swap(shape); |
| } |
| |
| ShapeTree(std::shared_ptr<Shape> shape, const T& init_value) |
| : ShapeTree(shape.get(), init_value) { |
| shape_storage_.swap(shape); |
| } |
| |
| ShapeTree(const Shape* shape, Nodes nodes) |
| : nodes_(std::move(nodes)), shape_(shape) { |
| DCHECK_EQ(nodes_.size(), ShapeUtil::SubshapeCount(*shape)); |
| } |
| |
| static Nodes CreateNodes( |
| const Shape& shape, absl::FunctionRef<T()> gen = [] { return T(); }) { |
| Nodes nodes; |
| ShapeUtil::ForEachSubshape(shape, |
| [&](const Shape&, const ShapeIndex& index) { |
| nodes.push_back({index, gen()}); |
| }); |
| return nodes; |
| } |
| |
| // Returns the index table entry for the given shape index. |
| const IndexTable::Entry& Lookup(ShapeIndexView index) const { |
| // The index table is evaluated lazily. |
| if (index_table_.empty()) index_table_ = IndexTable(*shape_); |
| return index_table_[index]; |
| } |
| |
| // The nodes in this shape tree. |
| Nodes nodes_; |
| |
| // Index table for node lookups. Each entry contains the index of the first |
| // child of the node at that index, or -1 for leaf nodes. Evaluated lazily. |
| mutable IndexTable index_table_; |
| |
| // If we own our Shape, this field contains it, and shape_ is a pointer into |
| // here. Otherwise if we don't own our shape, this is nullptr. |
| std::shared_ptr<Shape> shape_storage_; |
| |
| // The XLA shape mirrored in this ShapeTree. This is either |
| // shape_storage_.get() or the Shape pointer passed to our constructor. |
| const Shape* shape_; |
| }; |
| |
| // Internal iterator that performs a pre-order walk of the leaves. This is cheap |
| // to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&, |
| // similar to std::map. |
| template <typename T> |
| template <typename Iterator, typename ValueType> |
| class ShapeTree<T>::LeafIterator |
| : public std::iterator<std::bidirectional_iterator_tag, ValueType> { |
| public: |
| LeafIterator(const Nodes& nodes, Iterator it) : nodes_(nodes), it_(it) { |
| while ((it_ != nodes.end()) && !IsLeaf()) ++it_; |
| } |
| |
| LeafIterator& operator++() { |
| do { |
| ++it_; |
| } while ((it_ != nodes_.end()) && !IsLeaf()); |
| return *this; |
| } |
| |
| LeafIterator operator++(int) { |
| auto prev = *this; |
| ++(*this); |
| return prev; |
| } |
| |
| LeafIterator& operator--() { |
| do { |
| --it_; |
| } while ((it_ != nodes_.begin()) && !IsLeaf()); |
| return *this; |
| } |
| |
| LeafIterator operator--(int) { |
| auto prev = *this; |
| --(*this); |
| return prev; |
| } |
| |
| bool operator==(const LeafIterator& other) const { return it_ == other.it_; } |
| bool operator!=(const LeafIterator& other) const { return !(*this == other); } |
| ValueType& operator*() const { return *it_; } |
| ValueType* operator->() const { return &*it_; } |
| |
| private: |
| bool IsLeaf() const { |
| auto next = it_ + 1; |
| // If the node is not a leaf, the next node will have a longer shape index. |
| return (next == nodes_.end()) || (it_->first.size() >= next->first.size()); |
| } |
| |
| const Nodes& nodes_; |
| Iterator it_; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ |