[JAX] Refactor jax_jit to avoid DevicePut on pruned args.

name                                  old cpu/op  new cpu/op  delta
eager_unary_dispatch                  35.7µs ± 2%  35.9µs ± 3%     ~     (p=0.841 n=5+5)
eager_unary                           36.4µs ± 2%  36.6µs ± 3%     ~     (p=0.421 n=5+5)
eager_binary_dispatch                 45.6µs ± 1%  46.1µs ± 2%     ~     (p=0.421 n=5+5)
eager_binary                          46.6µs ± 2%  47.0µs ± 5%     ~     (p=1.000 n=5+5)
jit_trivial_dispatch                  41.4µs ± 1%  41.4µs ± 0%     ~     (p=0.690 n=5+5)
jit_trivial                           42.4µs ± 1%  42.3µs ± 1%     ~     (p=0.841 n=5+5)
jit_simple_dispatch                   8.85µs ± 3%  9.15µs ± 3%     ~     (p=0.095 n=5+5)
jit_simple                            9.77µs ± 1%  9.82µs ± 2%     ~     (p=0.548 n=5+5)
jit_simple_many_args_dispatch_10      13.4µs ± 1%  13.6µs ± 3%     ~     (p=0.222 n=5+5)
jit_simple_many_args_10               14.0µs ± 2%  14.1µs ± 1%     ~     (p=0.421 n=5+5)
jit_simple_pruned_args_dispatch_10    8.05µs ± 3%  8.07µs ± 4%     ~     (p=0.841 n=5+5)
jit_simple_pruned_args_10             9.53µs ± 2%  9.43µs ± 2%     ~     (p=0.222 n=5+5)
jit_simple_many_args_dispatch_100     55.2µs ± 1%  54.8µs ± 2%     ~     (p=0.310 n=5+5)
jit_simple_many_args_100              55.8µs ± 1%  55.8µs ± 1%     ~     (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_100   14.3µs ± 4%  12.6µs ± 1%  -11.41%  (p=0.016 n=5+4)
jit_simple_pruned_args_100            14.8µs ± 1%  13.3µs ± 2%  -10.06%  (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000     489µs ± 1%   477µs ± 3%     ~     (p=0.056 n=5+5)
jit_simple_many_args_1000              495µs ± 3%   493µs ± 3%     ~     (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_1000  85.0µs ± 3%  65.3µs ± 3%  -23.13%  (p=0.008 n=5+5)
jit_simple_pruned_args_1000           86.0µs ± 3%  66.4µs ± 3%  -22.78%  (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000    1.09ms ± 4%  1.03ms ± 3%   -5.97%  (p=0.016 n=5+5)
jit_simple_many_args_2000             1.07ms ± 3%  1.04ms ± 5%     ~     (p=0.095 n=5+5)
jit_simple_pruned_args_dispatch_2000   190µs ± 3%   144µs ± 3%  -23.96%  (p=0.008 n=5+5)
jit_simple_pruned_args_2000            195µs ± 4%   147µs ± 3%  -24.29%  (p=0.008 n=5+5)
jit_dispatch_without_transfer         76.0µs ± 1%  77.2µs ± 6%     ~     (p=0.310 n=5+5)
jit_dispatch_with_transfer            82.1µs ± 5%  81.3µs ± 2%     ~     (p=0.421 n=5+5)
sda_index_1                           8.83µs ± 1%  8.73µs ± 2%     ~     (p=0.222 n=5+5)

name                                  old time/op             new time/op             delta
eager_unary_dispatch                  35.7µs ± 2%             35.9µs ± 3%     ~             (p=0.841 n=5+5)
eager_unary                           36.5µs ± 2%             37.1µs ± 4%     ~             (p=0.222 n=5+5)
eager_binary_dispatch                 45.6µs ± 1%             46.1µs ± 2%     ~             (p=0.421 n=5+5)
eager_binary                          46.8µs ± 3%             47.1µs ± 5%     ~             (p=1.000 n=5+5)
jit_trivial_dispatch                  41.4µs ± 1%             41.4µs ± 0%     ~             (p=0.690 n=5+5)
jit_trivial                           42.4µs ± 1%             42.3µs ± 1%     ~             (p=0.841 n=5+5)
jit_simple_dispatch                   8.86µs ± 3%             9.15µs ± 3%     ~             (p=0.095 n=5+5)
jit_simple                            9.82µs ± 1%             9.91µs ± 0%     ~             (p=0.190 n=5+4)
jit_simple_many_args_dispatch_10      13.4µs ± 1%             13.6µs ± 4%     ~             (p=0.310 n=5+5)
jit_simple_many_args_10               14.1µs ± 2%             14.2µs ± 1%     ~             (p=0.421 n=5+5)
jit_simple_pruned_args_dispatch_10    8.07µs ± 4%             8.07µs ± 4%     ~             (p=0.841 n=5+5)
jit_simple_pruned_args_10             9.59µs ± 2%             9.48µs ± 2%     ~             (p=0.222 n=5+5)
jit_simple_many_args_dispatch_100     55.2µs ± 1%             54.8µs ± 2%     ~             (p=0.310 n=5+5)
jit_simple_many_args_100              55.9µs ± 1%             55.9µs ± 1%     ~             (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_100   14.3µs ± 5%             12.6µs ± 1%  -11.75%          (p=0.016 n=5+4)
jit_simple_pruned_args_100            14.8µs ± 2%             13.3µs ± 2%  -10.19%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000     489µs ± 1%              477µs ± 3%     ~             (p=0.056 n=5+5)
jit_simple_many_args_1000              495µs ± 3%              493µs ± 3%     ~             (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_1000  85.0µs ± 3%             65.3µs ± 3%  -23.13%          (p=0.008 n=5+5)
jit_simple_pruned_args_1000           86.1µs ± 3%             66.5µs ± 2%  -22.72%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000    1.09ms ± 4%             1.03ms ± 3%   -5.96%          (p=0.016 n=5+5)
jit_simple_many_args_2000             1.07ms ± 3%             1.04ms ± 5%     ~             (p=0.095 n=5+5)
jit_simple_pruned_args_dispatch_2000   190µs ± 3%              144µs ± 3%  -23.97%          (p=0.008 n=5+5)
jit_simple_pruned_args_2000            195µs ± 4%              147µs ± 3%  -24.31%          (p=0.008 n=5+5)
jit_dispatch_without_transfer         1.41ms ± 1%             1.40ms ± 1%     ~             (p=0.095 n=5+5)
jit_dispatch_with_transfer            1.40ms ± 2%             1.40ms ± 2%     ~             (p=0.841 n=5+5)
sda_index_1                           8.83µs ± 1%             8.73µs ± 2%     ~             (p=0.222 n=5+5)

PiperOrigin-RevId: 374468578
Change-Id: I0a45af35b936a72f8271bd3e3a66e0d778619132
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index f62b5d0..f9b3fbd 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -540,18 +540,14 @@
 
 CompiledFunction::~CompiledFunction() = default;
 
-// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
-// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
+// Compute signature for arguments.
 //
 // Returns `Status::OK()` on success. Returning an error should lead to
 // calling the Python fallback.
-xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
-                                 xla::PjRtDevice* default_device,
-                                 bool is_committed,
-                                 ParsedArgumentsAsBuffers& arguments) {
-  tensorflow::profiler::TraceMe traceme("ConvertArgsToBuffers");
-  std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
-  auto& keep_alive = arguments.keep_alive;
+xla::Status ComputeSignature(bool jax_enable_x64, xla::PyClient& pyclient,
+                             xla::PjRtDevice* default_device, bool is_committed,
+                             ParsedArgumentsAsBuffers& arguments) {
+  tensorflow::profiler::TraceMe traceme("ComputeSignature");
 
   int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
   struct PythonTypes {
@@ -624,14 +620,38 @@
   CHECK(data_device);
   arguments.signature.device = data_device;
 
+  arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
+  for (int i = 0; i < num_flat_dynamic_args; ++i) {
+    py::handle arg = arguments.flat_dynamic_args[i];
+    TF_ASSIGN_OR_RETURN(auto sig,
+                        xla::PyArgSignatureOfValue(arg, jax_enable_x64));
+    arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
+  }
+  return xla::Status::OK();
+}
+
+// Copy buffers to device, skipping pruned arguments.
+// Returns `Status::OK()` on success. Returning an error should lead to
+// calling the Python fallback.
+xla::Status CopyBuffersToDevice(
+    bool jax_enable_x64, const absl::optional<std::vector<bool>>& kept_args,
+    ParsedArgumentsAsBuffers& arguments) {
+  std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
+  xla::PjRtDevice* data_device = arguments.signature.device;
+
+  int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
   xla::DevicePutOptions options;
   options.squash_64bit_types = !jax_enable_x64;
   // TODO(phawkins): consider allowing forces here.
   options.force_lazy_arrays = false;
   options.allow_zero_copy = true;
   arg_buffers.reserve(num_flat_dynamic_args);
-  arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
+  bool input_pruning_enabled = kept_args.has_value();
   for (int i = 0; i < num_flat_dynamic_args; ++i) {
+    if (input_pruning_enabled && !kept_args.value()[i]) {
+      continue;
+    }
+
     py::handle arg = arguments.flat_dynamic_args[i];
     TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
                         DevicePut(arg, data_device, options));
@@ -639,16 +659,11 @@
     xla::PjRtBuffer* buffer = on_device.buffer;
     arg_buffers.push_back(buffer);
     if (on_device.owned_buffer) {
-      keep_alive.push_back(std::move(on_device.owned_buffer));
+      arguments.keep_alive.push_back(std::move(on_device.owned_buffer));
     } else if (on_device.owning_pybuffer) {
       arguments.keep_alive_objects.push_back(
           std::move(on_device.owning_pybuffer));
     }
-
-    xla::PyArgSignature sig(buffer->on_device_shape().element_type(),
-                            buffer->on_device_shape().dimensions(),
-                            on_device.weak_type);
-    arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
   }
   return xla::Status::OK();
 }
@@ -784,8 +799,8 @@
   arguments.signature.jax_enable_x64 = jax_enable_x64;
   // The C++ jit do not support Tracers arguments inputs yet. The Python-based
   // jit function will be called if any of the dynamic arguments is unsupported.
-  if (!ConvertArgsToBuffers(jax_enable_x64, *default_pyclient_, default_device_,
-                            is_committed_, arguments)
+  if (!ComputeSignature(jax_enable_x64, *default_pyclient_, default_device_,
+                        is_committed_, arguments)
            .ok()) {
     return py::object(
         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
@@ -842,35 +857,22 @@
                                         **kwargs.value_or(py::kwargs())))[0]);
   }
 
+  if (!CopyBuffersToDevice(jax_enable_x64, cache_entry->kept_var_bitvec,
+                           arguments)
+           .ok()) {
+    return py::object(
+        py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
+                                        **kwargs.value_or(py::kwargs())))[0]);
+  }
+
   // Executes the computation.
   std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
   {
     py::gil_scoped_release gil_release;
-    // TODO(zhangqiaorjc): Refactor ConvertArgsToBuffers. Split out the part
-    // that computes parts of the signature and tests for incompatible devices,
-    // and move it either into ParseArguments or a new function. Move the part
-    // that copies buffers around to here, and we can fuse this "argument
-    // dropping" logic with that code
-    if (cache_entry->kept_var_bitvec.has_value()) {
-      // Input pruning enabled.
-      std::vector<xla::PjRtBuffer*> kept_args;
-      kept_args.reserve(arguments.arg_buffers.size());
-      for (int i = 0; i < arguments.arg_buffers.size(); ++i) {
-        if (cache_entry->kept_var_bitvec.value()[i]) {
-          kept_args.push_back(arguments.arg_buffers[i]);
-        }
-      }
-      TF_ASSIGN_OR_RETURN(
-          output_buffers,
-          cache_entry->executable->mutable_pjrt_executable()->Execute(
-              {kept_args}, cache_entry->executable->options()));
-    } else {
-      // Input pruning not enabled.
-      TF_ASSIGN_OR_RETURN(
-          output_buffers,
-          cache_entry->executable->mutable_pjrt_executable()->Execute(
-              {arguments.arg_buffers}, cache_entry->executable->options()));
-    }
+    TF_ASSIGN_OR_RETURN(
+        output_buffers,
+        cache_entry->executable->mutable_pjrt_executable()->Execute(
+            {arguments.arg_buffers}, cache_entry->executable->options()));
   }
   auto traceback = xla::Traceback::Get();
 
diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h
index fe010cb..fe195a3 100644
--- a/tensorflow/compiler/xla/python/jax_jit.h
+++ b/tensorflow/compiler/xla/python/jax_jit.h
@@ -95,7 +95,7 @@
   std::vector<xla::PjRtBuffer*> arg_buffers;
   // We may need to keep these objects around, because:
   // (a) we need to extend the lifetime of objects created within
-  //    `ConvertArgsToBuffers`
+  //    `CopyBuffersToDevice`
   // (b) `arg_buffers` do not maintain ownership
   std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
 };
diff --git a/tensorflow/compiler/xla/python/py_values.cc b/tensorflow/compiler/xla/python/py_values.cc
index 53b293f..65eeed8 100644
--- a/tensorflow/compiler/xla/python/py_values.cc
+++ b/tensorflow/compiler/xla/python/py_values.cc
@@ -267,8 +267,6 @@
 
         // Generic subclasses of DeviceArray, e.g., ShardedDeviceArray.
         (*p)[PyBuffer::base_type()] = HandleDeviceArray;
-        // The C++ PyBuffer class is handled specially.
-        (*p)[PyBuffer::type()] = HandlePyBuffer;
 
         try {
           py::object xla_module = py::module::import("jax.interpreters.xla");
@@ -323,6 +321,11 @@
         return p;
       }();
 
+  // Fast-path for the most common case of PyBuffer.
+  if (arg.get_type().ptr() == PyBuffer::type()) {
+    return HandlePyBuffer(arg, to_device, options);
+  }
+
   auto res = handlers->find(arg.get_type().ptr());
   if (res == handlers->end()) {
     for (auto base_class : arg.get_type().attr("mro")()) {
@@ -337,9 +340,8 @@
                   "DeviceArray, Numpy arrays scalars of supported types "
                   "(see implementation), or Python scalars. Got type ",
                   py::cast<std::string>(py::str(arg.get_type()))));
-  } else {
-    return res->second(arg, to_device, options);
   }
+  return res->second(arg, to_device, options);
 }
 
 bool IsFloat0(py::array arg) {
@@ -369,9 +371,6 @@
       PyObject*, ToPyArgSignatureHandler>* const handlers = [] {
     auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
 
-    const auto xla_module = py::module::import("jax.interpreters.xla");
-    const auto& device_array = xla_module.attr("_DeviceArray");
-
     const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
 
     // The 4 Python native types.
@@ -418,19 +417,7 @@
     (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
     (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
 
-    // The Buffer types
-    // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
-    ToPyArgSignatureHandler buffer_handler =
-        [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
-      TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(h));
-      bool weak_type = buffer->weak_type().has_value()
-                           ? *buffer->weak_type()
-                           : py::cast<bool>(h.attr("aval").attr("weak_type"));
-      return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
-                            buffer->buffer()->on_device_shape().dimensions(),
-                            weak_type);
-    };
-    (*p)[PyBuffer::base_type()] = buffer_handler;
+    // The Buffer types except for fast-path PyBuffer.
     ToPyArgSignatureHandler device_array_handler =
         [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
       py::handle aval = h.attr("aval");
@@ -439,17 +426,33 @@
                             py::cast<std::vector<int64>>(aval.attr("shape")),
                             py::cast<py::bool_>(aval.attr("weak_type")));
     };
-    // ShardedDeviceArray is covered by the MRO fallback on _DeviceArray.
-    (*p)[device_array.ptr()] = device_array_handler;
+    (*p)[PyBuffer::base_type()] = device_array_handler;
+
+    try {
+      py::object xla_module = py::module::import("jax.interpreters.xla");
+      py::object device_array =
+          py::getattr(xla_module, "_DeviceArray", py::none());
+      if (!device_array.is_none()) {
+        (*p)[device_array.ptr()] = device_array_handler;
+      }
+    } catch (const py::error_already_set& e) {
+      // Ignore; jax may not be present.
+    }
+
+    try {
+      py::object pxla_module = py::module::import("jax.interpreters.pxla");
+      py::object sda =
+          py::getattr(pxla_module, "ShardedDeviceArray", py::none());
+      if (!sda.is_none()) {
+        (*p)[sda.ptr()] = device_array_handler;
+      }
+    } catch (const py::error_already_set& e) {
+      // Ignore; jax may not be present.
+    }
 
     ToPyArgSignatureHandler numpy_handler =
         [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
       py::array numpy_array = py::cast<py::array>(h);
-      if (IsFloat0(numpy_array)) {
-        return InvalidArgument(
-            "float0 numpy arrays not supported in C++. "
-            "Falling back to Python.");
-      }
       TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
                           DtypeToPrimitiveType(numpy_array.dtype()));
       if (!jax_enable_x64) {
@@ -468,8 +471,7 @@
           /*weak_type=*/false);
     };
     const auto numpy = py::module::import("numpy");
-    const auto& ndarray = numpy.attr("ndarray");
-    (*p)[ndarray.ptr()] = numpy_handler;
+    (*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
 
     ToPyArgSignatureHandler np_uint64_handler =
         [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
@@ -520,6 +522,18 @@
     return p;
   }();
 
+  // Fast-path for the most common case of PyBuffer.
+  if (arg.get_type().ptr() == PyBuffer::type()) {
+    // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
+    TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(arg));
+    bool weak_type = buffer->weak_type().has_value()
+                         ? *buffer->weak_type()
+                         : py::cast<bool>(arg.attr("aval").attr("weak_type"));
+    return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
+                          buffer->buffer()->on_device_shape().dimensions(),
+                          weak_type);
+  }
+
   auto res = handlers->find(arg.get_type().ptr());
   if (res == handlers->end()) {
     // We attempt to look at the MRO classes
@@ -536,9 +550,8 @@
                      "arrays scalars of supported types "
                      "(see implementation), or Python scalars. Got type ",
                      py::cast<std::string>(py::str(arg.get_type()))));
-  } else {
-    return res->second(arg, jax_enable_x64);
   }
+  return res->second(arg, jax_enable_x64);
 }
 
 }  // namespace xla