Support multi-core execution.
PiperOrigin-RevId: 323932743
Change-Id: I3235ddec06018eca323c974fabfcb094cceda10a
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
index 0f451e5..3522ace 100644
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
@@ -649,8 +649,9 @@
tensorflow::down_cast<const tpu::TpuProgramGroup*>(
entry.tpu_program_group());
CHECK_NE(tpu_program_group, nullptr);
+ const int core_index = entry.core_index();
const TPUExecutableInfoProto& executable =
- tpu_program_group->executable_info();
+ tpu_program_group->executable_info(core_index);
xla::Backend* const backend = node_context->backend();
xla::TransferManager* const transfer_manager = backend->transfer_manager();
@@ -749,8 +750,7 @@
// all subsequent writes to the program that could possibly clobber the memory
// will depend on the program to finish.
const TPUHostTransferInfoProto& host_transfer_info =
- tpu_program_group->host_transfer_info();
- const int core_index = entry.core_index();
+ tpu_program_group->host_transfer_info(core_index);
TF_ASSIGN_OR_RETURN(
xla::ExecutionOutput output,
TPUExecute(executable, host_transfer_info,
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc
index 27b699e..2ee926f 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group.cc
+++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc
@@ -98,55 +98,62 @@
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
per_core_variable_indices, device_assignment);
}
+} // namespace
-Status CreateTpuProgramGroup(
- absl::Span<XLA_TpuProgram* const> xla_tpu_programs,
- TpuProgramGroupInterface* tpu_program_group_interface) {
+void TpuProgramGroup::Initialize(
+ absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
CHECK_GT(xla_tpu_programs.size(), 0);
- TpuProgramGroup* tpu_program_group =
- tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
- CHECK_NE(tpu_program_group, nullptr);
- tpu_program_group->set_tpu_programs(xla_tpu_programs);
+ set_tpu_programs(xla_tpu_programs);
- // TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
- bool may_modify_variables;
- TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_programs[0],
- &may_modify_variables);
- tpu_program_group->set_may_modify_variables(
- std::vector<bool>(1, may_modify_variables));
+ std::vector<bool> may_modify_variables_array(xla_tpu_programs.size(), false);
+ std::vector<TPUExecutableInfoProto> executable_infos(xla_tpu_programs.size());
+ std::vector<TPUHostTransferInfoProto> host_transfer_infos(
+ xla_tpu_programs.size());
+ std::vector<xla::HloProto> hlo_metadatas(xla_tpu_programs.size());
+ for (size_t i = 0; i < xla_tpu_programs.size(); ++i) {
+ const XLA_TpuProgram* xla_tpu_program = xla_tpu_programs[i];
+ bool may_modify_variables;
+ TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
+ xla_tpu_program, &may_modify_variables);
+ may_modify_variables_array[i] = may_modify_variables;
- TpuSerializedProto serialized_executable_info;
- TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
- xla_tpu_programs[0], &serialized_executable_info);
- TPUExecutableInfoProto executable_info =
- se_tpu::DeserializeProto<TPUExecutableInfoProto>(
- serialized_executable_info);
- tpu_program_group->set_executable_info(executable_info);
- StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
+ TpuSerializedProto serialized_executable_info;
+ TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
+ xla_tpu_program, &serialized_executable_info);
+ TPUExecutableInfoProto executable_info =
+ se_tpu::DeserializeProto<TPUExecutableInfoProto>(
+ serialized_executable_info);
+ executable_infos[i] = executable_info;
+ StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
- TPUHostTransferInfoProto host_transfer_info;
- TpuSerializedProto serialized_host_transfer_info;
- TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
- xla_tpu_programs[0], &serialized_host_transfer_info);
- if (serialized_host_transfer_info.size > 0) {
- host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
- serialized_host_transfer_info);
- StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
+ TPUHostTransferInfoProto host_transfer_info;
+ TpuSerializedProto serialized_host_transfer_info;
+ TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
+ xla_tpu_program, &serialized_host_transfer_info);
+ if (serialized_host_transfer_info.size > 0) {
+ host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
+ serialized_host_transfer_info);
+ StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
+ }
+ host_transfer_infos[i] = host_transfer_info;
+
+ TpuSerializedProto serialized_hlo_metadata;
+ TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_program,
+ &serialized_hlo_metadata);
+ xla::HloProto hlo_metadata =
+ se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
+ hlo_metadatas[i] = hlo_metadata;
+ StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
}
- tpu_program_group->set_host_transfer_info(host_transfer_info);
- TpuSerializedProto serialized_hlo_metadata;
- TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_programs[0],
- &serialized_hlo_metadata);
- xla::HloProto hlo_metadata =
- se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
- tpu_program_group->set_hlo_metadata(hlo_metadata);
- StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
-
- return Status::OK();
+ may_modify_variables_ = may_modify_variables_array;
+ executable_infos_ = executable_infos;
+ host_transfer_infos_ = host_transfer_infos;
+ hlo_metadatas_ = hlo_metadatas;
+ RefreshHloMetadatasPtrs();
}
-} // namespace
+size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
int64_t TpuProgramGroup::program_size() const {
int64_t total_size = 0;
@@ -201,12 +208,6 @@
TF_RET_CHECK(per_core_output_shapes.size() ==
per_core_variable_indices.size());
- // TODO(henrytan): add an interface to TpuProgramGroupInterface to set
- // may_modify_variables.
- TpuProgramGroup* tpu_program_group =
- tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
- tpu_program_group->may_modify_variables_ = may_modify_variables;
-
// With shardable input/output pairs, XLA could generate separate
// sharding/unsharding programs along with the main program. The
// sharding/unsharding programs will be in nested entries of the AOT
@@ -221,17 +222,20 @@
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
xla_tpu_programs.size() == metadata.num_cores_per_replica());
- TF_RETURN_IF_ERROR(
- CreateTpuProgramGroup(xla_tpu_programs, tpu_program_group));
+ // TODO(henrytan): add an interface to TpuProgramGroupInterface to set
+ // may_modify_variables.
+ TpuProgramGroup* tpu_program_group =
+ tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
+ tpu_program_group->Initialize(xla_tpu_programs);
+ tpu_program_group->may_modify_variables_ = may_modify_variables;
return Status::OK();
}
TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
: may_modify_variables_(std::move(other.may_modify_variables_)),
- host_compute_metadata_(std::move(other.host_compute_metadata_)),
tpu_programs_(std::move(other.tpu_programs_)),
- executable_info_(std::move(other.executable_info_)),
- host_transfer_info_(std::move(other.host_transfer_info_)),
+ executable_infos_(std::move(other.executable_infos_)),
+ host_transfer_infos_(std::move(other.host_transfer_infos_)),
hlo_metadatas_(std::move(other.hlo_metadatas_)) {
RefreshHloMetadatasPtrs();
}
@@ -277,16 +281,6 @@
may_modify_variables_ = may_modify_variables;
}
-const tf2xla::HostComputeMetadata& TpuProgramGroup::host_compute_metadata()
- const {
- return host_compute_metadata_;
-}
-
-void TpuProgramGroup::set_host_compute_metadata(
- const tf2xla::HostComputeMetadata& host_compute_metadata) {
- host_compute_metadata_ = host_compute_metadata;
-}
-
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
return tpu_programs_;
}
@@ -305,22 +299,18 @@
}
}
-const TPUExecutableInfoProto& TpuProgramGroup::executable_info() const {
- return executable_info_;
+const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
+ int index) const {
+ CHECK_GE(index, 0);
+ CHECK_LT(index, executable_infos_.size());
+ return executable_infos_[index];
}
-void TpuProgramGroup::set_executable_info(
- const TPUExecutableInfoProto& executable_info) {
- executable_info_ = executable_info;
-}
-
-const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info() const {
- return host_transfer_info_;
-}
-
-void TpuProgramGroup::set_host_transfer_info(
- const TPUHostTransferInfoProto& host_transfer_info) {
- host_transfer_info_ = host_transfer_info;
+const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
+ int index) const {
+ CHECK_GE(index, 0);
+ CHECK_LT(index, host_transfer_infos_.size());
+ return host_transfer_infos_[index];
}
/*static*/
@@ -348,14 +338,13 @@
TF_RET_CHECK(count == 1 ||
count == compilation_request.metadata().num_cores_per_replica());
- VLOG(1) << "CreateTpuProgramGroup";
- Status serialize_status =
- CreateTpuProgramGroup(absl::MakeConstSpan(&xla_tpu_programs[0], count),
- tpu_program_group_interface);
- VLOG(1) << absl::StrCat("Run CreateTpuProgramGroup completed. StatusCode: ",
- serialize_status.code());
+ VLOG(1) << "Initialize TpuProgramGroup.";
+ TpuProgramGroup* tpu_program_group =
+ tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
+ tpu_program_group->Initialize(
+ absl::MakeConstSpan(&xla_tpu_programs[0], count));
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
- return serialize_status;
+ return status.status();
}
} // namespace tpu
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h
index 5a36fa4..bceede5 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group.h
+++ b/tensorflow/core/tpu/kernels/tpu_program_group.h
@@ -102,11 +102,14 @@
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
TpuProgramGroupInterface* tpu_program_group_interface);
+ // Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
+ void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
+
TpuProgramGroup() = default;
TpuProgramGroup(TpuProgramGroup&& other);
TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
- size_t program_count() const override { return tpu_programs_.size(); }
+ size_t program_count() const override;
int64_t program_size() const override;
@@ -120,21 +123,13 @@
const std::vector<bool>& may_modify_variables() const override;
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
- const tf2xla::HostComputeMetadata& host_compute_metadata() const;
- void set_host_compute_metadata(
- const tf2xla::HostComputeMetadata& host_compute_metadata);
-
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
const XLA_TpuProgram* tpu_program(int index) const;
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
- const TPUExecutableInfoProto& executable_info() const;
- void set_executable_info(const TPUExecutableInfoProto& executable_info);
+ const TPUExecutableInfoProto& executable_info(int index) const;
- const TPUHostTransferInfoProto& host_transfer_info() const;
- void set_host_transfer_info(
- const TPUHostTransferInfoProto& host_transfer_info);
-
+ const TPUHostTransferInfoProto& host_transfer_info(int index) const;
void set_hlo_metadata(const xla::HloProto& hlo_metadata);
const xla::HloProto* hlo_metadata(int index) const;
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
@@ -143,11 +138,10 @@
void RefreshHloMetadatasPtrs();
std::vector<bool> may_modify_variables_;
- tf2xla::HostComputeMetadata host_compute_metadata_;
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
- TPUExecutableInfoProto executable_info_;
- TPUHostTransferInfoProto host_transfer_info_;
+ std::vector<TPUExecutableInfoProto> executable_infos_;
+ std::vector<TPUHostTransferInfoProto> host_transfer_infos_;
// To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
// signature, we store HloProto values in hlo_metadatas_ when