Make FlattenInto return py::handles not py::objects.

PiperOrigin-RevId: 326446701
Change-Id: I3f35a3075dda8c4aac1db9be76ac22ba1218804d
diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc
index 401b7bc..58d6a58 100644
--- a/tensorflow/compiler/xla/python/pytree.cc
+++ b/tensorflow/compiler/xla/python/pytree.cc
@@ -107,7 +107,7 @@
 }
 
 void PyTreeDef::FlattenInto(py::handle handle,
-                            std::vector<py::object>& leaves) {
+                            std::vector<py::handle>& leaves) {
   Node node;
   int start_num_nodes = traversal_.size();
   int start_num_leaves = leaves.size();
@@ -158,19 +158,23 @@
     }
   } else {
     assert(node.kind == Kind::kLeaf);
-    leaves.push_back(py::reinterpret_borrow<py::object>(handle));
+    leaves.push_back(handle);
   }
   node.num_nodes = traversal_.size() - start_num_nodes + 1;
   node.num_leaves = leaves.size() - start_num_leaves;
   traversal_.push_back(std::move(node));
 }
 
-/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
-PyTreeDef::Flatten(py::handle x) {
-  std::vector<py::object> leaves;
+/*static*/ std::pair<py::list, std::unique_ptr<PyTreeDef>> PyTreeDef::Flatten(
+    py::handle x) {
+  std::vector<py::handle> leaves;
   auto tree = absl::make_unique<PyTreeDef>();
   tree->FlattenInto(x, leaves);
-  return std::make_pair(std::move(leaves), std::move(tree));
+  py::list outputs(leaves.size());
+  for (int i = 0; i < leaves.size(); ++i) {
+    outputs[i] = py::reinterpret_borrow<py::object>(leaves[i]);
+  }
+  return std::make_pair(std::move(outputs), std::move(tree));
 }
 
 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h
index 69cd93a..76fd76f 100644
--- a/tensorflow/compiler/xla/python/pytree.h
+++ b/tensorflow/compiler/xla/python/pytree.h
@@ -84,12 +84,12 @@
   PyTreeDef() = default;
 
   // Flattens a Pytree into a list of leaves and a PyTreeDef.
-  static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
-  Flatten(pybind11::handle x);
+  static std::pair<pybind11::list, std::unique_ptr<PyTreeDef>> Flatten(
+      pybind11::handle x);
 
   // Recursive helper used to implement Flatten().
   void FlattenInto(pybind11::handle handle,
-                   std::vector<pybind11::object>& leaves);
+                   std::vector<pybind11::handle>& leaves);
 
   // Tests whether the given list is a flat list of leaves.
   static bool AllLeaves(const pybind11::iterable& x);