[JAX] Only perform casts from Python objects to PyBuffer* once.

These are slow enough that they show up in profiling.

I believe a significant factor is that constructing a pybind11 type caster involves a `std::unordered_map<>` lookup.

A better fix might be to avoid pybind11's casting facilities, but that is a larger change.

PiperOrigin-RevId: 363176446
Change-Id: I645af27b1239b0add438e0fa004ee0b2c3295e96
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index d82a1d2..98b874a 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -604,6 +604,9 @@
   arg_buffers.reserve(num_flat_dynamic_args);
   arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args);
 
+  absl::InlinedVector<xla::PyBuffer*, 4> py_buffers;
+  py_buffers.resize(num_flat_dynamic_args, nullptr);
+
   struct PythonTypes {
     py::object device_array;
     py::object py_buffer_type;
@@ -626,12 +629,14 @@
   if (is_committed) {
     data_device = default_device;
   } else {
-    for (py::handle arg : arguments.flat_dynamic_args) {
+    for (int i = 0; i < num_flat_dynamic_args; ++i) {
+      py::handle arg = arguments.flat_dynamic_args[i];
       // We specically only deal with DeviceArray (not ShardedDeviceArray).
       // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
       xla::PjRtDevice* device = nullptr;
       if (arg.get_type().ptr() == types.py_buffer_type.ptr()) {
         xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(arg);
+        py_buffers[i] = buffer;
         if (!buffer->sticky_device()) {
           continue;
         }
@@ -681,9 +686,10 @@
   // TODO(phawkins): consider allowing forces here.
   options.force_lazy_arrays = false;
   options.allow_zero_copy = true;
-  for (py::handle arg : arguments.flat_dynamic_args) {
+  for (int i = 0; i < num_flat_dynamic_args; ++i) {
+    py::handle arg = arguments.flat_dynamic_args[i];
     TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
-                        DevicePut(arg, data_device, options));
+                        DevicePut(arg, data_device, options, py_buffers[i]));
 
     xla::PjRtBuffer* buffer = on_device.buffer;
     arg_buffers.push_back(buffer);
diff --git a/tensorflow/compiler/xla/python/py_values.cc b/tensorflow/compiler/xla/python/py_values.cc
index cecac80..3850a03 100644
--- a/tensorflow/compiler/xla/python/py_values.cc
+++ b/tensorflow/compiler/xla/python/py_values.cc
@@ -152,8 +152,8 @@
 }
 
 StatusOr<DevicePutResult> PyBufferHelper(py::handle obj, py::handle py_buffer,
+                                         PyBuffer* buffer,
                                          PjRtDevice* to_device) {
-  PyBuffer* buffer = py::cast<PyBuffer*>(py_buffer);
   bool weak_type = buffer->weak_type()
                        ? *buffer->weak_type()
                        : py::cast<bool>(obj.attr("aval").attr("weak_type"));
@@ -170,7 +170,7 @@
 
 StatusOr<DevicePutResult> HandlePyBuffer(py::handle obj, PjRtDevice* to_device,
                                          const DevicePutOptions& options) {
-  return PyBufferHelper(obj, obj, to_device);
+  return PyBufferHelper(obj, obj, py::cast<PyBuffer*>(obj), to_device);
 }
 
 StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
@@ -201,14 +201,18 @@
     obj = forced;
   }
 
-  return PyBufferHelper(obj, buffer, to_device);
+  return PyBufferHelper(obj, buffer, py::cast<PyBuffer*>(buffer), to_device);
 }
 
 }  // namespace
 
 StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
-                                    const DevicePutOptions& options) {
+                                    const DevicePutOptions& options,
+                                    PyBuffer* py_buffer) {
   tensorflow::profiler::TraceMe traceme("DevicePut");
+  if (py_buffer) {
+    return PyBufferHelper(arg, arg, py_buffer, to_device);
+  }
   static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
       [] {
         auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
diff --git a/tensorflow/compiler/xla/python/py_values.h b/tensorflow/compiler/xla/python/py_values.h
index ceffe74..61d6939 100644
--- a/tensorflow/compiler/xla/python/py_values.h
+++ b/tensorflow/compiler/xla/python/py_values.h
@@ -53,6 +53,8 @@
 // If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be
 // returned; float0s and `_DeviceArray`s with non-trivial LazyExprs are not
 // supported yet.
+// If the value is known to be a PyBuffer object, py_buffer can be passed as
+// an optimization to avoid a Python->C++ cast.
 //
 // May throw exceptions from pybind11 in addition to failing via an error
 // Status. (We could catch these if needed, but there seems little point.)
@@ -62,7 +64,8 @@
   bool force_lazy_arrays = true;
 };
 StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
-                                    const DevicePutOptions& options);
+                                    const DevicePutOptions& options,
+                                    PyBuffer* py_buffer = nullptr);
 
 }  // namespace xla