Rename xla::PjRtExecutable to xla::PjRtLoadedExecutable
PiperOrigin-RevId: 463460929
diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
index a8f2e80..f531866 100644
--- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
+++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
@@ -49,7 +49,7 @@
};
struct PJRT_Executable {
- std::unique_ptr<xla::PjRtExecutable> executable;
+ std::unique_ptr<xla::PjRtLoadedExecutable> executable;
PJRT_Client* client;
// These pointers are a subset of `client`'s `addressable_devices`, i.e. those
// addressed by the compiled executable program. `client` owns the objects
diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc
index ca26344..d72e2ae 100644
--- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc
+++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc
@@ -56,7 +56,7 @@
compile_options.executable_build_options.set_device_assignment(
device_assignment);
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<PjRtExecutable> executable,
+ std::unique_ptr<PjRtLoadedExecutable> executable,
client->Compile(computation, std::move(compile_options)));
int64_t dummy_size = 1 << 20;
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc
index ba1d837..98b0eea 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc
@@ -152,7 +152,7 @@
}
StatusOr<std::optional<std::string>> PjRtCApiClient::ExecutableFingerprint(
- const PjRtExecutable& executable) const {
+ const PjRtLoadedExecutable& executable) const {
return wrapped_->ExecutableFingerprint(
*PjRtCApiExecutable::GetWrapped(&executable));
}
@@ -168,13 +168,14 @@
}
StatusOr<std::string> PjRtCApiClient::SerializeExecutable(
- const PjRtExecutable& executable) const {
+ const PjRtLoadedExecutable& executable) const {
return wrapped_->SerializeExecutable(
*PjRtCApiExecutable::GetWrapped(&executable));
}
-StatusOr<std::unique_ptr<PjRtExecutable>> PjRtCApiClient::DeserializeExecutable(
- absl::string_view serialized, CompileOptions options) {
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
+PjRtCApiClient::DeserializeExecutable(absl::string_view serialized,
+ CompileOptions options) {
return WrapExecutable(wrapped_->DeserializeExecutable(serialized, options));
}
@@ -183,11 +184,11 @@
return wrapped_->UnsafeBufferPointer(PjRtCApiBuffer::GetWrapped(buffer));
}
-StatusOr<std::unique_ptr<PjRtExecutable>> PjRtCApiClient::WrapExecutable(
- StatusOr<std::unique_ptr<PjRtExecutable>> to_wrap) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::WrapExecutable(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
std::move(to_wrap));
- return std::unique_ptr<PjRtExecutable>(
+ return std::unique_ptr<PjRtLoadedExecutable>(
std::make_unique<PjRtCApiExecutable>(this, std::move(executable)));
}
@@ -304,8 +305,8 @@
// ------------------------------- Executables ---------------------------------
-PjRtCApiExecutable::PjRtCApiExecutable(PjRtCApiClient* client,
- std::unique_ptr<PjRtExecutable> wrapped)
+PjRtCApiExecutable::PjRtCApiExecutable(
+ PjRtCApiClient* client, std::unique_ptr<PjRtLoadedExecutable> wrapped)
: client_(client),
executable_(
new PJRT_Executable{std::move(wrapped), client->pjrt_c_client()}) {
@@ -483,7 +484,7 @@
return out;
}
-PjRtExecutable* PjRtCApiExecutable::wrapped() const {
+PjRtLoadedExecutable* PjRtCApiExecutable::wrapped() const {
return executable_->executable.get();
}
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
index 8c53ae6..6afa024 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
@@ -134,23 +134,23 @@
return wrapped_->GetHloCostAnalysis();
}
- StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override {
return WrapExecutable(wrapped_->Compile(computation, options));
}
- StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) override {
return WrapExecutable(wrapped_->Compile(module, options));
}
StatusOr<std::optional<std::string>> ExecutableFingerprint(
- const PjRtExecutable& executable) const override;
+ const PjRtLoadedExecutable& executable) const override;
StatusOr<std::string> SerializeExecutable(
- const PjRtExecutable& executable) const override;
+ const PjRtLoadedExecutable& executable) const override;
- StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized, CompileOptions options) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
@@ -231,8 +231,8 @@
return it->second;
}
- StatusOr<std::unique_ptr<PjRtExecutable>> WrapExecutable(
- StatusOr<std::unique_ptr<PjRtExecutable>> to_wrap);
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> WrapExecutable(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> to_wrap);
StatusOr<std::unique_ptr<PjRtBuffer>> WrapBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap);
@@ -371,10 +371,10 @@
void set_shape();
};
-class PjRtCApiExecutable : public PjRtExecutable {
+class PjRtCApiExecutable : public PjRtLoadedExecutable {
public:
PjRtCApiExecutable(PjRtCApiClient* client,
- std::unique_ptr<PjRtExecutable> wrapped);
+ std::unique_ptr<PjRtLoadedExecutable> wrapped);
~PjRtCApiExecutable() override;
@@ -426,9 +426,10 @@
void Delete() override;
bool IsDeleted() override;
- PjRtExecutable* wrapped() const;
+ PjRtLoadedExecutable* wrapped() const;
- static PjRtExecutable* GetWrapped(const PjRtExecutable* c_api_executable) {
+ static PjRtLoadedExecutable* GetWrapped(
+ const PjRtLoadedExecutable* c_api_executable) {
return tensorflow::down_cast<const PjRtCApiExecutable*>(c_api_executable)
->wrapped();
}
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h
index 530c141..1602955 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h
@@ -371,7 +371,7 @@
size_t dst_size, const Shape& dst_shape) = 0;
};
-class PjRtExecutable;
+class PjRtLoadedExecutable;
// Encapsulates the state of Python session with XLA.
//
@@ -493,27 +493,27 @@
virtual StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() = 0;
// Compile `computation` with given `options`.
- virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) = 0;
// Variant of `Compile` that accepts an MLIR module.
- virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) = 0;
// Generates a unique fingerprint for `executable`, may be std::nullopt.
virtual StatusOr<std::optional<std::string>> ExecutableFingerprint(
- const PjRtExecutable& executable) const = 0;
+ const PjRtLoadedExecutable& executable) const = 0;
// Returns a platform-specific serialization of `executable`. The
// serialization is not guaranteed to be stable over time. `executable` must
// have been produced by this client.
virtual StatusOr<std::string> SerializeExecutable(
- const PjRtExecutable& executable) const = 0;
+ const PjRtLoadedExecutable& executable) const = 0;
// Deserializes a serialized executable as produced by
// SerializeExecutable(). `serialized` must have been produced by a client of
// the same platform and version as this one.
- virtual StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
+ virtual StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized, CompileOptions options) = 0;
// Creates a buffer on the device without initializing or copying any data.
@@ -1081,9 +1081,9 @@
// device-allocated literals. If any input/output alias has been specified in
// the computation, the parameter containing the input buffer will be donated
// when passed to the execution.
-class PjRtExecutable {
+class PjRtLoadedExecutable {
public:
- virtual ~PjRtExecutable() = default;
+ virtual ~PjRtLoadedExecutable() = default;
virtual PjRtClient* client() const = 0;
@@ -1139,8 +1139,8 @@
// else:
// *returned_futures is undefined.
//
- // The caller is *NOT* required to ensure that PjRtExecutable stays alive
- // until futures are ready.
+ // The caller is *NOT* required to ensure that PjRtLoadedExecutable stays
+ // alive until futures are ready.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options,
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
index 77976d0..b87f97c 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
@@ -2087,7 +2087,8 @@
return outputs;
}
-StatusOr<PjRtExecutable::Result> PjRtStreamExecutorExecutable::ExecuteHelper(
+StatusOr<PjRtLoadedExecutable::Result>
+PjRtStreamExecutorExecutable::ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
const RunId& run_id, const ExecuteOptions& options, bool fill_future,
PjRtDevice* device) const {
@@ -2422,7 +2423,7 @@
VLOG(3) << "Non-local device: " << device_id;
continue;
}
- PjRtExecutable::LogicalDeviceIds logica_device_ids;
+ PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
logica_device_ids.replica = replica;
logica_device_ids.partition = partition;
addressable_device_logical_ids.push_back(std::move(logica_device_ids));
@@ -2443,8 +2444,9 @@
return extras;
}
-StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
- const XlaComputation& computation, CompileOptions options) {
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
+PjRtStreamExecutorClient::Compile(const XlaComputation& computation,
+ CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
VLOG(1) << "PjRtStreamExecutorClient::Compile";
@@ -2477,11 +2479,12 @@
std::move(addressable_devices), this);
TF_RETURN_IF_ERROR(
executable->SetUpDonation(options.parameter_is_tupled_arguments));
- return std::unique_ptr<PjRtExecutable>(std::move(executable));
+ return std::unique_ptr<PjRtLoadedExecutable>(std::move(executable));
}
-StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
- mlir::ModuleOp module, CompileOptions options) {
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
+PjRtStreamExecutorClient::Compile(mlir::ModuleOp module,
+ CompileOptions options) {
XlaComputation xla_computation;
TF_RETURN_IF_ERROR(MlirToXlaComputation(
module, xla_computation,
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
index 36e0b31..fbf6e02 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
@@ -184,23 +184,23 @@
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
- StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
- StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp mlir_module, CompileOptions options) override;
StatusOr<std::optional<std::string>> ExecutableFingerprint(
- const PjRtExecutable& executable) const override {
+ const PjRtLoadedExecutable& executable) const override {
return std::optional<std::string>();
}
StatusOr<std::string> SerializeExecutable(
- const PjRtExecutable& executable) const override {
+ const PjRtLoadedExecutable& executable) const override {
return Unimplemented("SerializeExecutable not implemented on %s",
platform_name());
}
- StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized, CompileOptions options) override {
return Unimplemented("DeserializeExecutable not implemented on %s",
platform_name());
@@ -324,7 +324,7 @@
// `options` in-place.
struct ExecutableExtras {
std::shared_ptr<DeviceAssignment> device_assignment;
- std::vector<PjRtExecutable::LogicalDeviceIds>
+ std::vector<PjRtLoadedExecutable::LogicalDeviceIds>
addressable_device_logical_ids;
std::vector<PjRtDevice*> addressable_devices;
};
@@ -701,7 +701,7 @@
// Wraps one or more XLA LocalExecutables (one per partition, as specified by
// the build options).
-class PjRtStreamExecutorExecutable : public PjRtExecutable {
+class PjRtStreamExecutorExecutable : public PjRtLoadedExecutable {
public:
PjRtStreamExecutorExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
@@ -750,21 +750,21 @@
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override;
- using PjRtExecutable::Execute;
+ using PjRtLoadedExecutable::Execute;
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options,
std::optional<std::vector<PjRtFuture<Status>>>& returned_futures)
override;
- using PjRtExecutable::ExecuteSharded;
+ using PjRtLoadedExecutable::ExecuteSharded;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<Status>>& returned_future,
bool fill_future) override;
- using PjRtExecutable::ExecutePortable;
+ using PjRtLoadedExecutable::ExecutePortable;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
index 4bf2db7..5f33d92 100644
--- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
+++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
@@ -213,7 +213,7 @@
}
StatusOr<std::optional<std::string>> TfrtCpuClient::ExecutableFingerprint(
- const PjRtExecutable& executable) const {
+ const PjRtLoadedExecutable& executable) const {
return std::optional<std::string>();
}
@@ -303,7 +303,7 @@
return {std::move(buffer_indices)};
}
-StatusOr<std::unique_ptr<PjRtExecutable>> TfrtCpuClient::Compile(
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("TfrtCpuClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options;
@@ -323,7 +323,8 @@
computation, &LayoutUtil::GetWithDefaultLayout, options.argument_layouts,
&options.executable_build_options, &argument_layout_pointers));
- std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
+ std::vector<PjRtLoadedExecutable::LogicalDeviceIds>
+ addressable_device_logical_ids;
std::vector<PjRtDevice*> addressable_devices;
if (device_assignment != nullptr) {
addressable_device_logical_ids.reserve(num_replicas * num_partitions);
@@ -336,7 +337,7 @@
VLOG(3) << "Non-local device: " << device_id;
continue;
}
- PjRtExecutable::LogicalDeviceIds logica_device_ids;
+ PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
logica_device_ids.replica = replica;
logica_device_ids.partition = partition;
addressable_device_logical_ids.push_back(std::move(logica_device_ids));
@@ -388,10 +389,10 @@
TF_RETURN_IF_ERROR(
executable->SetUpDonation(options.parameter_is_tupled_arguments));
- return std::unique_ptr<PjRtExecutable>(std::move(executable));
+ return std::unique_ptr<PjRtLoadedExecutable>(std::move(executable));
}
-StatusOr<std::unique_ptr<PjRtExecutable>> TfrtCpuClient::Compile(
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
mlir::ModuleOp module, CompileOptions options) {
XlaComputation xla_computation;
TF_RETURN_IF_ERROR(MlirToXlaComputation(
@@ -1453,7 +1454,7 @@
return OkStatus();
}
-StatusOr<PjRtExecutable::Result> TfrtCpuExecutable::ExecuteHelper(
+StatusOr<PjRtLoadedExecutable::Result> TfrtCpuExecutable::ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
const RunId& run_id, const ExecuteOptions& options,
tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event,
diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h
index 58efa44..d0bd692 100644
--- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h
+++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h
@@ -149,21 +149,21 @@
StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override;
- StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
- StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) override;
StatusOr<std::optional<std::string>> ExecutableFingerprint(
- const PjRtExecutable& executable) const override;
+ const PjRtLoadedExecutable& executable) const override;
StatusOr<std::string> SerializeExecutable(
- const PjRtExecutable& executable) const override {
+ const PjRtLoadedExecutable& executable) const override {
return Unimplemented("SerializeExecutable not implemented on %s",
platform_name());
}
- StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized, CompileOptions options) override {
return Unimplemented("DeserializeExecutable not implemented on %s",
platform_name());
@@ -540,7 +540,7 @@
tfrt::AsyncValueRef<Status> definition_event_ ABSL_GUARDED_BY(mu_);
};
-class TfrtCpuExecutable final : public PjRtExecutable {
+class TfrtCpuExecutable final : public PjRtLoadedExecutable {
public:
TfrtCpuExecutable(
int num_replicas, int num_partitions,
@@ -587,21 +587,21 @@
cpu_executable_->shared_module()};
}
- using PjRtExecutable::Execute;
+ using PjRtLoadedExecutable::Execute;
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options,
std::optional<std::vector<PjRtFuture<Status>>>& returned_futures)
override;
- using PjRtExecutable::ExecuteSharded;
+ using PjRtLoadedExecutable::ExecuteSharded;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<Status>>& returned_future,
bool fill_future) override;
- using PjRtExecutable::ExecutePortable;
+ using PjRtLoadedExecutable::ExecutePortable;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc
index 0699e48..3aa31ff 100644
--- a/tensorflow/compiler/xla/pjrt/tpu_client.cc
+++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc
@@ -133,7 +133,7 @@
}
StatusOr<std::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
- const PjRtExecutable& executable) const {
+ const PjRtLoadedExecutable& executable) const {
if (executable.client() != this) {
return InvalidArgument(
"Passed executable from different client (platform '%s') to "
@@ -154,7 +154,7 @@
}
StatusOr<std::string> PjRtTpuClient::SerializeExecutable(
- const PjRtExecutable& executable) const {
+ const PjRtLoadedExecutable& executable) const {
const PjRtStreamExecutorExecutable* se_executable =
tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(&executable);
if (se_executable->executables().size() > 1) {
@@ -168,8 +168,9 @@
return tpu_executable->Serialize();
}
-StatusOr<std::unique_ptr<PjRtExecutable>> PjRtTpuClient::DeserializeExecutable(
- absl::string_view serialized, CompileOptions options) {
+StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
+PjRtTpuClient::DeserializeExecutable(absl::string_view serialized,
+ CompileOptions options) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuExecutable> tpu_executable,
TpuExecutable::Deserialize(serialized));
@@ -203,7 +204,7 @@
std::move(extras.addressable_devices), this);
TF_RETURN_IF_ERROR(
pjrt_executable->SetUpDonation(options.parameter_is_tupled_arguments));
- return std::unique_ptr<PjRtExecutable>(std::move(pjrt_executable));
+ return std::unique_ptr<PjRtLoadedExecutable>(std::move(pjrt_executable));
}
static StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>>
diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h
index 1cfe9d3..b3b01b8 100644
--- a/tensorflow/compiler/xla/pjrt/tpu_client.h
+++ b/tensorflow/compiler/xla/pjrt/tpu_client.h
@@ -84,12 +84,12 @@
bool EnqueueD2DTransfersOnSrcStream() const override { return false; }
StatusOr<std::optional<std::string>> ExecutableFingerprint(
- const PjRtExecutable& executable) const override;
+ const PjRtLoadedExecutable& executable) const override;
StatusOr<std::string> SerializeExecutable(
- const PjRtExecutable& executable) const override;
+ const PjRtLoadedExecutable& executable) const override;
- StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
+ StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized, CompileOptions options) override;
private:
diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc
index 82d1118..44e354d 100644
--- a/tensorflow/compiler/xla/python/outfeed_receiver.cc
+++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc
@@ -411,7 +411,7 @@
compile_options.executable_build_options.set_device_assignment(
device_assignment);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
devices_[device_idx]->client()->Compile(
computation, std::move(compile_options)));
ExecuteOptions execute_options;
diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc
index 2fde029..6084270 100644
--- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc
+++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc
@@ -43,7 +43,7 @@
compile_options.executable_build_options.set_device_assignment(
device_assignment);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
client->Compile(computation, std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(
diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc
index 2a0d77d..9f54bca 100644
--- a/tensorflow/compiler/xla/python/py_client.cc
+++ b/tensorflow/compiler/xla/python/py_client.cc
@@ -318,7 +318,7 @@
StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
const XlaComputation& computation, CompileOptions options,
std::vector<pybind11::capsule> host_callbacks) {
- std::unique_ptr<PjRtExecutable> executable;
+ std::unique_ptr<PjRtLoadedExecutable> executable;
std::optional<std::string> fingerprint;
{
py::gil_scoped_release gil_release;
@@ -336,7 +336,7 @@
StatusOr<std::shared_ptr<PyExecutable>> PyClient::CompileMlir(
std::string mlir_module, CompileOptions options,
std::vector<pybind11::capsule> host_callbacks) {
- std::unique_ptr<PjRtExecutable> executable;
+ std::unique_ptr<PjRtLoadedExecutable> executable;
std::optional<std::string> fingerprint;
{
py::gil_scoped_release gil_release;
@@ -362,7 +362,7 @@
StatusOr<std::shared_ptr<PyExecutable>> PyClient::DeserializeExecutable(
const std::string& serialized, CompileOptions options,
std::vector<pybind11::capsule> host_callbacks) {
- std::unique_ptr<PjRtExecutable> executable;
+ std::unique_ptr<PjRtLoadedExecutable> executable;
std::optional<std::string> fingerprint;
{
py::gil_scoped_release gil_release;
diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc
index 178388f..18b5e1a 100644
--- a/tensorflow/compiler/xla/python/py_executable.cc
+++ b/tensorflow/compiler/xla/python/py_executable.cc
@@ -27,7 +27,7 @@
namespace py = pybind11;
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
- std::unique_ptr<PjRtExecutable> executable,
+ std::unique_ptr<PjRtLoadedExecutable> executable,
std::shared_ptr<Traceback> traceback,
std::optional<std::string> fingerprint,
std::vector<pybind11::capsule> host_callbacks)
diff --git a/tensorflow/compiler/xla/python/py_executable.h b/tensorflow/compiler/xla/python/py_executable.h
index 4729b86..2994339 100644
--- a/tensorflow/compiler/xla/python/py_executable.h
+++ b/tensorflow/compiler/xla/python/py_executable.h
@@ -37,16 +37,18 @@
class PyExecutable : public std::enable_shared_from_this<PyExecutable> {
public:
PyExecutable(std::shared_ptr<PyClient> client,
- std::unique_ptr<PjRtExecutable> executable,
+ std::unique_ptr<PjRtLoadedExecutable> executable,
std::shared_ptr<Traceback> traceback,
std::optional<std::string> fingerprint,
std::vector<pybind11::capsule> host_callbacks);
~PyExecutable();
std::shared_ptr<PyClient> client() const { return client_; }
- std::shared_ptr<PjRtExecutable> executable() const { return executable_; }
+ std::shared_ptr<PjRtLoadedExecutable> executable() const {
+ return executable_;
+ }
- absl::Span<const PjRtExecutable::LogicalDeviceIds>
+ absl::Span<const PjRtLoadedExecutable::LogicalDeviceIds>
addressable_device_logical_ids() const {
return executable_->addressable_device_logical_ids();
}
@@ -80,9 +82,11 @@
Traceback* traceback() { return traceback_.get(); }
- const PjRtExecutable& pjrt_executable() const { return *executable_; }
+ const PjRtLoadedExecutable& pjrt_executable() const { return *executable_; }
- PjRtExecutable* mutable_pjrt_executable() const { return executable_.get(); }
+ PjRtLoadedExecutable* mutable_pjrt_executable() const {
+ return executable_.get();
+ }
const ExecuteOptions& options() const { return options_; }
const std::optional<std::string>& fingerprint() const { return fingerprint_; }
@@ -93,7 +97,7 @@
friend class PyClient;
std::shared_ptr<PyClient> client_;
- std::shared_ptr<PjRtExecutable> executable_;
+ std::shared_ptr<PjRtLoadedExecutable> executable_;
std::shared_ptr<Traceback> traceback_;
// Identical executables (i.e. representing the same program) will have the