[JAX] Only perform casts from Python objects to PyBuffer* once.
These are slow enough that they show up in profiling.
I believe a significant factor is that constructing a pybind11 type caster involves a `std::unordered_map<>` lookup.
A better fix might be to avoid pybind11's casting facilities, but that is a larger change.
PiperOrigin-RevId: 363176446
Change-Id: I645af27b1239b0add438e0fa004ee0b2c3295e96
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index d82a1d2..98b874a 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -604,6 +604,9 @@
arg_buffers.reserve(num_flat_dynamic_args);
arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args);
+ absl::InlinedVector<xla::PyBuffer*, 4> py_buffers;
+ py_buffers.resize(num_flat_dynamic_args, nullptr);
+
struct PythonTypes {
py::object device_array;
py::object py_buffer_type;
@@ -626,12 +629,14 @@
if (is_committed) {
data_device = default_device;
} else {
- for (py::handle arg : arguments.flat_dynamic_args) {
+ for (int i = 0; i < num_flat_dynamic_args; ++i) {
+ py::handle arg = arguments.flat_dynamic_args[i];
// We specically only deal with DeviceArray (not ShardedDeviceArray).
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
xla::PjRtDevice* device = nullptr;
if (arg.get_type().ptr() == types.py_buffer_type.ptr()) {
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(arg);
+ py_buffers[i] = buffer;
if (!buffer->sticky_device()) {
continue;
}
@@ -681,9 +686,10 @@
// TODO(phawkins): consider allowing forces here.
options.force_lazy_arrays = false;
options.allow_zero_copy = true;
- for (py::handle arg : arguments.flat_dynamic_args) {
+ for (int i = 0; i < num_flat_dynamic_args; ++i) {
+ py::handle arg = arguments.flat_dynamic_args[i];
TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
- DevicePut(arg, data_device, options));
+ DevicePut(arg, data_device, options, py_buffers[i]));
xla::PjRtBuffer* buffer = on_device.buffer;
arg_buffers.push_back(buffer);
diff --git a/tensorflow/compiler/xla/python/py_values.cc b/tensorflow/compiler/xla/python/py_values.cc
index cecac80..3850a03 100644
--- a/tensorflow/compiler/xla/python/py_values.cc
+++ b/tensorflow/compiler/xla/python/py_values.cc
@@ -152,8 +152,8 @@
}
StatusOr<DevicePutResult> PyBufferHelper(py::handle obj, py::handle py_buffer,
+ PyBuffer* buffer,
PjRtDevice* to_device) {
- PyBuffer* buffer = py::cast<PyBuffer*>(py_buffer);
bool weak_type = buffer->weak_type()
? *buffer->weak_type()
: py::cast<bool>(obj.attr("aval").attr("weak_type"));
@@ -170,7 +170,7 @@
StatusOr<DevicePutResult> HandlePyBuffer(py::handle obj, PjRtDevice* to_device,
const DevicePutOptions& options) {
- return PyBufferHelper(obj, obj, to_device);
+ return PyBufferHelper(obj, obj, py::cast<PyBuffer*>(obj), to_device);
}
StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
@@ -201,14 +201,18 @@
obj = forced;
}
- return PyBufferHelper(obj, buffer, to_device);
+ return PyBufferHelper(obj, buffer, py::cast<PyBuffer*>(buffer), to_device);
}
} // namespace
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
- const DevicePutOptions& options) {
+ const DevicePutOptions& options,
+ PyBuffer* py_buffer) {
tensorflow::profiler::TraceMe traceme("DevicePut");
+ if (py_buffer) {
+ return PyBufferHelper(arg, arg, py_buffer, to_device);
+ }
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
[] {
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
diff --git a/tensorflow/compiler/xla/python/py_values.h b/tensorflow/compiler/xla/python/py_values.h
index ceffe74..61d6939 100644
--- a/tensorflow/compiler/xla/python/py_values.h
+++ b/tensorflow/compiler/xla/python/py_values.h
@@ -53,6 +53,8 @@
// If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be
// returned; float0s and `_DeviceArray`s with non-trivial LazyExprs are not
// supported yet.
+// If the value is known to be a PyBuffer object, py_buffer can be passed as
+// an optimization to avoid a Python->C++ cast.
//
// May throw exceptions from pybind11 in addition to failing via an error
// Status. (We could catch these if needed, but there seems little point.)
@@ -62,7 +64,8 @@
bool force_lazy_arrays = true;
};
StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
- const DevicePutOptions& options);
+ const DevicePutOptions& options,
+ PyBuffer* py_buffer = nullptr);
} // namespace xla