[JAX] Refactor jax_jit to avoid DevicePut on pruned args.
name old cpu/op new cpu/op delta
eager_unary_dispatch 35.7µs ± 2% 35.9µs ± 3% ~ (p=0.841 n=5+5)
eager_unary 36.4µs ± 2% 36.6µs ± 3% ~ (p=0.421 n=5+5)
eager_binary_dispatch 45.6µs ± 1% 46.1µs ± 2% ~ (p=0.421 n=5+5)
eager_binary 46.6µs ± 2% 47.0µs ± 5% ~ (p=1.000 n=5+5)
jit_trivial_dispatch 41.4µs ± 1% 41.4µs ± 0% ~ (p=0.690 n=5+5)
jit_trivial 42.4µs ± 1% 42.3µs ± 1% ~ (p=0.841 n=5+5)
jit_simple_dispatch 8.85µs ± 3% 9.15µs ± 3% ~ (p=0.095 n=5+5)
jit_simple 9.77µs ± 1% 9.82µs ± 2% ~ (p=0.548 n=5+5)
jit_simple_many_args_dispatch_10 13.4µs ± 1% 13.6µs ± 3% ~ (p=0.222 n=5+5)
jit_simple_many_args_10 14.0µs ± 2% 14.1µs ± 1% ~ (p=0.421 n=5+5)
jit_simple_pruned_args_dispatch_10 8.05µs ± 3% 8.07µs ± 4% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_10 9.53µs ± 2% 9.43µs ± 2% ~ (p=0.222 n=5+5)
jit_simple_many_args_dispatch_100 55.2µs ± 1% 54.8µs ± 2% ~ (p=0.310 n=5+5)
jit_simple_many_args_100 55.8µs ± 1% 55.8µs ± 1% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_100 14.3µs ± 4% 12.6µs ± 1% -11.41% (p=0.016 n=5+4)
jit_simple_pruned_args_100 14.8µs ± 1% 13.3µs ± 2% -10.06% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000 489µs ± 1% 477µs ± 3% ~ (p=0.056 n=5+5)
jit_simple_many_args_1000 495µs ± 3% 493µs ± 3% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_1000 85.0µs ± 3% 65.3µs ± 3% -23.13% (p=0.008 n=5+5)
jit_simple_pruned_args_1000 86.0µs ± 3% 66.4µs ± 3% -22.78% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000 1.09ms ± 4% 1.03ms ± 3% -5.97% (p=0.016 n=5+5)
jit_simple_many_args_2000 1.07ms ± 3% 1.04ms ± 5% ~ (p=0.095 n=5+5)
jit_simple_pruned_args_dispatch_2000 190µs ± 3% 144µs ± 3% -23.96% (p=0.008 n=5+5)
jit_simple_pruned_args_2000 195µs ± 4% 147µs ± 3% -24.29% (p=0.008 n=5+5)
jit_dispatch_without_transfer 76.0µs ± 1% 77.2µs ± 6% ~ (p=0.310 n=5+5)
jit_dispatch_with_transfer 82.1µs ± 5% 81.3µs ± 2% ~ (p=0.421 n=5+5)
sda_index_1 8.83µs ± 1% 8.73µs ± 2% ~ (p=0.222 n=5+5)
name old time/op new time/op delta
eager_unary_dispatch 35.7µs ± 2% 35.9µs ± 3% ~ (p=0.841 n=5+5)
eager_unary 36.5µs ± 2% 37.1µs ± 4% ~ (p=0.222 n=5+5)
eager_binary_dispatch 45.6µs ± 1% 46.1µs ± 2% ~ (p=0.421 n=5+5)
eager_binary 46.8µs ± 3% 47.1µs ± 5% ~ (p=1.000 n=5+5)
jit_trivial_dispatch 41.4µs ± 1% 41.4µs ± 0% ~ (p=0.690 n=5+5)
jit_trivial 42.4µs ± 1% 42.3µs ± 1% ~ (p=0.841 n=5+5)
jit_simple_dispatch 8.86µs ± 3% 9.15µs ± 3% ~ (p=0.095 n=5+5)
jit_simple 9.82µs ± 1% 9.91µs ± 0% ~ (p=0.190 n=5+4)
jit_simple_many_args_dispatch_10 13.4µs ± 1% 13.6µs ± 4% ~ (p=0.310 n=5+5)
jit_simple_many_args_10 14.1µs ± 2% 14.2µs ± 1% ~ (p=0.421 n=5+5)
jit_simple_pruned_args_dispatch_10 8.07µs ± 4% 8.07µs ± 4% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_10 9.59µs ± 2% 9.48µs ± 2% ~ (p=0.222 n=5+5)
jit_simple_many_args_dispatch_100 55.2µs ± 1% 54.8µs ± 2% ~ (p=0.310 n=5+5)
jit_simple_many_args_100 55.9µs ± 1% 55.9µs ± 1% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_100 14.3µs ± 5% 12.6µs ± 1% -11.75% (p=0.016 n=5+4)
jit_simple_pruned_args_100 14.8µs ± 2% 13.3µs ± 2% -10.19% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000 489µs ± 1% 477µs ± 3% ~ (p=0.056 n=5+5)
jit_simple_many_args_1000 495µs ± 3% 493µs ± 3% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_dispatch_1000 85.0µs ± 3% 65.3µs ± 3% -23.13% (p=0.008 n=5+5)
jit_simple_pruned_args_1000 86.1µs ± 3% 66.5µs ± 2% -22.72% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000 1.09ms ± 4% 1.03ms ± 3% -5.96% (p=0.016 n=5+5)
jit_simple_many_args_2000 1.07ms ± 3% 1.04ms ± 5% ~ (p=0.095 n=5+5)
jit_simple_pruned_args_dispatch_2000 190µs ± 3% 144µs ± 3% -23.97% (p=0.008 n=5+5)
jit_simple_pruned_args_2000 195µs ± 4% 147µs ± 3% -24.31% (p=0.008 n=5+5)
jit_dispatch_without_transfer 1.41ms ± 1% 1.40ms ± 1% ~ (p=0.095 n=5+5)
jit_dispatch_with_transfer 1.40ms ± 2% 1.40ms ± 2% ~ (p=0.841 n=5+5)
sda_index_1 8.83µs ± 1% 8.73µs ± 2% ~ (p=0.222 n=5+5)
PiperOrigin-RevId: 374468578
Change-Id: I0a45af35b936a72f8271bd3e3a66e0d778619132
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index f62b5d0..f9b3fbd 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -540,18 +540,14 @@
CompiledFunction::~CompiledFunction() = default;
-// Converts flattened arguments contained in ParsedArgumentsAsBuffers in
-// place. If arguments are `DeviceArray`, they must all be on the same `Device`.
+// Compute signature for arguments.
//
// Returns `Status::OK()` on success. Returning an error should lead to
// calling the Python fallback.
-xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
- xla::PjRtDevice* default_device,
- bool is_committed,
- ParsedArgumentsAsBuffers& arguments) {
- tensorflow::profiler::TraceMe traceme("ConvertArgsToBuffers");
- std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
- auto& keep_alive = arguments.keep_alive;
+xla::Status ComputeSignature(bool jax_enable_x64, xla::PyClient& pyclient,
+ xla::PjRtDevice* default_device, bool is_committed,
+ ParsedArgumentsAsBuffers& arguments) {
+ tensorflow::profiler::TraceMe traceme("ComputeSignature");
int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
struct PythonTypes {
@@ -624,14 +620,38 @@
CHECK(data_device);
arguments.signature.device = data_device;
+ arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
+ for (int i = 0; i < num_flat_dynamic_args; ++i) {
+ py::handle arg = arguments.flat_dynamic_args[i];
+ TF_ASSIGN_OR_RETURN(auto sig,
+ xla::PyArgSignatureOfValue(arg, jax_enable_x64));
+ arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
+ }
+ return xla::Status::OK();
+}
+
+// Copy buffers to device, skipping pruned arguments.
+// Returns `Status::OK()` on success. Returning an error should lead to
+// calling the Python fallback.
+xla::Status CopyBuffersToDevice(
+ bool jax_enable_x64, const absl::optional<std::vector<bool>>& kept_args,
+ ParsedArgumentsAsBuffers& arguments) {
+ std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
+ xla::PjRtDevice* data_device = arguments.signature.device;
+
+ int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
xla::DevicePutOptions options;
options.squash_64bit_types = !jax_enable_x64;
// TODO(phawkins): consider allowing forces here.
options.force_lazy_arrays = false;
options.allow_zero_copy = true;
arg_buffers.reserve(num_flat_dynamic_args);
- arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
+ bool input_pruning_enabled = kept_args.has_value();
for (int i = 0; i < num_flat_dynamic_args; ++i) {
+ if (input_pruning_enabled && !kept_args.value()[i]) {
+ continue;
+ }
+
py::handle arg = arguments.flat_dynamic_args[i];
TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
DevicePut(arg, data_device, options));
@@ -639,16 +659,11 @@
xla::PjRtBuffer* buffer = on_device.buffer;
arg_buffers.push_back(buffer);
if (on_device.owned_buffer) {
- keep_alive.push_back(std::move(on_device.owned_buffer));
+ arguments.keep_alive.push_back(std::move(on_device.owned_buffer));
} else if (on_device.owning_pybuffer) {
arguments.keep_alive_objects.push_back(
std::move(on_device.owning_pybuffer));
}
-
- xla::PyArgSignature sig(buffer->on_device_shape().element_type(),
- buffer->on_device_shape().dimensions(),
- on_device.weak_type);
- arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
}
return xla::Status::OK();
}
@@ -784,8 +799,8 @@
arguments.signature.jax_enable_x64 = jax_enable_x64;
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
// jit function will be called if any of the dynamic arguments is unsupported.
- if (!ConvertArgsToBuffers(jax_enable_x64, *default_pyclient_, default_device_,
- is_committed_, arguments)
+ if (!ComputeSignature(jax_enable_x64, *default_pyclient_, default_device_,
+ is_committed_, arguments)
.ok()) {
return py::object(
py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
@@ -842,35 +857,22 @@
**kwargs.value_or(py::kwargs())))[0]);
}
+ if (!CopyBuffersToDevice(jax_enable_x64, cache_entry->kept_var_bitvec,
+ arguments)
+ .ok()) {
+ return py::object(
+ py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
+ **kwargs.value_or(py::kwargs())))[0]);
+ }
+
// Executes the computation.
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
{
py::gil_scoped_release gil_release;
- // TODO(zhangqiaorjc): Refactor ConvertArgsToBuffers. Split out the part
- // that computes parts of the signature and tests for incompatible devices,
- // and move it either into ParseArguments or a new function. Move the part
- // that copies buffers around to here, and we can fuse this "argument
- // dropping" logic with that code
- if (cache_entry->kept_var_bitvec.has_value()) {
- // Input pruning enabled.
- std::vector<xla::PjRtBuffer*> kept_args;
- kept_args.reserve(arguments.arg_buffers.size());
- for (int i = 0; i < arguments.arg_buffers.size(); ++i) {
- if (cache_entry->kept_var_bitvec.value()[i]) {
- kept_args.push_back(arguments.arg_buffers[i]);
- }
- }
- TF_ASSIGN_OR_RETURN(
- output_buffers,
- cache_entry->executable->mutable_pjrt_executable()->Execute(
- {kept_args}, cache_entry->executable->options()));
- } else {
- // Input pruning not enabled.
- TF_ASSIGN_OR_RETURN(
- output_buffers,
- cache_entry->executable->mutable_pjrt_executable()->Execute(
- {arguments.arg_buffers}, cache_entry->executable->options()));
- }
+ TF_ASSIGN_OR_RETURN(
+ output_buffers,
+ cache_entry->executable->mutable_pjrt_executable()->Execute(
+ {arguments.arg_buffers}, cache_entry->executable->options()));
}
auto traceback = xla::Traceback::Get();
diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h
index fe010cb..fe195a3 100644
--- a/tensorflow/compiler/xla/python/jax_jit.h
+++ b/tensorflow/compiler/xla/python/jax_jit.h
@@ -95,7 +95,7 @@
std::vector<xla::PjRtBuffer*> arg_buffers;
// We may need to keep these objects around, because:
// (a) we need to extend the lifetime of objects created within
- // `ConvertArgsToBuffers`
+ // `CopyBuffersToDevice`
// (b) `arg_buffers` do not maintain ownership
std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
};
diff --git a/tensorflow/compiler/xla/python/py_values.cc b/tensorflow/compiler/xla/python/py_values.cc
index 53b293f..65eeed8 100644
--- a/tensorflow/compiler/xla/python/py_values.cc
+++ b/tensorflow/compiler/xla/python/py_values.cc
@@ -267,8 +267,6 @@
// Generic subclasses of DeviceArray, e.g., ShardedDeviceArray.
(*p)[PyBuffer::base_type()] = HandleDeviceArray;
- // The C++ PyBuffer class is handled specially.
- (*p)[PyBuffer::type()] = HandlePyBuffer;
try {
py::object xla_module = py::module::import("jax.interpreters.xla");
@@ -323,6 +321,11 @@
return p;
}();
+ // Fast-path for the most common case of PyBuffer.
+ if (arg.get_type().ptr() == PyBuffer::type()) {
+ return HandlePyBuffer(arg, to_device, options);
+ }
+
auto res = handlers->find(arg.get_type().ptr());
if (res == handlers->end()) {
for (auto base_class : arg.get_type().attr("mro")()) {
@@ -337,9 +340,8 @@
"DeviceArray, Numpy arrays scalars of supported types "
"(see implementation), or Python scalars. Got type ",
py::cast<std::string>(py::str(arg.get_type()))));
- } else {
- return res->second(arg, to_device, options);
}
+ return res->second(arg, to_device, options);
}
bool IsFloat0(py::array arg) {
@@ -369,9 +371,6 @@
PyObject*, ToPyArgSignatureHandler>* const handlers = [] {
auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
- const auto xla_module = py::module::import("jax.interpreters.xla");
- const auto& device_array = xla_module.attr("_DeviceArray");
-
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
// The 4 Python native types.
@@ -418,19 +417,7 @@
(*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
(*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
- // The Buffer types
- // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
- ToPyArgSignatureHandler buffer_handler =
- [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
- TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(h));
- bool weak_type = buffer->weak_type().has_value()
- ? *buffer->weak_type()
- : py::cast<bool>(h.attr("aval").attr("weak_type"));
- return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
- buffer->buffer()->on_device_shape().dimensions(),
- weak_type);
- };
- (*p)[PyBuffer::base_type()] = buffer_handler;
+ // The Buffer types except for fast-path PyBuffer.
ToPyArgSignatureHandler device_array_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
py::handle aval = h.attr("aval");
@@ -439,17 +426,33 @@
py::cast<std::vector<int64>>(aval.attr("shape")),
py::cast<py::bool_>(aval.attr("weak_type")));
};
- // ShardedDeviceArray is covered by the MRO fallback on _DeviceArray.
- (*p)[device_array.ptr()] = device_array_handler;
+ (*p)[PyBuffer::base_type()] = device_array_handler;
+
+ try {
+ py::object xla_module = py::module::import("jax.interpreters.xla");
+ py::object device_array =
+ py::getattr(xla_module, "_DeviceArray", py::none());
+ if (!device_array.is_none()) {
+ (*p)[device_array.ptr()] = device_array_handler;
+ }
+ } catch (const py::error_already_set& e) {
+ // Ignore; jax may not be present.
+ }
+
+ try {
+ py::object pxla_module = py::module::import("jax.interpreters.pxla");
+ py::object sda =
+ py::getattr(pxla_module, "ShardedDeviceArray", py::none());
+ if (!sda.is_none()) {
+ (*p)[sda.ptr()] = device_array_handler;
+ }
+ } catch (const py::error_already_set& e) {
+ // Ignore; jax may not be present.
+ }
ToPyArgSignatureHandler numpy_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
py::array numpy_array = py::cast<py::array>(h);
- if (IsFloat0(numpy_array)) {
- return InvalidArgument(
- "float0 numpy arrays not supported in C++. "
- "Falling back to Python.");
- }
TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
DtypeToPrimitiveType(numpy_array.dtype()));
if (!jax_enable_x64) {
@@ -468,8 +471,7 @@
/*weak_type=*/false);
};
const auto numpy = py::module::import("numpy");
- const auto& ndarray = numpy.attr("ndarray");
- (*p)[ndarray.ptr()] = numpy_handler;
+ (*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
ToPyArgSignatureHandler np_uint64_handler =
[](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
@@ -520,6 +522,18 @@
return p;
}();
+ // Fast-path for the most common case of PyBuffer.
+ if (arg.get_type().ptr() == PyBuffer::type()) {
+ // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
+ TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(arg));
+ bool weak_type = buffer->weak_type().has_value()
+ ? *buffer->weak_type()
+ : py::cast<bool>(arg.attr("aval").attr("weak_type"));
+ return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
+ buffer->buffer()->on_device_shape().dimensions(),
+ weak_type);
+ }
+
auto res = handlers->find(arg.get_type().ptr());
if (res == handlers->end()) {
// We attempt to look at the MRO classes
@@ -536,9 +550,8 @@
"arrays scalars of supported types "
"(see implementation), or Python scalars. Got type ",
py::cast<std::string>(py::str(arg.get_type()))));
- } else {
- return res->second(arg, jax_enable_x64);
}
+ return res->second(arg, jax_enable_x64);
}
} // namespace xla