blob: 03ccbd2c40c743a05347b24edb66675f418b0706 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
#define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace internal {
class IndexTable {
public:
// Use indices, rather than pointers, so index table can be copied between
// ShapeTrees.
struct Entry {
// Index of the node in the nodes vector.
size_t node_id;
// Index of the first child of this node in the index table (-1 for leaves).
std::make_signed_t<size_t> children_start_id = -1;
};
IndexTable() = default;
explicit IndexTable(const Shape& shape);
bool empty() const { return entries_.empty(); }
const Entry& operator[](ShapeIndexView index) const;
private:
void CreateEntry(Entry& entry, const Shape& shape, size_t& next_node_id);
absl::InlinedVector<Entry, 1> entries_;
};
} // namespace internal
// A ShapeTree<T> is a recursive data structure which mirrors the structure of a
// XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
// in the shape. For array shapes, a ShapeTree trivially holds a single value of
// type T.
//
// For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
// ShapeTree is an identically structured tree with data elements of type T at
// every node. I.e. the root is a tuple by definition, all interior nodes are
// also tuples, and all leaves are arrays.
//
// Like the Shape data structure, this is a tree and tuple elements cannot be
// duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
// object.
//
// Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
// it's helpful not to copy a Shape just to make a ShapeTree. In these cases,
// you can pass a Shape* instead of a Shape to the ShapeTree constructor. It's
// then up to you to ensure that the pointed-to Shape isn't freed, moved or
// modified before its ShapeTree goes away.
template <typename T>
class ShapeTree {
template <typename U>
friend class ShapeTree;
public:
// TODO(cjfj): Don't store ShapeIndex with data. Generate it or cache it?
using Node = std::pair<ShapeIndex, T>;
using Nodes = absl::InlinedVector<Node, 1>;
using IndexTable = internal::IndexTable;
template <typename Iterator, typename ValueType>
class LeafIterator;
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
// Create ShapeTree with the given shape, and default-constructed T values for
// all nodes.
//
// The version that takes a pointer may be cheaper because it doesn't require
// any Shape copies, but then it's up to you to ensure that the pointer stays
// alive longer than this ShapeTree.
explicit ShapeTree(Shape shape)
: ShapeTree(std::make_shared<Shape>(std::move(shape))) {}
explicit ShapeTree(const Shape* shape)
: ShapeTree(shape, CreateNodes(*shape)) {}
// Create ShapeTree with the given shape, and init_value for all nodes.
ShapeTree(Shape shape, const T& init_value)
: ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {}
ShapeTree(const Shape* shape, const T& init_value)
: ShapeTree(shape, CreateNodes(*shape, [&] { return init_value; })) {}
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
const T& element(ShapeIndexView index) const { return find(index)->second; }
T* mutable_element(ShapeIndexView index) { return &find(index)->second; }
// Return the shape represented with this ShapeTree.
const Shape& shape() const { return *shape_; }
// A ShapeTree object can own the underlying Shape pointer (via the
// shape_storage_ member), or can point to a Shape object owned by the caller.
// This API replaces the underlying Shape object to the one supplied by the
// caller, whom must ensure the object remain valid for the whole lifetime of
// this ShapeTree object, and also that the Shape is consistent with it.
void replace_shape_ptr(const Shape& shape) {
if (shape_storage_ != nullptr) {
DCHECK_EQ(shape, *shape_storage_);
shape_storage_ = nullptr;
}
shape_ = &shape;
}
// Returns true if the node at the given index is a leaf node (an array
// shape).
bool IsLeaf(ShapeIndexView index) const {
return Lookup(index).children_start_id == -1;
}
using iterator = typename Nodes::iterator;
using const_iterator = typename Nodes::const_iterator;
using reverse_iterator = typename Nodes::reverse_iterator;
using const_reverse_iterator = typename Nodes::const_reverse_iterator;
using leaf_iterator = LeafIterator<iterator, Node>;
using const_leaf_iterator = LeafIterator<const_iterator, const Node>;
using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
using const_reverse_leaf_iterator =
std::reverse_iterator<const_leaf_iterator>;
iterator begin() { return nodes_.begin(); }
iterator end() { return nodes_.end(); }
const_iterator begin() const { return nodes_.begin(); }
const_iterator end() const { return nodes_.end(); }
reverse_iterator rbegin() { return nodes_.rbegin(); }
reverse_iterator rend() { return nodes_.rend(); }
const_reverse_iterator rbegin() const { return nodes_.rbegin(); }
const_reverse_iterator rend() const { return nodes_.rend(); }
// leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
// children). We check for empty tuples here, so the leaf iterator's `IsLeaf`
// check is simpler.
leaf_iterator leaf_begin() {
return ShapeUtil::IsEmptyTuple(*shape_)
? leaf_end()
: leaf_iterator(nodes_, nodes_.begin());
}
leaf_iterator leaf_end() { return leaf_iterator(nodes_, nodes_.end()); }
const_leaf_iterator leaf_begin() const {
return ShapeUtil::IsEmptyTuple(*shape_)
? leaf_end()
: const_leaf_iterator(nodes_, nodes_.begin());
}
const_leaf_iterator leaf_end() const {
return const_leaf_iterator(nodes_, nodes_.end());
}
// range-based iterator for leaf_begin()/leaf_end().
tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
}
tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const {
return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
}
reverse_leaf_iterator leaf_rbegin() {
return reverse_leaf_iterator(leaf_end());
}
reverse_leaf_iterator leaf_rend() {
return reverse_leaf_iterator(leaf_begin());
}
const_reverse_leaf_iterator leaf_rbegin() const {
return const_reverse_leaf_iterator(leaf_end());
}
const_reverse_leaf_iterator leaf_rend() const {
return const_reverse_leaf_iterator(leaf_begin());
}
// Returns an iterator pointing to the given ShapeIndex.
// REQUIRES: index must exist in the ShapeTree.
iterator find(ShapeIndexView index) {
return nodes_.begin() + Lookup(index).node_id;
}
const_iterator find(ShapeIndexView index) const {
return nodes_.begin() + Lookup(index).node_id;
}
// Returns the number of leaf nodes in the tree.
int64_t leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
// TODO(cjfj): Remove the `ForEach...` methods. They are redundant.
// Recursively traverses the shape and calls the given function at each
// element.
void ForEachElement(
absl::FunctionRef<void(const ShapeIndex&, const T&)> func) const {
for (const Node& node : nodes_) {
func(node.first, node.second);
}
}
void ForEachMutableElement(
absl::FunctionRef<void(const ShapeIndex&, T*)> func) {
for (Node& node : nodes_) {
func(node.first, &node.second);
}
}
// Like ForEach(Mutable)Element, but the callable returns a Status instead of
// void. The first non-OK return value is returned by the ForEach* function.
Status ForEachElementWithStatus(
absl::FunctionRef<Status(const ShapeIndex&, const T&)> func) const {
for (const Node& node : nodes_) {
TF_RETURN_IF_ERROR(func(node.first, node.second));
}
return Status::OK();
}
Status ForEachMutableElementWithStatus(
absl::FunctionRef<Status(const ShapeIndex&, T*)> func) {
for (Node& node : nodes_) {
TF_RETURN_IF_ERROR(func(node.first, &node.second));
}
return Status::OK();
}
// Maps each element to generate a new tree with the same shape.
template <typename U>
ShapeTree<U> Map(absl::FunctionRef<U(const T&)> func) {
typename ShapeTree<U>::Nodes result_nodes;
result_nodes.reserve(nodes_.size());
for (const Node& node : nodes_) {
result_nodes.push_back({node.first, func(node.second)});
}
ShapeTree<U> result(shape_, std::move(result_nodes));
result.index_table_ = index_table_;
result.shape_storage_ = shape_storage_;
return result;
}
// Copy the subtree of values from 'other' rooted at ShapeIndex 'src_index'
// into the subtree of value in this ShapeTree rooted at 'dst_index'.
//
// Precondition: The subshape of other.shape() at index src_index must be
// compatible with the subshape of shape() at index dst_index.
void CopySubtreeFrom(const ShapeTree<T>& other, const ShapeIndex& src_index,
const ShapeIndex& dst_index) {
const Shape& src_shape = ShapeUtil::GetSubshape(other.shape(), src_index);
const Shape& dst_shape = ShapeUtil::GetSubshape(shape(), dst_index);
CHECK(ShapeUtil::Compatible(src_shape, dst_shape))
<< src_shape << ", " << dst_shape;
// Replace the prefix `src_index` with `dst_index`.
auto replace_shape_index_prefix = [&](const ShapeIndex& index) {
auto without_prefix = ShapeIndexView(index).subspan(src_index.size());
ShapeIndex result;
result.reserve(dst_index.size() + without_prefix.size());
result.insert(result.end(), dst_index.begin(), dst_index.end());
result.insert(result.end(), without_prefix.begin(), without_prefix.end());
return result;
};
auto first = other.find(src_index);
auto last = first + ShapeUtil::SubshapeCount(src_shape);
std::transform(first, last, find(dst_index), [&](const Node& node) -> Node {
return {replace_shape_index_prefix(node.first), node.second};
});
}
StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const {
TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
ShapeUtil::TryGetSubshape(shape(), index));
size_t count = ShapeUtil::SubshapeCount(*sub_shape);
Nodes sub_tree_nodes;
sub_tree_nodes.reserve(count);
for (auto it = find(index), end = it + count; it != end; ++it) {
// For each shape index, remove the prefix `index`.
auto without_prefix = ShapeIndexView(it->first).subspan(index.size());
sub_tree_nodes.push_back(Node{without_prefix, it->second});
}
return ShapeTree(sub_shape, std::move(sub_tree_nodes));
}
bool operator==(const ShapeTree<T>& other) const {
return nodes_ == other.nodes_;
}
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
private:
explicit ShapeTree(std::shared_ptr<Shape> shape) : ShapeTree(shape.get()) {
shape_storage_.swap(shape);
}
ShapeTree(std::shared_ptr<Shape> shape, const T& init_value)
: ShapeTree(shape.get(), init_value) {
shape_storage_.swap(shape);
}
ShapeTree(const Shape* shape, Nodes nodes)
: nodes_(std::move(nodes)), shape_(shape) {
DCHECK_EQ(nodes_.size(), ShapeUtil::SubshapeCount(*shape));
}
static Nodes CreateNodes(
const Shape& shape, absl::FunctionRef<T()> gen = [] { return T(); }) {
Nodes nodes;
ShapeUtil::ForEachSubshape(shape,
[&](const Shape&, const ShapeIndex& index) {
nodes.push_back({index, gen()});
});
return nodes;
}
// Returns the index table entry for the given shape index.
const IndexTable::Entry& Lookup(ShapeIndexView index) const {
// The index table is evaluated lazily.
if (index_table_.empty()) index_table_ = IndexTable(*shape_);
return index_table_[index];
}
// The nodes in this shape tree.
Nodes nodes_;
// Index table for node lookups. Each entry contains the index of the first
// child of the node at that index, or -1 for leaf nodes. Evaluated lazily.
mutable IndexTable index_table_;
// If we own our Shape, this field contains it, and shape_ is a pointer into
// here. Otherwise if we don't own our shape, this is nullptr.
std::shared_ptr<Shape> shape_storage_;
// The XLA shape mirrored in this ShapeTree. This is either
// shape_storage_.get() or the Shape pointer passed to our constructor.
const Shape* shape_;
};
// Internal iterator that performs a pre-order walk of the leaves. This is cheap
// to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
// similar to std::map.
template <typename T>
template <typename Iterator, typename ValueType>
class ShapeTree<T>::LeafIterator
: public std::iterator<std::bidirectional_iterator_tag, ValueType> {
public:
LeafIterator(const Nodes& nodes, Iterator it) : nodes_(nodes), it_(it) {
while ((it_ != nodes.end()) && !IsLeaf()) ++it_;
}
LeafIterator& operator++() {
do {
++it_;
} while ((it_ != nodes_.end()) && !IsLeaf());
return *this;
}
LeafIterator operator++(int) {
auto prev = *this;
++(*this);
return prev;
}
LeafIterator& operator--() {
do {
--it_;
} while ((it_ != nodes_.begin()) && !IsLeaf());
return *this;
}
LeafIterator operator--(int) {
auto prev = *this;
--(*this);
return prev;
}
bool operator==(const LeafIterator& other) const { return it_ == other.it_; }
bool operator!=(const LeafIterator& other) const { return !(*this == other); }
ValueType& operator*() const { return *it_; }
ValueType* operator->() const { return &*it_; }
private:
bool IsLeaf() const {
auto next = it_ + 1;
// If the node is not a leaf, the next node will have a longer shape index.
return (next == nodes_.end()) || (it_->first.size() >= next->first.size());
}
const Nodes& nodes_;
Iterator it_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_