Make FlattenInto return py::handles not py::objects.
PiperOrigin-RevId: 326446701
Change-Id: I3f35a3075dda8c4aac1db9be76ac22ba1218804d
diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc
index 401b7bc..58d6a58 100644
--- a/tensorflow/compiler/xla/python/pytree.cc
+++ b/tensorflow/compiler/xla/python/pytree.cc
@@ -107,7 +107,7 @@
}
void PyTreeDef::FlattenInto(py::handle handle,
- std::vector<py::object>& leaves) {
+ std::vector<py::handle>& leaves) {
Node node;
int start_num_nodes = traversal_.size();
int start_num_leaves = leaves.size();
@@ -158,19 +158,23 @@
}
} else {
assert(node.kind == Kind::kLeaf);
- leaves.push_back(py::reinterpret_borrow<py::object>(handle));
+ leaves.push_back(handle);
}
node.num_nodes = traversal_.size() - start_num_nodes + 1;
node.num_leaves = leaves.size() - start_num_leaves;
traversal_.push_back(std::move(node));
}
-/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
-PyTreeDef::Flatten(py::handle x) {
- std::vector<py::object> leaves;
+/*static*/ std::pair<py::list, std::unique_ptr<PyTreeDef>> PyTreeDef::Flatten(
+ py::handle x) {
+ std::vector<py::handle> leaves;
auto tree = absl::make_unique<PyTreeDef>();
tree->FlattenInto(x, leaves);
- return std::make_pair(std::move(leaves), std::move(tree));
+ py::list outputs(leaves.size());
+ for (int i = 0; i < leaves.size(); ++i) {
+ outputs[i] = py::reinterpret_borrow<py::object>(leaves[i]);
+ }
+ return std::make_pair(std::move(outputs), std::move(tree));
}
/*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h
index 69cd93a..76fd76f 100644
--- a/tensorflow/compiler/xla/python/pytree.h
+++ b/tensorflow/compiler/xla/python/pytree.h
@@ -84,12 +84,12 @@
PyTreeDef() = default;
// Flattens a Pytree into a list of leaves and a PyTreeDef.
- static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
- Flatten(pybind11::handle x);
+ static std::pair<pybind11::list, std::unique_ptr<PyTreeDef>> Flatten(
+ pybind11::handle x);
// Recursive helper used to implement Flatten().
void FlattenInto(pybind11::handle handle,
- std::vector<pybind11::object>& leaves);
+ std::vector<pybind11::handle>& leaves);
// Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const pybind11::iterable& x);