Support multi-core execution.

PiperOrigin-RevId: 323932743
Change-Id: I3235ddec06018eca323c974fabfcb094cceda10a
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
index 0f451e5..3522ace 100644
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
@@ -649,8 +649,9 @@
       tensorflow::down_cast<const tpu::TpuProgramGroup*>(
           entry.tpu_program_group());
   CHECK_NE(tpu_program_group, nullptr);
+  const int core_index = entry.core_index();
   const TPUExecutableInfoProto& executable =
-      tpu_program_group->executable_info();
+      tpu_program_group->executable_info(core_index);
 
   xla::Backend* const backend = node_context->backend();
   xla::TransferManager* const transfer_manager = backend->transfer_manager();
@@ -749,8 +750,7 @@
   // all subsequent writes to the program that could possibly clobber the memory
   // will depend on the program to finish.
   const TPUHostTransferInfoProto& host_transfer_info =
-      tpu_program_group->host_transfer_info();
-  const int core_index = entry.core_index();
+      tpu_program_group->host_transfer_info(core_index);
   TF_ASSIGN_OR_RETURN(
       xla::ExecutionOutput output,
       TPUExecute(executable, host_transfer_info,
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc
index 27b699e..2ee926f 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group.cc
+++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc
@@ -98,55 +98,62 @@
       compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
       per_core_variable_indices, device_assignment);
 }
+}  // namespace
 
-Status CreateTpuProgramGroup(
-    absl::Span<XLA_TpuProgram* const> xla_tpu_programs,
-    TpuProgramGroupInterface* tpu_program_group_interface) {
+void TpuProgramGroup::Initialize(
+    absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
   CHECK_GT(xla_tpu_programs.size(), 0);
-  TpuProgramGroup* tpu_program_group =
-      tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
-  CHECK_NE(tpu_program_group, nullptr);
-  tpu_program_group->set_tpu_programs(xla_tpu_programs);
+  set_tpu_programs(xla_tpu_programs);
 
-  // TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
-  bool may_modify_variables;
-  TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_programs[0],
-                                                        &may_modify_variables);
-  tpu_program_group->set_may_modify_variables(
-      std::vector<bool>(1, may_modify_variables));
+  std::vector<bool> may_modify_variables_array(xla_tpu_programs.size(), false);
+  std::vector<TPUExecutableInfoProto> executable_infos(xla_tpu_programs.size());
+  std::vector<TPUHostTransferInfoProto> host_transfer_infos(
+      xla_tpu_programs.size());
+  std::vector<xla::HloProto> hlo_metadatas(xla_tpu_programs.size());
+  for (size_t i = 0; i < xla_tpu_programs.size(); ++i) {
+    const XLA_TpuProgram* xla_tpu_program = xla_tpu_programs[i];
+    bool may_modify_variables;
+    TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
+        xla_tpu_program, &may_modify_variables);
+    may_modify_variables_array[i] = may_modify_variables;
 
-  TpuSerializedProto serialized_executable_info;
-  TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
-      xla_tpu_programs[0], &serialized_executable_info);
-  TPUExecutableInfoProto executable_info =
-      se_tpu::DeserializeProto<TPUExecutableInfoProto>(
-          serialized_executable_info);
-  tpu_program_group->set_executable_info(executable_info);
-  StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
+    TpuSerializedProto serialized_executable_info;
+    TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
+        xla_tpu_program, &serialized_executable_info);
+    TPUExecutableInfoProto executable_info =
+        se_tpu::DeserializeProto<TPUExecutableInfoProto>(
+            serialized_executable_info);
+    executable_infos[i] = executable_info;
+    StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
 
-  TPUHostTransferInfoProto host_transfer_info;
-  TpuSerializedProto serialized_host_transfer_info;
-  TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
-      xla_tpu_programs[0], &serialized_host_transfer_info);
-  if (serialized_host_transfer_info.size > 0) {
-    host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
-        serialized_host_transfer_info);
-    StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
+    TPUHostTransferInfoProto host_transfer_info;
+    TpuSerializedProto serialized_host_transfer_info;
+    TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
+        xla_tpu_program, &serialized_host_transfer_info);
+    if (serialized_host_transfer_info.size > 0) {
+      host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
+          serialized_host_transfer_info);
+      StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
+    }
+    host_transfer_infos[i] = host_transfer_info;
+
+    TpuSerializedProto serialized_hlo_metadata;
+    TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_program,
+                                                   &serialized_hlo_metadata);
+    xla::HloProto hlo_metadata =
+        se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
+    hlo_metadatas[i] = hlo_metadata;
+    StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
   }
-  tpu_program_group->set_host_transfer_info(host_transfer_info);
 
-  TpuSerializedProto serialized_hlo_metadata;
-  TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_programs[0],
-                                                 &serialized_hlo_metadata);
-  xla::HloProto hlo_metadata =
-      se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
-  tpu_program_group->set_hlo_metadata(hlo_metadata);
-  StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
-
-  return Status::OK();
+  may_modify_variables_ = may_modify_variables_array;
+  executable_infos_ = executable_infos;
+  host_transfer_infos_ = host_transfer_infos;
+  hlo_metadatas_ = hlo_metadatas;
+  RefreshHloMetadatasPtrs();
 }
 
-}  // namespace
+size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
 
 int64_t TpuProgramGroup::program_size() const {
   int64_t total_size = 0;
@@ -201,12 +208,6 @@
   TF_RET_CHECK(per_core_output_shapes.size() ==
                per_core_variable_indices.size());
 
-  // TODO(henrytan): add an interface to TpuProgramGroupInterface to set
-  // may_modify_variables.
-  TpuProgramGroup* tpu_program_group =
-      tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
-  tpu_program_group->may_modify_variables_ = may_modify_variables;
-
   // With shardable input/output pairs, XLA could generate separate
   // sharding/unsharding programs along with the main program. The
   // sharding/unsharding programs will be in nested entries of the AOT
@@ -221,17 +222,20 @@
   TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
                xla_tpu_programs.size() == metadata.num_cores_per_replica());
 
-  TF_RETURN_IF_ERROR(
-      CreateTpuProgramGroup(xla_tpu_programs, tpu_program_group));
+  // TODO(henrytan): add an interface to TpuProgramGroupInterface to set
+  // may_modify_variables.
+  TpuProgramGroup* tpu_program_group =
+      tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
+  tpu_program_group->Initialize(xla_tpu_programs);
+  tpu_program_group->may_modify_variables_ = may_modify_variables;
   return Status::OK();
 }
 
 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
     : may_modify_variables_(std::move(other.may_modify_variables_)),
-      host_compute_metadata_(std::move(other.host_compute_metadata_)),
       tpu_programs_(std::move(other.tpu_programs_)),
-      executable_info_(std::move(other.executable_info_)),
-      host_transfer_info_(std::move(other.host_transfer_info_)),
+      executable_infos_(std::move(other.executable_infos_)),
+      host_transfer_infos_(std::move(other.host_transfer_infos_)),
       hlo_metadatas_(std::move(other.hlo_metadatas_)) {
   RefreshHloMetadatasPtrs();
 }
@@ -277,16 +281,6 @@
   may_modify_variables_ = may_modify_variables;
 }
 
-const tf2xla::HostComputeMetadata& TpuProgramGroup::host_compute_metadata()
-    const {
-  return host_compute_metadata_;
-}
-
-void TpuProgramGroup::set_host_compute_metadata(
-    const tf2xla::HostComputeMetadata& host_compute_metadata) {
-  host_compute_metadata_ = host_compute_metadata;
-}
-
 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
   return tpu_programs_;
 }
@@ -305,22 +299,18 @@
   }
 }
 
-const TPUExecutableInfoProto& TpuProgramGroup::executable_info() const {
-  return executable_info_;
+const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
+    int index) const {
+  CHECK_GE(index, 0);
+  CHECK_LT(index, executable_infos_.size());
+  return executable_infos_[index];
 }
 
-void TpuProgramGroup::set_executable_info(
-    const TPUExecutableInfoProto& executable_info) {
-  executable_info_ = executable_info;
-}
-
-const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info() const {
-  return host_transfer_info_;
-}
-
-void TpuProgramGroup::set_host_transfer_info(
-    const TPUHostTransferInfoProto& host_transfer_info) {
-  host_transfer_info_ = host_transfer_info;
+const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
+    int index) const {
+  CHECK_GE(index, 0);
+  CHECK_LT(index, host_transfer_infos_.size());
+  return host_transfer_infos_[index];
 }
 
 /*static*/
@@ -348,14 +338,13 @@
   TF_RET_CHECK(count == 1 ||
                count == compilation_request.metadata().num_cores_per_replica());
 
-  VLOG(1) << "CreateTpuProgramGroup";
-  Status serialize_status =
-      CreateTpuProgramGroup(absl::MakeConstSpan(&xla_tpu_programs[0], count),
-                            tpu_program_group_interface);
-  VLOG(1) << absl::StrCat("Run CreateTpuProgramGroup completed. StatusCode: ",
-                          serialize_status.code());
+  VLOG(1) << "Initialize TpuProgramGroup.";
+  TpuProgramGroup* tpu_program_group =
+      tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
+  tpu_program_group->Initialize(
+      absl::MakeConstSpan(&xla_tpu_programs[0], count));
   TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
-  return serialize_status;
+  return status.status();
 }
 
 }  // namespace tpu
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h
index 5a36fa4..bceede5 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group.h
+++ b/tensorflow/core/tpu/kernels/tpu_program_group.h
@@ -102,11 +102,14 @@
       const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
       TpuProgramGroupInterface* tpu_program_group_interface);
 
+  // Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
+  void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
+
   TpuProgramGroup() = default;
   TpuProgramGroup(TpuProgramGroup&& other);
   TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
 
-  size_t program_count() const override { return tpu_programs_.size(); }
+  size_t program_count() const override;
 
   int64_t program_size() const override;
 
@@ -120,21 +123,13 @@
   const std::vector<bool>& may_modify_variables() const override;
   void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
 
-  const tf2xla::HostComputeMetadata& host_compute_metadata() const;
-  void set_host_compute_metadata(
-      const tf2xla::HostComputeMetadata& host_compute_metadata);
-
   const std::vector<XLA_TpuProgram*>& tpu_programs() const;
   const XLA_TpuProgram* tpu_program(int index) const;
   void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
 
-  const TPUExecutableInfoProto& executable_info() const;
-  void set_executable_info(const TPUExecutableInfoProto& executable_info);
+  const TPUExecutableInfoProto& executable_info(int index) const;
 
-  const TPUHostTransferInfoProto& host_transfer_info() const;
-  void set_host_transfer_info(
-      const TPUHostTransferInfoProto& host_transfer_info);
-
+  const TPUHostTransferInfoProto& host_transfer_info(int index) const;
   void set_hlo_metadata(const xla::HloProto& hlo_metadata);
   const xla::HloProto* hlo_metadata(int index) const;
   absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
@@ -143,11 +138,10 @@
   void RefreshHloMetadatasPtrs();
 
   std::vector<bool> may_modify_variables_;
-  tf2xla::HostComputeMetadata host_compute_metadata_;
 
   std::vector<XLA_TpuProgram*> tpu_programs_;  // Not owned.
-  TPUExecutableInfoProto executable_info_;
-  TPUHostTransferInfoProto host_transfer_info_;
+  std::vector<TPUExecutableInfoProto> executable_infos_;
+  std::vector<TPUHostTransferInfoProto> host_transfer_infos_;
 
   // To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
   // signature, we store HloProto values in hlo_metadatas_ when