blob: 76fd76fad6a1631930ad2ba2b1d8d6dd1f679b7a [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
// See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
// about pytree.
// Caution: this code uses exceptions. The exception use is local to the
// binding code and the idiomatic way to emit Python exceptions.
#include <memory>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/hash/hash.h"
#include "absl/memory/memory.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
namespace xla {
// Registry of custom node types.
class CustomNodeRegistry {
public:
struct Registration {
// The Python type object, used to identify the type.
pybind11::object type;
// A function with signature: object -> (iterable, aux_data)
pybind11::function to_iterable;
// A function with signature: (aux_data, iterable) -> object
pybind11::function from_iterable;
};
// Registers a new custom type. Objects of `type` will be treated as container
// node types in PyTrees.
static void Register(pybind11::object type, pybind11::function to_iterable,
pybind11::function from_iterable);
// Finds the custom type registration for `type`. Returns nullptr if none
// exists.
static const Registration* Lookup(pybind11::handle type);
private:
static CustomNodeRegistry* Singleton();
struct TypeHash {
size_t operator()(const pybind11::object& t) const {
return pybind11::hash(t);
}
};
struct TypeEq {
bool operator()(const pybind11::object& a,
const pybind11::object& b) const {
return a.equal(b);
}
};
absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
TypeEq>
registrations_;
};
// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
// Python values, where the interior nodes are tuples, lists, dictionaries, or
// user-defined containers, and the leaves are other objects.
class PyTreeDef {
public:
PyTreeDef() = default;
// Flattens a Pytree into a list of leaves and a PyTreeDef.
static std::pair<pybind11::list, std::unique_ptr<PyTreeDef>> Flatten(
pybind11::handle x);
// Recursive helper used to implement Flatten().
void FlattenInto(pybind11::handle handle,
std::vector<pybind11::handle>& leaves);
// Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const pybind11::iterable& x);
// Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
// the tree-structure of 'x'. For example, if we flatten a value
// [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
// list of leaves [1, (2, 3), {"foo": 4}].
pybind11::list FlattenUpTo(pybind11::handle x) const;
// Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
pybind11::object Unflatten(pybind11::iterable leaves) const;
// Composes two PyTreeDefs, replacing the leaves of this tree with copies of
// `inner`.
std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
// Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
std::vector<std::unique_ptr<PyTreeDef>> Children() const;
// Maps a function over a PyTree structure, applying f_leaf to each leaf, and
// f_node to each container node.
// TODO(phawkins): use flattening everywhere instead and delete this method.
pybind11::object Walk(const pybind11::function& f_node,
pybind11::handle f_leaf,
pybind11::iterable leaves) const;
// Given a tree of iterables with the same node/leaf structure as this PyTree,
// build the corresponding PyTree.
// TODO(phawkins): use flattening everywhere instead and delete this method.
pybind11::object FromIterableTree(pybind11::handle xs) const;
int num_leaves() const {
if (traversal_.empty()) {
return 0;
}
return traversal_.back().num_leaves;
}
int num_nodes() const { return traversal_.size(); }
size_t Hash() const;
bool operator==(const PyTreeDef& other) const;
bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
std::string ToString() const;
private:
enum class Kind {
kLeaf, // An opaque leaf node
kNone, // None.
kTuple, // A tuple
kNamedTuple, // A collections.namedtuple
kList, // A list
kDict, // A dict
kCustom, // A custom type.
};
struct Node {
Kind kind = Kind::kLeaf;
// Arity for non-kLeaf types.
int arity = 0;
// Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
// object. For a kDict, contains a sorted list of keys. For a kCustom type,
// contains the auxiliary data returned by the `to_iterable` function.
pybind11::object node_data;
const CustomNodeRegistry::Registration* custom = nullptr;
// Number of leaf nodes in the subtree rooted at this node.
int num_leaves = 0;
// Number of leaf and interior nodes in the subtree rooted at this node.
int num_nodes = 0;
};
template <typename H>
friend H AbslHashValue(H h, const Node& n);
template <typename H>
friend H AbslHashValue(H h, const PyTreeDef& t);
// Helper that manufactures an instance of a node given its children.
static pybind11::object MakeNode(const Node& node,
absl::Span<pybind11::object> children);
// Recursive helper used to implement FromIterableTree()
pybind11::object FromIterableTreeHelper(
pybind11::handle xs,
std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const;
// Computes the node kind of a given Python object.
static Kind GetKind(const pybind11::handle& obj,
CustomNodeRegistry::Registration const** custom);
// Nodes, in a post-order traversal. We use an ordered traversal to minimize
// allocations, and post-order corresponds to the order we need to rebuild the
// tree structure.
std::vector<Node> traversal_;
};
template <typename H>
H AbslHashValue(H h, const PyTreeDef::Node& n) {
h = H::combine(std::move(h), n.kind, n.arity, n.custom);
return h;
}
template <typename H>
H AbslHashValue(H h, const PyTreeDef& t) {
return H::combine_contiguous(std::move(h), t.traversal_.data(),
t.traversal_.size());
}
void BuildPytreeSubmodule(pybind11::module& m);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_