Open sourcing some TPU-related work
PiperOrigin-RevId: 315431095
Change-Id: I734632c0e5723dfca37acf53bbbd2b378b04c95d
diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD
new file mode 100644
index 0000000..ef9e4a0
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/BUILD
@@ -0,0 +1,55 @@
+# Contains graph rewrites for TPU runtimes and optimizations.
+
+package(
+ default_visibility = [
+ "//tensorflow/core/tpu:__subpackages__",
+ "//tensorflow/stream_executor/tpu:__subpackages__",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "distributed_tpu_configuration_rewrite_registration",
+ srcs = ["distributed_tpu_configuration_rewrite_registration.cc"],
+ deps = [
+ ":distributed_tpu_configuration_rewrite_pass",
+ "//tensorflow/core:core_cpu",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "distributed_tpu_configuration_rewrite_pass",
+ srcs = [
+ "distributed_tpu_configuration_rewrite_pass.cc",
+ ],
+ hdrs = [
+ "distributed_tpu_configuration_rewrite_pass.h",
+ ],
+ deps = [
+ ":distributed_tpu_rewrite_helpers",
+ "//tensorflow/cc:scope",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/protobuf/tpu:topology_proto_cc",
+ "//tensorflow/core/tpu:tpu_init_mode",
+ "//tensorflow/core/tpu/kernels:tpu_compile_op_options",
+ ],
+)
+
+cc_library(
+ name = "distributed_tpu_rewrite_helpers",
+ srcs = ["distributed_tpu_rewrite_helpers.cc"],
+ hdrs = ["distributed_tpu_rewrite_helpers.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/tpu:tpu_defs",
+ ],
+)
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc
new file mode 100644
index 0000000..3b1e9d7
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc
@@ -0,0 +1,402 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Configuration for distributed TPU jobs
+
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
+
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
+#include "tensorflow/core/tpu/tpu_init_mode.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/dump_graph.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr char kIdentityOp[] = "Identity";
+constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
+constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
+constexpr char kWaitOp[] = "_WaitForDistributedTPU";
+constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
+constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
+constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
+constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
+constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
+constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
+constexpr int kDefaultStartupTimeout = 20;
+
+Status AddConfigurationNode(const string& configuration_device_name,
+ int number_of_hosts, Graph* graph,
+ bool enable_whole_mesh_compilations,
+ Node** configuration_node) {
+ NodeDef config_def;
+ config_def.set_name(graph->NewName("configure_distributed_tpu"));
+ config_def.set_op(kInternalConfigureOp);
+ config_def.set_device(configuration_device_name);
+ AddNodeAttr("N", number_of_hosts, &config_def);
+ AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
+ &config_def);
+ // TODO(shikharagarwal): Fill with appropriate original node debug info.
+
+ Status status;
+ *configuration_node = graph->AddNode(config_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ (*configuration_node)->set_assigned_device_name(configuration_device_name);
+ return Status::OK();
+}
+
+Status AddHostConfigNode(const string& host_device_name,
+ Node* configuration_node, Graph* graph,
+ bool enable_whole_mesh_compilations,
+ Node** host_configuration_node) {
+ NodeDef host_config_def;
+ host_config_def.set_name(graph->NewName("configure_tpu_host"));
+ host_config_def.set_op(kHostConfigureOp);
+ host_config_def.set_device(host_device_name);
+ AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
+ &host_config_def);
+ MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
+
+ Status status;
+ *host_configuration_node = graph->AddNode(host_config_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ (*host_configuration_node)->set_assigned_device_name(host_device_name);
+ graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
+ return Status::OK();
+}
+
+Status AddWaitNode(const string& configuration_device_name,
+ const std::vector<Node*>& host_configuration_nodes,
+ Graph* graph, Node** wait_node) {
+ NodeDef wait_def;
+ wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
+ wait_def.set_op(kWaitOp);
+ wait_def.set_device(configuration_device_name);
+ AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
+ &wait_def);
+ AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
+ if (!host_configuration_nodes.empty()) {
+ MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
+ &wait_def);
+ }
+
+ Status status;
+ *wait_node = graph->AddNode(wait_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ (*wait_node)->set_assigned_device_name(configuration_device_name);
+ // Get the inputs from the host configuration nodes.
+ for (int i = 0; i < host_configuration_nodes.size(); ++i) {
+ graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
+ }
+ return Status::OK();
+}
+
+Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
+ Graph* graph, Node** global_tpu_array_node) {
+ NodeDef global_tpu_array_def;
+ global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
+ global_tpu_array_def.set_op(kGlobalTPUArrayOp);
+ global_tpu_array_def.set_device(host_device_name);
+ MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
+
+ Status status;
+ *global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ (*global_tpu_array_node)->set_assigned_device_name(host_device_name);
+ graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
+ return Status::OK();
+}
+
+Status AddSynchronizationNode(
+ const NodeDef& sync_node_def, const string& device_name,
+ const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
+ const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+ output_dependencies,
+ Graph* graph) {
+ NodeDef sync_def;
+ sync_def.set_name(sync_node_def.name());
+ sync_def.set_op(kIdentityOp);
+ sync_def.set_device(device_name);
+ AddNodeAttr("T", DT_STRING, &sync_def);
+ MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
+
+ Status status;
+ Node* sync_node = graph->AddNode(sync_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ sync_node->set_assigned_device_name(device_name);
+ // Add control edges from the global array id nodes.
+ for (auto node : global_array_id_nodes) {
+ graph->AddControlEdge(node, sync_node);
+ }
+ // Forward the data from the wait node.
+ graph->AddEdge(wait_node, 0, sync_node, 0);
+ // Replace the output edges.
+ for (const DistributedTPURewriteHelpers::OutputDependency& dep :
+ output_dependencies) {
+ if (dep.dst_input == Graph::kControlSlot) {
+ graph->AddControlEdge(sync_node, dep.dst);
+ } else {
+ graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
+ }
+ }
+ return Status::OK();
+}
+
+
+Status AddShutdownNode(
+ const NodeDef& shutdown_node_def, const string& shutdown_device_name,
+ const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+ output_dependencies,
+ Graph* graph, Node** shutdown_node) {
+ NodeDef shutdown_def;
+ shutdown_def.set_name(shutdown_node_def.name());
+ shutdown_def.set_op(kInternalShutdownOp);
+ shutdown_def.set_device(shutdown_device_name);
+ MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
+
+ Status status;
+ *shutdown_node = graph->AddNode(shutdown_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ (*shutdown_node)->set_assigned_device_name(shutdown_device_name);
+ // Replace the output control edges.
+ for (const DistributedTPURewriteHelpers::OutputDependency& dep :
+ output_dependencies) {
+ if (dep.dst_input != Graph::kControlSlot) {
+ return errors::Internal("Shutdown node had non-control edge output");
+ }
+ graph->AddControlEdge(*shutdown_node, dep.dst);
+ }
+ return Status::OK();
+}
+
+Status AddHostDisconnectNode(const string& host_device_name,
+ const std::vector<Node*>& input_dependencies,
+ Node* post_disconnect_node, int output_index,
+ Graph* graph) {
+ NodeDef host_disconnect_def;
+ host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
+ host_disconnect_def.set_op(kHostDisconnectOp);
+ host_disconnect_def.set_device(host_device_name);
+ MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
+ &host_disconnect_def);
+
+ Status status;
+ Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ host_disconnect_node->set_assigned_device_name(host_device_name);
+ // Replace the input control edges.
+ for (Node* src_node : input_dependencies) {
+ graph->AddControlEdge(src_node, host_disconnect_node);
+ }
+ if (output_index == -1) {
+ graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
+ } else {
+ graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status DistributedTPUConfigurationRewritePass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
+
+ Graph* graph = options.graph->get();
+
+ if (VLOG_IS_ON(1)) {
+ DumpGraphToFile("distributed_tpu_configuration_before", *graph,
+ options.flib_def);
+ }
+
+ // This pass can only run in the session master, which should fill
+ // in the device_set field to the options.
+ TF_RET_CHECK(options.device_set != nullptr);
+
+ TF_RETURN_IF_ERROR(
+ DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
+ kConfigureOp, graph, *options.device_set,
+ [](const NodeDef& configuration_node_def,
+ const string& configuration_device_name,
+ const std::vector<Device*>& host_devices,
+ const std::vector<Node*>& input_dependencies,
+ const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+ output_dependencies,
+ Graph* graph) -> Status {
+ const std::string& embedding_attr_string = GetNodeAttrString(
+ AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
+
+ if (!embedding_attr_string.empty()) {
+ return errors::InvalidArgument("embedding_config must be empty.");
+ }
+
+ bool is_global_init = false;
+ bool enable_whole_mesh_compilations = false;
+ TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
+ "is_global_init", &is_global_init));
+ TryGetNodeAttr(configuration_node_def,
+ "enable_whole_mesh_compilations",
+ &enable_whole_mesh_compilations);
+ TF_RETURN_IF_ERROR(SetTPUInitMode(
+ is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
+
+ bool compilation_failure_closes_chips;
+ TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
+ "compilation_failure_closes_chips",
+ &compilation_failure_closes_chips));
+ internal::SetTpuCompilationFailureClosesChips(
+ compilation_failure_closes_chips);
+
+ // Add the global TPU system configuration node.
+ Node* configuration_node;
+ TF_RETURN_IF_ERROR(AddConfigurationNode(
+ configuration_device_name, host_devices.size(), graph,
+ enable_whole_mesh_compilations, &configuration_node));
+
+ // Add the host disconnect nodes.
+ for (int i = 0; i < host_devices.size(); ++i) {
+ const auto host_device = host_devices[i];
+ TF_RETURN_IF_ERROR(
+ AddHostDisconnectNode(host_device->name(), input_dependencies,
+ configuration_node, i, graph));
+ }
+
+ // Add the host configuration nodes.
+ std::vector<Node*> host_configuration_nodes;
+ for (const auto host_device : host_devices) {
+ Node* host_configuration_node;
+ TF_RETURN_IF_ERROR(AddHostConfigNode(
+ host_device->name(), configuration_node, graph,
+ enable_whole_mesh_compilations, &host_configuration_node));
+ host_configuration_nodes.push_back(host_configuration_node);
+ }
+
+ // Add the node to wait for the system configuration to
+ // stabilize. Use the name of the original dummy Op in case it was
+ // the target of a Session::Run call.
+ Node* wait_node;
+ TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
+ host_configuration_nodes, graph,
+ &wait_node));
+
+ // Add the nodes to set the global TPU ids at each host.
+ std::vector<Node*> global_array_id_nodes;
+ for (const auto host_device : host_devices) {
+ Node* global_array_id_node;
+ TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
+ wait_node, graph,
+ &global_array_id_node));
+ global_array_id_nodes.push_back(global_array_id_node);
+ }
+
+ if (host_devices.empty()) {
+ return errors::InvalidArgument("TPU job contains no CPU devices");
+ }
+ TF_RET_CHECK(!host_devices.empty());
+
+ TF_RETURN_IF_ERROR(AddSynchronizationNode(
+ configuration_node_def, host_devices.front()->name(),
+ global_array_id_nodes, wait_node, output_dependencies, graph));
+
+ return Status::OK();
+ }));
+
+ if (VLOG_IS_ON(1)) {
+ DumpGraphToFile("distributed_tpu_configuration_after", *graph,
+ options.flib_def);
+ }
+
+ VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
+ return Status::OK();
+}
+
+Status DistributedTPUShutdownRewritePass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
+
+ Graph* graph = options.graph->get();
+
+ if (VLOG_IS_ON(1)) {
+ DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
+ options.flib_def);
+ }
+
+ // This pass can only run in the session master, which should fill
+ // in the device_set field to the options.
+ TF_RET_CHECK(options.device_set != nullptr);
+
+ TF_RETURN_IF_ERROR(
+ DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
+ kShutdownOp, graph, *options.device_set,
+ [](const NodeDef& shutdown_node_def,
+ const string& shutdown_device_name,
+ const std::vector<Device*>& host_devices,
+ const std::vector<Node*>& input_dependencies,
+ const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+ output_dependencies,
+ Graph* graph) -> Status {
+ Node* shutdown_node;
+ TF_RETURN_IF_ERROR(
+ AddShutdownNode(shutdown_node_def, shutdown_device_name,
+ output_dependencies, graph, &shutdown_node));
+
+ // Add the host disconnect nodes.
+ for (const auto host_device : host_devices) {
+ TF_RETURN_IF_ERROR(
+ AddHostDisconnectNode(host_device->name(), input_dependencies,
+ shutdown_node, -1, graph));
+ }
+
+ return Status::OK();
+ }));
+
+ if (VLOG_IS_ON(1)) {
+ DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
+ }
+
+ VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h
new file mode 100644
index 0000000..191f32f
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Rewrites ConfigureDistributedTPU Op into a graph that configures each host.
+//
+// See the comment at the top of
+// third_party/tensorflow/core/ops/tpu_configuration_ops.cc to see the
+// sequence of Ops used to configure a distributed TPU system.
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+// Replaces dummy ConfigureDistributedTPU Ops assigned to TPU_SYSTEM
+// devices with _ConfigureDistributedTPU and _WaitForDistributedTPU
+// Ops on TPU_SYSTEM, and _InitializeHostForDistributedTPU on the CPU
+// device of each host in the same job as the given TPU_SYSTEM device.
+class DistributedTPUConfigurationRewritePass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+// Replaces dummy ShutdownDistributedTPU Ops assigned to TPU_SYSTEM
+// devices with _ShutdownDistributedTPU Ops on TPU_SYSTEM and
+// _DisconnectHostFromDistributedTPUSystem on the CPU device of each
+// host in the same job as the given TPU_SYSTEM device.
+class DistributedTPUShutdownRewritePass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc
new file mode 100644
index 0000000..db2b3a5
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
+
+namespace tensorflow {
+namespace {
+
+// This pass removes the TPUEmbeddingConfiguration in ConfigureDistributedTPU.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
+ DistributedTPUConfigurationRewritePass);
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
+ DistributedTPUShutdownRewritePass);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc
new file mode 100644
index 0000000..965a174
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc
@@ -0,0 +1,255 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions for TPU rewrite passes.
+
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// LINT.IfChange
+Status DistributedTPURewriteHelpers::GetSystemDevice(
+ const string& system_spec_string, const DeviceSet& device_set,
+ DeviceNameUtils::ParsedName* system_spec, Device** system_device) {
+ if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) {
+ system_spec->Clear();
+ }
+
+ // Callers may have relied on an Op only being registered on TPU_SYSTEM
+ // devices to ensure the Op is placed there. Augment the device spec to make
+ // the device type explicit.
+ if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) {
+ system_spec->type = DEVICE_TPU_SYSTEM;
+ system_spec->has_type = true;
+ system_spec->id = 0;
+ system_spec->has_id = true;
+ }
+
+ std::vector<Device*> system_devices;
+ device_set.FindMatchingDevices(*system_spec, &system_devices);
+ if (system_devices.empty()) {
+ if (system_spec_string.empty()) {
+ return errors::InvalidArgument(
+ "No TPU_SYSTEM device found. Please ensure that you're connected to "
+ "a host with a TPU_SYSTEM device.");
+ }
+ return errors::InvalidArgument("No matching devices found for '",
+ system_spec_string, "'");
+ } else if (system_devices.size() > 1) {
+ // Validate that all system devices are part of the same job.
+ std::unordered_set<string> job_names;
+ for (auto device : system_devices) {
+ const auto& parsed_name = device->parsed_name();
+ TF_RET_CHECK(parsed_name.has_job);
+ job_names.insert(parsed_name.job);
+ }
+ if (job_names.size() > 1) {
+ return errors::InvalidArgument(
+ "System devices cannot be part "
+ "of multiple different jobs. Found: ",
+ str_util::Join(job_names, ","));
+ }
+
+ // Identify the lexicographically first device from the list of
+ // valid TPU SYSTEM devices, so that every process in the same
+ // 'cluster' definition uses the same system device.
+ std::sort(system_devices.begin(), system_devices.end(),
+ [](Device* i, Device* j) {
+ auto i_name = i->parsed_name();
+ auto j_name = j->parsed_name();
+ if (i_name.replica != j_name.replica) {
+ return i_name.replica < j_name.replica;
+ }
+ return i_name.task < j_name.task;
+ });
+ }
+
+ *system_device = system_devices[0];
+ if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) {
+ return errors::InvalidArgument("Unable to re-parse system device name ",
+ (*system_device)->name(),
+ " as a device spec.");
+ }
+ return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+// LINT.IfChange
+Status DistributedTPURewriteHelpers::GetHostSystemDevices(
+ const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
+ std::vector<Device*>* host_system_devices) {
+ DeviceNameUtils::ParsedName host_spec;
+ if (system_spec.has_job) {
+ // The system Op has been explicitly assigned to a job, so we want
+ // all the hosts in that job.
+ CHECK(DeviceNameUtils::ParseFullName(
+ strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM,
+ ":0"),
+ &host_spec));
+ } else {
+ // The system Op has not been explicitly assigned to a
+ // job, so take all hosts in the system. There will be a runtime
+ // error if some of those hosts don't contain TPU devices.
+ CHECK(DeviceNameUtils::ParseFullName(
+ strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec));
+ }
+ device_set.FindMatchingDevices(host_spec, host_system_devices);
+
+ TF_RET_CHECK(!host_system_devices->empty())
+ << "No hosts found matching device spec "
+ << DeviceNameUtils::ParsedNameToString(host_spec);
+
+ // Check that all the devices belong to the same job.
+ TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job);
+ const string& job_name = (*host_system_devices)[0]->parsed_name().job;
+ int replica = (*host_system_devices)[0]->parsed_name().replica;
+ for (const auto host_device : *host_system_devices) {
+ const auto& parsed_name = host_device->parsed_name();
+ TF_RET_CHECK(parsed_name.has_job);
+ if (parsed_name.job != job_name) {
+ return errors::InvalidArgument(
+ "All TPU host devices must be in the same job");
+ }
+ TF_RET_CHECK(parsed_name.has_replica);
+ if (parsed_name.replica != replica) {
+ return errors::InvalidArgument(
+ "All TPU host devices must be in the same replica");
+ }
+ }
+
+ // Sort the devices by replica and then task.
+ std::sort(host_system_devices->begin(), host_system_devices->end(),
+ [](Device* i, Device* j) {
+ auto i_name = i->parsed_name();
+ auto j_name = j->parsed_name();
+ return i_name.task < j_name.task;
+ });
+ return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+// LINT.IfChange
+Status DistributedTPURewriteHelpers::GetTPUDevices(
+ const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
+ int* num_tpus_per_host, std::vector<std::vector<Device*>>* tpu_devices) {
+ // GetHostSystemDevices returns the CPU device on each host that is
+ // going to be used for executing TPU code.
+ std::vector<Device*> host_system_devices;
+ TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices(
+ system_spec, device_set, &host_system_devices));
+
+ // Enumerate all the physical devices. Enumerate devices on task 0,
+ // then task 1, etc.
+ std::sort(host_system_devices.begin(), host_system_devices.end(),
+ [](Device* i, Device* j) {
+ return i->parsed_name().task < j->parsed_name().task;
+ });
+
+ *num_tpus_per_host = 0;
+ tpu_devices->clear();
+ tpu_devices->reserve(host_system_devices.size());
+ for (const auto device : host_system_devices) {
+ // Make a copy of the parsed name because we are going to change it.
+ DeviceNameUtils::ParsedName device_spec = device->parsed_name();
+ device_spec.has_type = true;
+ device_spec.type = "TPU";
+ // Enumerate all the available TPUs.
+ device_spec.has_id = false;
+ std::vector<Device*> host_tpu_devices;
+ device_set.FindMatchingDevices(device_spec, &host_tpu_devices);
+ // Sort the devices by device id.
+ std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
+ [](Device* i, Device* j) {
+ return i->parsed_name().id < j->parsed_name().id;
+ });
+ if (tpu_devices->empty()) {
+ // First iteration: set *num_tpus_per_host to the number of TPUs on the
+ // first host.
+ *num_tpus_per_host = host_tpu_devices.size();
+ } else if (*num_tpus_per_host != host_tpu_devices.size()) {
+ // Subsequent iterations: check the number of TPUs match the number on
+ // the first host.
+ return errors::InvalidArgument(
+ "Mismatched number of TPU devices in cluster ", *num_tpus_per_host,
+ " vs. ", host_tpu_devices.size());
+ }
+ tpu_devices->push_back(std::move(host_tpu_devices));
+ }
+ return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
+ const string& node_type, Graph* graph, const DeviceSet& device_set,
+ const std::function<
+ Status(const NodeDef& configuration_node_def,
+ const string& configuration_device_name,
+ const std::vector<Device*>& host_devices,
+ const std::vector<Node*>& input_dependencies,
+ const std::vector<OutputDependency>& output_dependencies,
+ Graph* graph)>& action) {
+ // Find all the matching nodes before mutating the graph.
+ std::vector<Node*> nodes;
+ for (Node* node : graph->nodes()) {
+ if (node->type_string() == node_type) {
+ nodes.push_back(node);
+ }
+ }
+
+ for (Node* node : nodes) {
+ string spec_string = node->requested_device();
+ DeviceNameUtils::ParsedName spec;
+ Device* device;
+ TF_RETURN_IF_ERROR(
+ GetSystemDevice(spec_string, device_set, &spec, &device));
+ const string& device_name = device->name();
+
+ std::vector<Device*> host_devices;
+ TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices));
+
+ std::vector<Node*> input_dependencies;
+ for (const Edge* edge : node->in_edges()) {
+ // Config ops have no inputs, so all edges must be control edges.
+ CHECK(edge->IsControlEdge());
+ input_dependencies.push_back(edge->src());
+ }
+ std::vector<OutputDependency> output_dependencies;
+ for (const Edge* edge : node->out_edges()) {
+ OutputDependency dep;
+ dep.src_output = edge->src_output();
+ dep.dst = edge->dst();
+ dep.dst_input = edge->dst_input();
+ output_dependencies.push_back(dep);
+ }
+ NodeDef node_def = node->def();
+
+ // Remove the node now so we can insert a new node with the same
+ // name inside the action.
+ graph->RemoveNode(node);
+
+ TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices,
+ input_dependencies, output_dependencies, graph));
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h
new file mode 100644
index 0000000..40aacce
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions for TPU rewrite passes.
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+class DistributedTPURewriteHelpers {
+ public:
+ // Given a user-assigned device string, system_spec_string, parse it into
+ // system_spec. Verify that the device type is either TPU_SYSTEM or
+ // unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the
+ // type, verify that the spec matches a unique device in device_set, and
+ // return that device in system_device. The normal use case is for
+ // system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the
+ // job that contains the TPU hardware.
+ // TODO(b/110910013): Possibly remove the tpu system device.
+ static Status GetSystemDevice(const string& system_spec_string,
+ const DeviceSet& device_set,
+ DeviceNameUtils::ParsedName* system_spec,
+ Device** system_device);
+
+ // Given a parsed system spec (e.g., the one returned above from
+ // GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on
+ // every host in the spec's job. If the spec does not include an explicit job,
+ // "localhost" is used. Returns an error if system_spec matches devices from
+ // a multiple jobs or replicas.
+ static Status GetHostSystemDevices(
+ const DeviceNameUtils::ParsedName& system_spec,
+ const DeviceSet& device_set, std::vector<Device*>* host_system_devices);
+
+ // Given a parsed system spec (e.g., the one returned above from
+ // GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU
+ // devices on every host in the spec's job. If the spec does not include an
+ // explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number
+ // of TPU devices in each host, and verifies that each host in the job has
+ // the same number of TPU devices.
+ // Returns an error if system_spec matches devices from a multiple jobs or
+ // replicas.
+ static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec,
+ const DeviceSet& device_set,
+ int* num_tpus_per_host,
+ std::vector<std::vector<Device*>>* tpu_devices);
+
+ // Perform 'action' on every node in 'graph' of type
+ // 'node_type'. This function is designed for use with configuration
+ // Ops that have no inputs or outputs. The arguments passed to 'action' are:
+ // 'configuration_node_name': the name of the node that matched
+ // 'configuration_device_name': the name of the device that the
+ // matching node is placed on
+ // 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are
+ // in the same system as the node that matched.
+ // 'input_dependencies': the set of nodes that have control edges to
+ // the matching node.
+ // 'output_dependencies': the set of output port, destination node, input port
+ // triples that have edges from the matching node. Input port is
+ // Graph::kControlSlot for a control edge.
+ // 'graph': the graph being mutated.
+ struct OutputDependency {
+ int src_output;
+ Node* dst;
+ int dst_input;
+ };
+ static Status ForConfigurationNodeMatchingType(
+ const string& node_type, Graph* graph, const DeviceSet& device_set,
+ const std::function<
+ Status(const NodeDef& configuration_node_def,
+ const string& configuration_device_name,
+ const std::vector<Device*>& host_devices,
+ const std::vector<Node*>& input_dependencies,
+ const std::vector<OutputDependency>& output_dependencies,
+ Graph* graph)>& action);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
new file mode 100644
index 0000000..0e5a91c
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -0,0 +1,288 @@
+# TPU Kernel Implementations
+load(
+ "//tensorflow/core/platform:build_config.bzl",
+ "tf_proto_library_cc",
+)
+
+package(
+ default_visibility = [
+ "//tensorflow/core/tpu:__subpackages__",
+ "//tensorflow/stream_executor/tpu:__subpackages__",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "tpu_compile_op_options",
+ srcs = ["tpu_compile_op_options.cc"],
+ hdrs = ["tpu_compile_op_options.h"],
+)
+
+cc_library(
+ name = "tpu_configuration_ops",
+ srcs = ["tpu_configuration_ops.cc"],
+ hdrs = ["tpu_configuration_ops.h"],
+ deps = [
+ ":tpu_mesh_state_interface",
+ "//tensorflow/c:tf_status",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform:refcount",
+ "//tensorflow/core/tpu:tpu_config_c_api",
+ "//tensorflow/core/tpu:tpu_configuration",
+ "//tensorflow/core/tpu:tpu_defs",
+ "//tensorflow/core/tpu:tpu_library_loader",
+ "//tensorflow/stream_executor/tpu:proto_helper",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tpu_compile_c_api_hdrs",
+ hdrs = ["tpu_compile_c_api.h"],
+ deps = [
+ ":tpu_mesh_state_c_api",
+ "//tensorflow/c:tf_datatype",
+ "//tensorflow/stream_executor/tpu:proto_helper",
+ "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "tpu_executable_info_proto",
+ srcs = ["tpu_executable_info.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/core:protos_all",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "tpu_compile_proto",
+ srcs = ["tpu_compile.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":tpu_executable_info_proto",
+ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/core:protos_all",
+ "//tensorflow/core/protobuf/tpu:compile_metadata_proto",
+ ],
+)
+
+cc_library(
+ name = "tpu_compilation_cache_key",
+ srcs = [],
+ hdrs = [
+ "tpu_compilation_cache_key.h",
+ ],
+ deps = ["@com_google_absl//absl/types:optional"],
+)
+
+cc_library(
+ name = "tpu_compile_op_support",
+ srcs = ["tpu_compile_op_support.cc"],
+ hdrs = ["tpu_compile_op_support.h"],
+ deps = [
+ ":tpu_compilation_cache_key",
+ ":tpu_compile_c_api_hdrs",
+ ":tpu_compile_proto_cc",
+ ":tpu_executable_info_proto_cc",
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:debug_options_flags",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:dump",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:hlo_module_group",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/framework:protos_all_cc",
+ "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+ "//tensorflow/stream_executor/tpu:proto_helper",
+ "//tensorflow/stream_executor/tpu:status_helper",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "tpu_compilation_cache_entry",
+ hdrs = [
+ "tpu_compilation_cache_entry.h",
+ ],
+ deps = [
+ ":tpu_executable_info_proto_cc",
+ ":tpu_program",
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
+ "//tensorflow/core/lib/core:refcount",
+ ],
+)
+
+cc_library(
+ name = "tpu_compilation_cache_lookup",
+ srcs = ["tpu_compilation_cache_lookup.cc"],
+ hdrs = [
+ "tpu_compilation_cache_lookup.h",
+ ],
+ deps = [
+ ":tpu_compilation_cache_entry",
+ ":tpu_compilation_cache_external",
+ ":tpu_compilation_cache_proto_cc",
+ "//tensorflow/core/lib/core:refcount",
+ "//tensorflow/core/platform:status",
+ "//tensorflow/core/profiler/lib:traceme",
+ ],
+)
+
+cc_library(
+ name = "tpu_mesh_state_c_api",
+ hdrs = ["tpu_mesh_state_c_api.h"],
+)
+
+cc_library(
+ name = "tpu_mesh_state_interface",
+ srcs = [],
+ hdrs = ["tpu_mesh_state_interface.h"],
+ deps = [
+ ":tpu_compile_c_api_hdrs",
+ ":tpu_mesh_state_c_api",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+ "//tensorflow/core/tpu:tpu_config_c_api",
+ ],
+)
+
+cc_library(
+ name = "tpu_program",
+ srcs = ["tpu_program.cc"],
+ hdrs = ["tpu_program.h"],
+ deps = [
+ ":tpu_compile_c_api_hdrs",
+ ":tpu_compile_op_support",
+ ":tpu_compile_proto_cc",
+ ":tpu_executable_info_proto_cc",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:xla_proto_cc",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xla/service:hlo_module_group",
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+ "//tensorflow/stream_executor/tpu:proto_helper",
+ "//tensorflow/stream_executor/tpu:status_helper",
+ "//tensorflow/stream_executor/tpu:tpu_platform_interface",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "tpu_compilation_cache_external",
+ srcs = ["tpu_compilation_cache_external.cc"],
+ hdrs = [
+ "tpu_compilation_cache_external.h",
+ ],
+ deps = [
+ ":tpu_compilation_cache_entry",
+ ":tpu_compilation_cache_key",
+ ":tpu_compilation_cache_metrics", # buildcleaner: keep
+ ":tpu_compilation_cache_metrics_hdrs",
+ ":tpu_compilation_cache_proto_cc",
+ ":tpu_compile_c_api_hdrs",
+ ":tpu_compile_op_support",
+ ":tpu_mesh_state_interface",
+ ":tpu_program",
+ ":tpu_util",
+ ":trace_util_hdrs",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/platform:refcount",
+ "//tensorflow/core/profiler/lib:traceme",
+ "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "tpu_compilation_cache_metrics_hdrs",
+ hdrs = ["tpu_compilation_cache_metrics.h"],
+ deps = [
+ "//tensorflow/core/platform:types",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "tpu_compilation_cache_metrics",
+ srcs = ["tpu_compilation_cache_metrics.cc"],
+ deps = [
+ ":tpu_compilation_cache_metrics_hdrs",
+ ],
+)
+
+cc_library(
+ name = "trace_util_hdrs",
+ srcs = [],
+ hdrs = ["trace_util.h"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "tpu_util_hdrs",
+ srcs = [],
+ hdrs = ["tpu_util.h"],
+ deps = [
+ ":tpu_compilation_cache_key",
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "tpu_util",
+ srcs = ["tpu_util.cc"],
+ hdrs = ["tpu_util.h"],
+ deps = [
+ ":tpu_compilation_cache_key",
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
+tf_proto_library_cc(
+ name = "tpu_compilation_cache_proto",
+ srcs = ["tpu_compilation_cache.proto"],
+ cc_api_version = 2,
+)
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
new file mode 100644
index 0000000..8308cba
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
@@ -0,0 +1,25 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// Target type for compilation cache fetch operation.
+enum CompilationCacheFetchTarget {
+ INVALID = 0;
+ MAIN = 1;
+ SHARDING = 2;
+ UNSHARDING = 3;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
new file mode 100644
index 0000000..d16b2d5
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
@@ -0,0 +1,84 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class CompilationCacheEntry {
+ public:
+ explicit CompilationCacheEntry(
+ std::unique_ptr<const TpuProgram> tpu_program)
+ : tpu_program_(std::move(tpu_program)) {}
+
+ // Constructor for an empty entry.
+ CompilationCacheEntry()
+ : tpu_program_(nullptr) {}
+
+ const TPUExecutableInfoProto* get_executable_info() const {
+ return &tpu_program_->executable_info();
+ }
+
+ const TPUHostTransferInfoProto* get_host_transfer_info() const {
+ return &tpu_program_->host_transfer_info();
+ }
+
+ const xla::HloProto* get_hlo_metadata() const {
+ return &tpu_program_->hlo_metadata();
+ }
+
+ // TODO(henrytan,jiawenhao): When should we expect more than one
+ // XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
+ const XLA_TpuProgram* get_tpu_program() const {
+ CHECK_EQ(tpu_program_->program_count(), 1);
+ return tpu_program_->tpu_programs()[0];
+ }
+
+ private:
+ std::unique_ptr<const TpuProgram> tpu_program_;
+};
+
+// Base class for a reference to a cached proto. A unique_ptr to a
+// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
+// and ensures the underlying proto is not garbage-collected until the client
+// discards the ptr.
+class CompilationCacheEntryRef {
+ public:
+ virtual ~CompilationCacheEntryRef() = default;
+
+ // Returns a CompilationCacheEntry that should not be used beyond the lifetime
+ // of the CompilationCacheEntryRef.
+ virtual CompilationCacheEntry get() = 0;
+};
+
+// Base class that holds references to compiled protos so that the protos are
+// not garbage-collected before being used by execute ops. Use
+// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
+// ref holder object.
+class CompilationRefHolder : public ResourceBase {
+ public:
+ ~CompilationRefHolder() override = default;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
new file mode 100644
index 0000000..8dbf608
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
@@ -0,0 +1,791 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/random.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+#include "tensorflow/core/tpu/kernels/tpu_util.h"
+#include "tensorflow/core/tpu/kernels/trace_util.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+
+using CompilationEntry = TpuCompilationCacheInterface::CompilationEntry;
+
+int64 get_uid() {
+ uint64 unsigned_rand = random::New64() & INT64_MAX;
+ return static_cast<int64>(unsigned_rand);
+}
+
+void PopulateEntry(const std::string& key, CompilationEntry* entry,
+ std::unique_ptr<TpuProgram> tpu_program) {
+ // Make the unique keys for each cached proto.
+ for (int i = 0; i < tpu_program->program_count(); ++i) {
+ entry->proto_key.push_back(ProtoKeyForComputation(key, i));
+ }
+
+ entry->tpu_program = std::move(tpu_program);
+ entry->initialized = true;
+}
+
+std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
+ if (!key.has_guaranteed_const) {
+ return key.prefix;
+ }
+ return absl::StrCat(key.prefix, "|", key.session_handle, "|",
+ key.guaranteed_const_fingerprint());
+}
+
+// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
+// data to compute the fingerprint.
+std::string GuaranteedConstFingerprint(
+ const string& fingerprint_in_metadata,
+ const OpInputList& guaranteed_constants) {
+ if (fingerprint_in_metadata.empty()) {
+ uint64_t fingerprint = 0;
+ for (const auto& constant : guaranteed_constants) {
+ fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
+ fingerprint, constant.tensor_data().data(),
+ constant.tensor_data().size());
+ }
+ return std::to_string(fingerprint);
+ } else {
+ return fingerprint_in_metadata;
+ }
+}
+
+std::string CreateShapePrefix(
+ const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
+ std::string shapes_prefix;
+ for (const TensorShape& shape : dynamic_shapes) {
+ for (int64 size : shape.dim_sizes()) {
+ absl::StrAppend(&shapes_prefix, size, ",");
+ }
+ absl::StrAppend(&shapes_prefix, ";");
+ }
+ return shapes_prefix;
+}
+
+// Include compilation configurations of the arguments that are not captured
+// by the called graph.
+std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
+ std::string config_prefix;
+ for (const auto& arg : metadata.args()) {
+ if (arg.is_same_data_across_replicas()) {
+ absl::StrAppend(&config_prefix, ":s");
+ // Same.
+ } else {
+ // Different.
+ absl::StrAppend(&config_prefix, ":");
+ }
+ if (arg.enable_xla_sharding() ==
+ tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
+ // Enabled.
+ absl::StrAppend(&config_prefix, "e");
+ }
+ if (arg.unrestricted_layout()) {
+ // Unrestricted.
+ absl::StrAppend(&config_prefix, ":u");
+ }
+ absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
+ if (arg.has_shape()) {
+ absl::StrAppend(&config_prefix, ",shape(");
+ for (const auto& dim : arg.shape().dim()) {
+ absl::StrAppend(&config_prefix, dim.size(), ",");
+ }
+ absl::StrAppend(&config_prefix, ")");
+ }
+ }
+ return config_prefix;
+}
+
+} // namespace
+
+TpuCompilationCacheInterface::TpuCompilationCacheInterface(
+ int64_t max_cache_size)
+ : max_cache_size_(max_cache_size) {
+ if (max_cache_size < 0) {
+ LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
+ }
+ VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
+}
+
+TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
+ VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
+ // A buggy client may be holding onto a reference, or a client might have
+ // crashed while holding onto a reference. In either case, discard all
+ // outstanding client references to avoid leaking storage.
+ for (const auto& entry : entries_by_uid_) {
+ while (entry.second->external_references > 0) {
+ TF_CHECK_OK(Release(entry.first));
+ }
+ }
+ while (!entries_by_last_use_.empty()) {
+ UnloadAndDestroy(MarkOldestEntryForEviction());
+ }
+ // By the time the cache is deleted all reference holders should have already
+ // been deleted, since they were holding references to the cache. So all
+ // entries should be gone at this point.
+ CHECK_EQ(cache_store_.size(), 0);
+ CHECK_EQ(entries_by_uid_.size(), 0);
+ CHECK_EQ(entries_by_proto_key_.size(), 0);
+ CHECK_EQ(cache_size_, 0);
+ CHECK_EQ(marked_for_eviction_size_, 0);
+}
+
+std::string TpuCompilationCacheInterface::FindCacheKey(
+ const TpuCompilationCacheKey& subgraph_key) const {
+ if (!subgraph_key.has_guaranteed_const) {
+ return subgraph_key.prefix;
+ }
+ auto iter = session_key_map_.find(
+ strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
+ if (iter != session_key_map_.end()) {
+ return iter->second;
+ }
+ iter = fingerprint_key_map_.find(strings::StrCat(
+ subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
+ if (iter != session_key_map_.end()) {
+ return iter->second;
+ }
+ VLOG(1) << "No matching cache key found for key "
+ << ConstructCompilationCacheKey(subgraph_key);
+ return "";
+}
+
+void TpuCompilationCacheInterface::InsertEntry(
+ const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
+ CompilationEntry* entry) {
+ entry->parent = this;
+ entry->subgraph_key = cache_key;
+ entry->uid = get_uid();
+ TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
+ entry->cache_entry_debug_string = subgraph_key.prefix;
+ VLOG(1) << "Cache Initializing Entry Session Debug "
+ << entry->cache_entry_debug_string;
+
+ if (!subgraph_key.has_guaranteed_const) {
+ return;
+ }
+ session_key_map_.insert(std::make_pair(
+ strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
+ cache_key));
+ fingerprint_key_map_.insert(std::make_pair(
+ strings::StrCat(subgraph_key.prefix,
+ subgraph_key.guaranteed_const_fingerprint()),
+ cache_key));
+}
+
+CompilationEntry* TpuCompilationCacheInterface::InitializeEntry(
+ const string& key,
+ const std::function<Status(TpuProgram*)>& initialize_program,
+ const TpuCompilationCacheKey& subgraph_key) {
+ CompilationEntry* main_entry = new CompilationEntry();
+
+ // Add the entry to the cache, with size zero since there are no compiled
+ // programs in it. Once the subgraph has been compiled,
+ // UpdateEntryAfterCompilation will be called to potentially mark old entries
+ // that don't fit any more for eviction.
+ //
+ // At this point there is one reference to entry, which is owned by the caller
+ // who created the entry. A second reference, owned by the cache, will be
+ // added below since we leave the entry in the 'marked for eviction' state
+ // here.
+ InsertEntry(key, subgraph_key, main_entry);
+
+ // Initialize the programs outside the lock so that other cache operations
+ // can proceed during the (potentially lengthy) initialization.
+ Status initialization_status;
+
+ auto tpu_program = absl::make_unique<TpuProgram>();
+ {
+ mu_.Unlock();
+ {
+ profiler::TraceMe compile_programs_traceme(
+ "TPU compilation cache compile",
+ /*level=*/2);
+ initialization_status = initialize_program(tpu_program.get());
+ }
+ mu_.Lock();
+ }
+
+ main_entry->initialization_status = initialization_status;
+
+ // Add the entry to the uid index.
+ auto uid_inserted = entries_by_uid_.insert(
+ std::pair<int64, CompilationEntry*>(main_entry->uid, main_entry));
+ CHECK(uid_inserted.second);
+
+ if (initialization_status.ok()) {
+ // Compute the entries total size once all members are initialized.
+ main_entry->total_size = tpu_program->program_size();
+ }
+
+ // TODO(henrytan): handle sharding/unsharding.
+ PopulateEntry(key, main_entry, std::move(tpu_program));
+
+ for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
+ auto entry_inserted = entries_by_proto_key_.insert(
+ std::pair<string, std::pair<CompilationEntry*, int>>(
+ main_entry->proto_key[i], std::make_pair(main_entry, i)));
+ CHECK(entry_inserted.second);
+ }
+
+ // Add the size to marked_for_eviction_size_ since it will be adjusted down
+ // again when the newly-created entry gets unmarked.
+ marked_for_eviction_size_ += main_entry->total_size;
+ return main_entry;
+}
+
+/*static*/ TpuCompilationCacheKey
+TpuCompilationCacheInterface::CreateCompilationCacheKey(
+ absl::string_view function_name, uint64 function_library_fingerprint,
+ absl::string_view mlir_module,
+ const tensorflow::OpInputList& guaranteed_constants,
+ const std::vector<tensorflow::TensorShape>& dynamic_shapes,
+ const tensorflow::tpu::TPUCompileMetadataProto& metadata,
+ const TpuMeshStateInterface& mesh_state) {
+ VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
+ std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
+ VLOG(1) << "shapes_prefix = " << shapes_prefix;
+ std::string config_prefix = CreateConfigPrefix(metadata);
+ VLOG(1) << "config_prefix = " << config_prefix;
+ std::vector<int32_t> flattened_device_ids;
+ if (metadata.has_device_assignment()) {
+ for (const auto& device :
+ metadata.device_assignment().computation_devices()) {
+ flattened_device_ids.insert(flattened_device_ids.end(),
+ device.replica_device_ids().begin(),
+ device.replica_device_ids().end());
+ }
+ }
+ // TODO(henrytan): return the debug_string.
+ const char* prefix =
+ TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty{
+ config_prefix.data(),
+ shapes_prefix.data(),
+ function_name.data(),
+ mlir_module.data(),
+ flattened_device_ids.data(),
+ flattened_device_ids.size(),
+ guaranteed_constants.size(),
+ function_library_fingerprint,
+ metadata.num_cores_per_replica(),
+ metadata.num_replicas(),
+ mesh_state.data(),
+ });
+ auto buffer_cleanup = gtl::MakeCleanup([prefix]() { delete[] prefix; });
+ TpuCompilationCacheKey key;
+ key.prefix = prefix;
+
+ // Guaranteed constants can be different across sessions. Use session_handle
+ // and guaranteed_const fingerprint to guarantee no collision.
+ if (guaranteed_constants.size() > 0) {
+ key.has_guaranteed_const = true;
+ key.session_handle = metadata.session_handle();
+ // Both `metadata` and `guaranteed_constants` lifetime are captured by
+ // reference based on the assumption that these variables lifetime is
+ // managed through the `TPUCompileOpKernelImpl` that outlives the
+ // lifetime of the compilation cache lookups.
+ string fingerprint;
+ key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
+ fingerprint]() mutable {
+ if (fingerprint.empty()) {
+ fingerprint = GuaranteedConstFingerprint(
+ metadata.guaranteed_const_fingerprint(), guaranteed_constants);
+ }
+ return fingerprint;
+ };
+ }
+ return key;
+}
+
+TpuCompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
+ return new RefHolder(this);
+}
+
+Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) {
+ profiler::TraceMe key_release_traceme(
+ "TPU compilation cache possibly evict uid",
+ /*level=*/2);
+ CompilationEntry* deleted_entry = nullptr;
+ {
+ absl::MutexLock lock(&mu_);
+ auto iter = entries_by_uid_.find(subgraph_uid);
+ if (iter == entries_by_uid_.end()) {
+ // If already evicted, return ok.
+ return Status::OK();
+ }
+
+ // Mark entry for eviction.
+ CompilationEntry* subgraph_to_evict = iter->second;
+ // If there are external references, should not use this API.
+ if (subgraph_to_evict->external_references != 0) {
+ return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
+ " external_references greater than zero. Should "
+ "use TpuCompilationCache::Release.");
+ }
+
+ VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
+ entries_by_last_use_.erase(subgraph_to_evict->last_use);
+ cache_size_ -= subgraph_to_evict->total_size;
+ marked_for_eviction_size_ += subgraph_to_evict->total_size;
+
+ // Evict if refcount exactly one, otherwise only discard cache's reference
+ // to the entry while the actual eviction will happen when refholder's
+ // references go away.
+ deleted_entry = DiscardEntryRef(subgraph_to_evict);
+
+ VLOG(1) << "After possibly evicting entry " << subgraph_uid
+ << " refs cache is " << cache_store_.size() << " entries ("
+ << cache_size_ + marked_for_eviction_size_
+ << " bytes), marked for eviction "
+ << (cache_store_.size() - entries_by_last_use_.size())
+ << " entries (" << marked_for_eviction_size_ << " bytes).";
+ }
+
+ // Unload from device cache if entry is evicted from host cache.
+ UnloadAndDestroy(deleted_entry);
+ return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::Release(int64 subgraph_uid) {
+ profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
+ /*level=*/2);
+
+ CompilationEntry* deleted_entry = nullptr;
+ {
+ absl::MutexLock lock(&mu_);
+ auto iter = entries_by_uid_.find(subgraph_uid);
+
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No cache entry found for uid ", subgraph_uid);
+ }
+
+ CHECK_GT(iter->second->external_references, 0);
+ --iter->second->external_references;
+
+ deleted_entry = DiscardEntryRef(iter->second);
+
+ VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
+ << cache_store_.size() << " entries ("
+ << cache_size_ + marked_for_eviction_size_
+ << " bytes), marked for eviction "
+ << (cache_store_.size() - entries_by_last_use_.size())
+ << " entries (" << marked_for_eviction_size_ << " bytes).";
+ }
+ UnloadAndDestroy(deleted_entry);
+ return Status::OK();
+}
+
+void TpuCompilationCacheInterface::UnloadAndDestroy(CompilationEntry* entry) {
+ if (!entry) return;
+
+ CHECK(entry->RefCountIsOne());
+ entry->tpu_program->UnloadAndDestroyPrograms();
+ entry->Unref();
+}
+
+size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
+ auto erased = cache_store_.erase(key);
+ TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
+ auto parsed_key_or_status = ParseCompilationCacheKey(key);
+ CHECK(parsed_key_or_status.status().ok());
+ const TpuCompilationCacheKey parsed_key =
+ parsed_key_or_status.ConsumeValueOrDie();
+ if (!parsed_key.has_guaranteed_const) {
+ return erased;
+ }
+ session_key_map_.erase(
+ strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
+ fingerprint_key_map_.erase(strings::StrCat(
+ parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
+ return erased;
+}
+
+ABSL_MUST_USE_RESULT CompilationEntry*
+TpuCompilationCacheInterface::DiscardEntryRef(CompilationEntry* entry) {
+ if (entry->RefCountIsOne()) {
+ // The last reference to this entry is going away, so really delete it from
+ // the cache in such a way that it can't be restored by being looked up
+ // again.
+
+ // Sanity-check that it has been marked for eviction.
+ CHECK(entries_by_last_use_.find(entry->last_use) ==
+ entries_by_last_use_.end());
+ // Update the counter tracking how much space is taken up by entries that
+ // are marked for eviction.
+ marked_for_eviction_size_ -= entry->total_size;
+
+ // Remove the entry from the cache.
+ auto erased = RemoveEntry(entry->subgraph_key);
+
+ if (erased == 0) {
+ LOG(FATAL) << "Tried to discard nonexistent cache entry";
+ }
+ erased = entries_by_uid_.erase(entry->uid);
+ CHECK_EQ(erased, 1);
+ for (const string& key : entry->proto_key) {
+ erased = entries_by_proto_key_.erase(key);
+ CHECK_EQ(erased, 1);
+ }
+ // The actual deletion will happen outside the lock in UnloadAndDestroy().
+ return entry;
+ }
+ entry->Unref();
+ return nullptr;
+}
+
+void TpuCompilationCacheInterface::DiscardEntryRefs(
+ gtl::ArraySlice<CompilationEntry*> entries) {
+ std::vector<CompilationEntry*> removed_entries;
+ {
+ absl::MutexLock lock(&mu_);
+
+ for (auto entry : entries) {
+ removed_entries.push_back(DiscardEntryRef(entry));
+ }
+
+ VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
+ << " entries (" << cache_size_ + marked_for_eviction_size_
+ << " bytes), marked for eviction "
+ << (cache_store_.size() - entries_by_last_use_.size())
+ << " entries (" << marked_for_eviction_size_ << " bytes).";
+ }
+ for (auto removed_entry : removed_entries) {
+ UnloadAndDestroy(removed_entry);
+ }
+}
+
+ABSL_MUST_USE_RESULT CompilationEntry*
+TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
+ CompilationEntry* entry_to_mark = entries_by_last_use_.begin()->second;
+ VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
+ entries_by_last_use_.erase(entry_to_mark->last_use);
+ cache_size_ -= entry_to_mark->total_size;
+ marked_for_eviction_size_ += entry_to_mark->total_size;
+ // Discard the cache's reference to entry. If steps are holding onto
+ // references to entry it won't be deleted until the last step holding it
+ // completes. It stays in the cache in the meantime and can be resurrected
+ // by a call to CompileIfKeyAbsent if that occurs before the last reference
+ // expires.
+ return DiscardEntryRef(entry_to_mark);
+}
+
+void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
+ CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries) {
+ // The entry was previously marked for eviction (or is newly created) so
+ // unmark it. Add a reference (owned by the cache), update the cache size, and
+ // mark something old for eviction if necessary.
+ entry->Ref();
+ marked_for_eviction_size_ -= entry->total_size;
+ cache_size_ += entry->total_size;
+
+ // Mark the least-recently-used non-marked entry for eviction. Never mark the
+ // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
+ // which means there's only one entry not already marked for eviction), so
+ // that an entry persists in the cache even if it is larger than the allocated
+ // cache size.
+ while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
+ if (auto entry_to_evict = MarkOldestEntryForEviction()) {
+ removed_entries->push_back(entry_to_evict);
+ }
+ }
+}
+
+Status TpuCompilationCacheInterface::ToSubEntryRef(
+ CompilationCacheEntryRef* entry,
+ CompilationCacheFetchTarget fetch_target) const {
+ return static_cast<EntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
+}
+
+TpuCompilationCacheInterface::EntryRefImpl::EntryRefImpl(
+ TpuCompilationCacheInterface* parent, CompilationEntry* entry, int index)
+ : parent_(parent), entry_(entry), index_(index) {
+ if (entry_ == nullptr) {
+ return;
+ }
+ if (entry_->main_entry == nullptr) {
+ entry_->Ref();
+ } else {
+ // This is a sharding/unsharding entry nested in a main entry. Only refcount
+ // the main entry.
+ entry_->main_entry->Ref();
+ }
+}
+
+TpuCompilationCacheInterface::EntryRefImpl::~EntryRefImpl() {
+ if (entry_ == nullptr) {
+ return;
+ }
+ if (entry_->main_entry == nullptr) {
+ parent_->DiscardEntryRefs({entry_});
+ } else {
+ parent_->DiscardEntryRefs({entry_->main_entry});
+ }
+}
+
+CompilationCacheEntry TpuCompilationCacheInterface::EntryRefImpl::get() {
+ if (entry_ == nullptr) {
+ // Create an empty entry if the entry is nullptr. This corresponds to
+ // non-existing sharding/unsharding entries.
+ return CompilationCacheEntry();
+ }
+ return CompilationCacheEntry(std::move(entry_->tpu_program));
+}
+
+Status TpuCompilationCacheInterface::EntryRefImpl::ToSubEntryRef(
+ CompilationCacheFetchTarget fetch_target) {
+ CompilationEntry* target = nullptr;
+ switch (fetch_target) {
+ case CompilationCacheFetchTarget::MAIN:
+ target = entry_;
+ break;
+ case CompilationCacheFetchTarget::SHARDING:
+ target = entry_->sharding_entry.get();
+ break;
+ case CompilationCacheFetchTarget::UNSHARDING:
+ target = entry_->unsharding_entry.get();
+ break;
+ default:
+ return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
+ }
+
+ if (target == nullptr) {
+ // Cache entry does not have an unsharding subentry. Unref and replace
+ // with nullptr.
+ parent_->DiscardEntryRefs({entry_});
+ }
+ // Otherwise, since the refcount is always on the main entry, we don't need
+ // ref/unref.
+ entry_ = target;
+ return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::Lookup(
+ int64 uid, int proto_index,
+ std::unique_ptr<CompilationCacheEntryRef>* entry) {
+ entry->reset();
+
+ profiler::TraceMe proto_lookup_traceme(
+ "TPU compilation cache proto lookup by uid",
+ /*level=*/2);
+
+ absl::MutexLock lock(&mu_);
+ const auto iter = entries_by_uid_.find(uid);
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No subgraph found for uid ", uid);
+ }
+ CompilationEntry* cache_entry = iter->second;
+ if (proto_index < 0 ||
+ proto_index >= cache_entry->tpu_program->program_size()) {
+ return errors::NotFound("No proto found for core index ", proto_index,
+ " in subgraph with uid ", uid);
+ }
+ *entry = std::unique_ptr<CompilationCacheEntryRef>(
+ new EntryRefImpl(this, cache_entry, proto_index));
+ return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::Lookup(
+ const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
+ entry->reset();
+
+ profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
+ /*level=*/2);
+
+ absl::MutexLock lock(&mu_);
+ const auto iter = entries_by_proto_key_.find(proto_key);
+ if (iter == entries_by_proto_key_.end()) {
+ return errors::NotFound("No proto found for key ", proto_key);
+ }
+ CompilationEntry* cache_entry = iter->second.first;
+ int proto_index = iter->second.second;
+ *entry = std::unique_ptr<CompilationCacheEntryRef>(
+ new EntryRefImpl(this, cache_entry, proto_index));
+ return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
+ const TpuCompilationCacheKey& subgraph_key,
+ const SessionMetadata* session_metadata,
+ TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+ std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+ std::vector<CompilationEntry*>* removed_entries,
+ std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+ const std::function<Status(TpuProgram*)>& compile_function) {
+ profiler::TraceMe subgraph_lookup_traceme(
+ "TPU compilation cache subgraph lookup",
+ /*level=*/2);
+
+ // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
+ // for the lifetime of the object, see InitializeEntry() call below.
+ absl::MutexLock lock(&mu_);
+
+ std::string cache_key = FindCacheKey(subgraph_key);
+ auto iter = cache_store_.find(cache_key);
+ bool is_new_key = iter == cache_store_.end();
+
+ const std::string session_name = SessionNameFromMetadata(session_metadata);
+
+ CompilationEntry* entry = nullptr;
+ if (is_new_key) {
+ cache_key = ConstructCompilationCacheKey(subgraph_key);
+ TpuCompilationCacheMetrics::IncrementCacheLookupCount(
+ /*is_cache_hit=*/false, session_name);
+ const string msg =
+ strings::StrCat("TPU host compilation cache miss: cache_key(",
+ cache_key, "), session_name(", session_name, ")");
+
+ TRACESTRING(msg);
+ LOG(INFO) << msg;
+
+ // Check if caller has disabled compilation. Set using
+ // internal::ScopedTpuCompileDisabler.
+ if (!IsTpuCompilationEnabled()) {
+ const string error_msg = strings::StrCat(
+ "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
+ "disabled, session_name(",
+ session_name, ") Debug String: ", subgraph_key.debug_string);
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Cache Missed. Current cache entries: ";
+ for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
+ // TODO(henrytan): add DebugKey as cache_entry_debug_string to
+ // TpuCompilationCacheKey.
+ VLOG(2) << "Cache Debug Info: ";
+ VLOG(2) << it->second->cache_entry_debug_string;
+ }
+ }
+
+ LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
+ return errors::NotFound(error_msg);
+ }
+
+ // The single ref on the newly-created entry is owned by the caller.
+ VLOG(1) << "Before adding new entry for key " << cache_key
+ << " with session_name( " << session_name << ");"
+ << "; cache is " << cache_store_.size() << " entries ("
+ << cache_size_ + marked_for_eviction_size_ << " bytes), "
+ << " marked for eviction "
+ << (cache_store_.size() - entries_by_last_use_.size())
+ << " entries (" << marked_for_eviction_size_ << " bytes).";
+ // Note that InitializeEntry() will Release/Reacquire mu_.
+ entry = InitializeEntry(cache_key, compile_function, subgraph_key);
+ TRACELITERAL("TPU host compilation cache: compilation done.");
+
+ LOG(INFO) << strings::StrCat(
+ "TPU host compilation cache: compilation done for cache_key(",
+ cache_key, "), session_name(", session_name, ")");
+ // If session_name is present, log some additional stats related to HBM
+ // here, so that they can be associated directly to the session.
+ if (!session_name.empty()) {
+ entry->tpu_program->LogProgramMemorySummary();
+ }
+ } else {
+ TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
+ const string msg =
+ strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
+ "), session_name(", session_name, ")");
+ TRACESTRING(msg);
+ VLOG(1) << msg;
+ VLOG(1) << "Before refreshing entry for key " << cache_key
+ << " with session_name( " << session_name << "); cache is "
+ << cache_store_.size() << " entries ("
+ << cache_size_ + marked_for_eviction_size_ << " bytes), "
+ << " marked for eviction "
+ << (cache_store_.size() - entries_by_last_use_.size())
+ << " entries (" << marked_for_eviction_size_ << " bytes).";
+ entry = iter->second;
+ // Make a new reference that is owned by the caller.
+ entry->Ref();
+ // Block if necessary until the subgraph has been initialized.
+ mu_.Await(absl::Condition(
+ +[](CompilationEntry* e) { return e->initialized; }, entry));
+ }
+
+ // Let the caller know the uid of the entry.
+ *uid = entry->uid;
+ // Let the caller know the keys for each of the cached protos.
+ *proto_key = entry->proto_key;
+ *may_modify_variables = entry->tpu_program->may_modify_variables();
+ *hlo_metadata = entry->hlo_metadata;
+
+ // If the caller didn't supply a per_step_ref_holder then the caller is going
+ // to manually release the reference later via a call to Release().
+ if (per_step_ref_holder == nullptr) {
+ ++entry->external_references;
+ } else {
+ // The caller wants its reference to be handed off to a per-step holder that
+ // will discard the reference when the step completes.
+ RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
+ TF_RET_CHECK(cast_ref_holder != nullptr);
+ cast_ref_holder->AddRef(entry);
+ }
+
+ // Remove the old LRU-table entry if it wasn't already marked for eviction.
+ auto erased = entries_by_last_use_.erase(entry->last_use);
+ // Update the LRU table indicating this entry is the most recently used.
+ entry->last_use = use_counter_++;
+ entries_by_last_use_[entry->last_use] = entry;
+ if (erased == 0) {
+ // The entry had been marked for eviction, or is newly created.
+ LookupEntryMarkedForEviction(entry, removed_entries);
+ }
+
+ // Log a little more verbosely when a key is added.
+ if (VLOG_IS_ON(1) || is_new_key) {
+ LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
+ << " entry for key " << cache_key << " with session_name "
+ << session_name << " cache is " << cache_store_.size()
+ << " entries (" << cache_size_ + marked_for_eviction_size_
+ << " bytes), "
+ << " marked for eviction "
+ << (cache_store_.size() - entries_by_last_use_.size())
+ << " entries (" << marked_for_eviction_size_ << " bytes).";
+ }
+ return entry->initialization_status;
+}
+
+tensorflow::Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
+ const TpuCompilationCacheKey& cache_key,
+ const tensorflow::SessionMetadata* session_metadata,
+ TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+ std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+ std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+ const std::function<tensorflow::Status(TpuProgram*)>& compile_function) {
+ std::vector<CompilationEntry*> removed_entries;
+ auto status = CompileIfKeyAbsentHelper(
+ cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
+ may_modify_variables, &removed_entries, hlo_metadata, compile_function);
+ for (auto entry : removed_entries) {
+ UnloadAndDestroy(entry);
+ }
+ return status;
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
new file mode 100644
index 0000000..b6cdbe9
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
@@ -0,0 +1,394 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
+
+#include <functional>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/refcount.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+
+namespace tensorflow {
+namespace tpu {
+
+const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
+const char kCompilationCacheUnloaderResourceName[] =
+ "tpu_compilation_cache_unloader";
+
+// Base class that holds references to compiled protos so that the protos are
+// not garbage-collected before being used by execute ops. Use
+// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
+// ref holder object.
+class TpuCompilationRefHolder : public ResourceBase {
+ public:
+ ~TpuCompilationRefHolder() override = default;
+};
+
+class TpuCompilationCacheInterface : public ResourceBase {
+ public:
+ using Status = ::stream_executor::port::Status;
+
+ // An entry in the compilation cache. The entry is deleted once it has been
+ // marked for eviction from the cache _and_ all steps that use it have
+ // completed. When the entry is first created, it is uninitialized and a
+ // client-supplied compilation function is run outside the cache's lock to
+ // generate the programs to be stored in the entry. Any other client that
+ // requests the entry will block until it has been initialized. Each entry has
+ // a last_use value that set from a monotonically-increasing counter in the
+ // cache whenever the entry is referenced. When the cache becomes full,
+ // entries are marked for eviction in LRU order.
+ //
+ // The bridge can request XLA to generate separate sharding and unsharding
+ // programs along with the main program; we use nested fields sharding_entry,
+ // unsharding_entry to store them under the main entry, and these two fields
+ // must be either both present or both absent. They have a back pointer
+ // main_entry to refer to the main program. These nested entries share the
+ // same cache key and the same lifetime as the main entry, so we use the
+ // refcount on the main entry to track the access to any of them.
+ // /-------------------------------\
+ // v \
+ // main_entry (refcount) -> sharding_entry -> main_entry
+ // ^ \
+ // | \-> unsharding_entry -> main_entry
+ // \--------------------------------------/
+ struct CompilationEntry : public core::RefCounted {
+ TpuCompilationCacheInterface* parent = nullptr; // Not owned.
+ bool initialized = false;
+
+ // The Status returned by the compilation function when the entry is
+ // initialized. This status will be returned to any client that requests the
+ // entry.
+ Status initialization_status;
+
+ // The uid describing this entry.
+ int64 uid;
+ std::vector<string> proto_key;
+
+ // Counter to keep track of LRU entries for the eviction policy.
+ int64 last_use = -1;
+
+ // The unique key describing this entry.
+ std::string subgraph_key;
+
+ // Entries representing the associated sharding and unsharding programs,
+ // which share the same life time of the owning main entry, so we always use
+ // the main entry's ref count.
+ std::unique_ptr<CompilationEntry> sharding_entry;
+ std::unique_ptr<CompilationEntry> unsharding_entry;
+
+ // The number of 'external' client-held references to the entry.
+ int external_references = 0;
+
+ std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadata;
+
+ // The sum of the SpaceUsed of each of the elements of programs; an estimate
+ // of how much RAM the entry consumes, used to determine when entries must
+ // be marked for eviction.
+ int64 total_size = 0;
+
+ // Only used for the nested sharding/unsharding entries to point to the
+ // owning main entry.
+ CompilationEntry* main_entry = nullptr;
+
+ // Debug info in case we miss.
+ string cache_entry_debug_string;
+
+ // Compiled Tpu program.
+ std::unique_ptr<TpuProgram> tpu_program;
+ };
+
+ explicit TpuCompilationCacheInterface(int64_t max_cache_size);
+ ~TpuCompilationCacheInterface() override;
+ TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
+ TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&)
+ = delete;
+
+ Status CompileIfKeyAbsent(
+ const TpuCompilationCacheKey& cache_key,
+ const SessionMetadata* session_metadata,
+ TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+ std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+ std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+ const std::function<tensorflow::Status(TpuProgram*)>& compile_function);
+
+ static TpuCompilationCacheKey CreateCompilationCacheKey(
+ absl::string_view function_name, uint64 function_library_fingerprint,
+ absl::string_view mlir_module,
+ const tensorflow::OpInputList& guaranteed_constants,
+ const std::vector<tensorflow::TensorShape>& dynamic_shapes,
+ const tensorflow::tpu::TPUCompileMetadataProto& metadata,
+ const TpuMeshStateInterface& mesh_state);
+
+ string DebugString() const override { return "TpuCompilationCacheInterface"; }
+
+ // Makes a reference holder for this cache, that can be stored in the per-step
+ // resource manager and will ensure that compiled entries persist until the
+ // end of a step.
+ TpuCompilationRefHolder* MakePerStepRefHolder();
+
+ // Differences between MarkEntryForEviction and Release:
+ // There are two modes of managing cache entries:
+ // 1) LRU eviction + pinning; 2) manual.
+ // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
+ // Otherwise it is manual mode (mainly used by XRT).
+ // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
+ // entries when callers know that they do not need them anymore.
+ // Release should only be used in mode 2) to explicitly remove an entry.
+
+ // Mark the entry indexed by `subgraph_uid` for eviction. This should only be
+ // called if per_step_ref_holder was NOT nullptr in the corresponding call to
+ // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
+ // subgraph_uid).
+ Status MarkEntryForEviction(int64 subgraph_uid);
+
+ // Manually discards a reference to the compiled subgraph. This should only be
+ // called if per_step_ref_holder was nullptr in the corresponding call to
+ // CompileIfKeyAbsent(subgraph_key, ...).
+ Status Release(int64 subgraph_uid);
+
+ // Looks up an executable corresponding to the model-parallel core index of
+ // the subgraph represented by key. On success a pointer to an EntryRef
+ // holding the program is returned in entry.
+ Status Lookup(const string& proto_key,
+ std::unique_ptr<CompilationCacheEntryRef>* entry);
+
+ // Looks up an executable corresponding to the model-parallel core index of
+ // the subgraph represented by uid. On success a pointer to an EntryRef
+ // holding the program is returned in entry.
+ Status Lookup(int64 uid, int proto_index,
+ std::unique_ptr<CompilationCacheEntryRef>* entry);
+
+ // Mutates the main entry ref to point to the entry's subentry
+ // (for sharding/unsharding) or main entry (unchanged) representing the
+ // fetch target. The entry ref needs to point to the main entry before this
+ // call.
+ //
+ // If the requested subentry does not exist, the ref will point to a nullptr
+ // entry.
+ Status ToSubEntryRef(CompilationCacheEntryRef* entry,
+ CompilationCacheFetchTarget fetch_target) const;
+
+ private:
+ // Wrapper for a cache entry that holds a reference to the entry until the
+ // wrapper is deleted. This wrapper is the concrete type of
+ // CompilationCacheEntryRef returned by Lookup.
+ class EntryRefImpl : public CompilationCacheEntryRef {
+ public:
+ EntryRefImpl(TpuCompilationCacheInterface* parent, CompilationEntry* entry,
+ int index);
+ ~EntryRefImpl() override;
+
+ CompilationCacheEntry get() override;
+
+ // Mutates this ref to point to the entry's subentry (for
+ // sharding/unsharding) or main entry (unchanged) as specified by
+ // fetch_target. The refcount is kept unchanged, since we only track the
+ // refcount of the main entry. The entry ref needs to point to the main
+ // entry before this call.
+ //
+ // If the requested subentry does not exist, the ref will point to a nullptr
+ // entry, and the original entry will be unref'ed.
+ Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
+
+ private:
+ TpuCompilationCacheInterface* parent_; // Not owned.
+ // A reference to entry_ is acquired in the constructor and released via
+ // parent->DiscardEntryRefs in the destructor.
+ CompilationEntry* entry_;
+ // The program in entry_ that is returned by the get method.
+ int index_;
+ };
+
+ // Private implementation of the generic CompilationRefHolder that knows about
+ // CompiledSubgraph entries.
+ class RefHolder : public TpuCompilationRefHolder {
+ public:
+ explicit RefHolder(TpuCompilationCacheInterface* parent) : parent_(parent) {
+ parent_->Ref();
+ }
+ ~RefHolder() override {
+ // Release our reference to the parent.
+ parent_->Unref();
+ }
+
+ // Adds entry to the list of entries that will be released when the
+ // RefHolder is destroyed. Each entry is released via a call to
+ // parent_->DiscardEntryRefs.
+ void AddRef(CompilationEntry* entry) {
+ entries_.push_back(entry);
+ }
+
+ string DebugString() const override {
+ return "TpuCompilationCacheInterface::RefHolder";
+ }
+
+ private:
+ TpuCompilationCacheInterface* parent_; // Not owned.
+ std::vector<CompilationEntry*> entries_;
+ };
+
+ // The bulk of implementation of CompileIfKeyAbsent() with the exception
+ // of unloading programs that corresponds to possibly removed cache
+ // entries. The split helps to manage locking since we prefer to perform
+ // unloading without holding extra locks.
+ Status CompileIfKeyAbsentHelper(
+ const TpuCompilationCacheKey& subgraph_key,
+ const SessionMetadata* session_metadata,
+ TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+ std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+ std::vector<CompilationEntry*>* removed_entries,
+ std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+ const std::function<Status(TpuProgram*)>& compile_function);
+
+ // This is called by the cache when entry is marked for eviction; by
+ // a RefHolder (via DiscardEntryRefs) when a step completes; and by
+ // an EntryRefImpl when it is destroyed. Releases one reference to entry
+ // if more than 1 remains. If only one reference is left, the entry is removed
+ // from cache_ and is returned to the caller; which must eventually call
+ // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
+ // to avoid holding the lock during program unloading.
+ ABSL_MUST_USE_RESULT CompilationEntry* DiscardEntryRef(
+ CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Convenience method called by ~RefHolder without mu_ held. Calls
+ // DiscardEntryRef on every element of entries.
+ void DiscardEntryRefs(
+ gtl::ArraySlice<CompilationEntry*> entries);
+
+ // Marks the oldest unmarked entry for eviction. Requires that there is at
+ // least one such entry. In case the evicted entry had only 1 reference it
+ // is removed from the cache and returned to the caller which must eventually
+ // call UnloadAndDestroy.
+ CompilationEntry* MarkOldestEntryForEviction()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Updates datastructures to indicate that entry, which had been marked for
+ // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
+ // entry is newly created, or an entry that has been marked for eviction but
+ // not yet evicted is looked up.
+ //
+ // First the entry is unmarked for eviction, i.e. the cache gains a reference
+ // to entry, entry's last_use field is set to be the most recent value of
+ // use_counter_ and entries_by_last_use_ is updated accordingly.
+ //
+ // Next, the size of the cache is examined to see if any other entries need to
+ // be marked for eviction now that entry has been unmarked. While the total
+ // size of unmarked cached entries is greater than max_cache_size_, entries
+ // are marked for eviction in LRU order. The most recently used entry is never
+ // marked for eviction, so an entry larger than the max cache size will remain
+ // in the cache until it is replaced by something else. In case some entries
+ // actually were removed from the cache, they are a returned to the caller via
+ // removed_entries. The caller must eventually delete them by calling
+ // UnloadAndDestroy.
+ void LookupEntryMarkedForEviction(
+ CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Removes the entry with given key from cache.
+ size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Inserts the given key and entry to cache.
+ void InsertEntry(const std::string& key,
+ const TpuCompilationCacheKey& subgraph_key,
+ CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns the cache key matching given subgraph_key.
+ std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Creates a new entry by running initialize_programs and places it in the
+ // cache to be looked up by key. The new entry is in the 'marked for eviction'
+ // state (not present in entries_by_last_use_) and the caller is expected to
+ // call LookupEntryMarkedForEviction after InitializeEntry.
+ //
+ // **InitializeEntry releases mu_ during the call to initialize_programs.**
+ CompilationEntry* InitializeEntry(
+ const string& key,
+ const std::function<Status(TpuProgram*)>& initialize_program,
+ const TpuCompilationCacheKey& subgraph_key)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Unloads the program associated with the entry from all local devices
+ // and deletes the entry itself. It is assumed no one else has a reference
+ // to it and all related keys had already been removed from the cache.
+ // The call can perform device IO so no locks should be held while calling it.
+ void UnloadAndDestroy(CompilationEntry* entry) ABSL_LOCKS_EXCLUDED(mu_);
+
+ // The maximum size of entries that are stored in the cache before entries are
+ // marked for eviction.
+ const int64 max_cache_size_;
+
+ mutable absl::Mutex mu_;
+ // The total size of entries that are stored and not marked for eviction.
+ int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
+
+ // The total size of entries that are marked for eviction.
+ int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
+
+ // The value to assign to the last_use field of the next entry that is looked
+ // up.
+ int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
+
+ // session_key_map_ and fingerprint_key_map_ are used for looking up the
+ // cache_ key matching a given subgraph key. When doing a lookup, check
+ // session_key_map_ first to avoid unnecessay fingerprint computation.
+ // Map from key prefix + session_handle to a cache_ key.
+ std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
+
+ // Map from key prefix + fingerprint to a cache_ key.
+ std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
+
+ // All the subgraph entries that can be looked up in the cache. An entry is
+ // marked for eviction iff it is present in cache_ and not in
+ // entries_by_last_use_.
+ std::unordered_map<string, CompilationEntry*> cache_store_
+ ABSL_GUARDED_BY(mu_);
+
+ // All the subgraph entries that can be looked up in the cache, indexed by
+ // uid.
+ absl::node_hash_map<int64, CompilationEntry*> entries_by_uid_
+ ABSL_GUARDED_BY(mu_);
+
+ // All the protos that can be looked up in the cache, indexed by proto
+ // key. The value of the map is a subgraph and the index of the proto compiled
+ // for that subgraph.
+ std::unordered_map<string, std::pair<CompilationEntry*, int>>
+ entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
+
+ // Map from last_use to entry, used to mark entries for eviction in LRU
+ // order. If an entry's last_use counter is not present as a key in
+ // entries_by_last_use_ then the entry has been marked for eviction.
+ std::map<int64, CompilationEntry*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h
new file mode 100644
index 0000000..49c2eb6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
+
+#include <functional>
+#include <string>
+
+#include "absl/types/optional.h"
+
+namespace tensorflow {
+namespace tpu {
+
+struct TpuCompilationCacheKey {
+ // Prefix of the key.
+ std::string prefix;
+
+ // A boolean flag to specify if `guaranteed_const` is used. Guarantee const is
+ // normally used in TPU inference to avoid re-copying unchanged variables onto
+ // the TPU device. It promises the value is identical for every execution in
+ // the same session even if the actual value changes in later executions.
+ bool has_guaranteed_const = false;
+
+ // Unique session identifier. It is set when `has_guaranteed_const` is true.
+ std::string session_handle;
+
+ // Fingerprint of `guaranteed_const` value. It is set when the value of the
+ // `has_guaranteed_const` is true. Produce the value when necessary.
+ std::function<std::string()> guaranteed_const_fingerprint;
+
+ // A more verbose key for debugging purpose.
+ std::string debug_string;
+
+ explicit TpuCompilationCacheKey() {}
+ explicit TpuCompilationCacheKey(const std::string& p) : prefix(p) {}
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
new file mode 100644
index 0000000..f4f8dbf
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
@@ -0,0 +1,93 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
+
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+class CompilationCacheFetchTargetUtility {
+ public:
+ CompilationCacheFetchTargetUtility()
+ : names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
+
+ std::string name(CompilationCacheFetchTarget target) const {
+ return names_[static_cast<int>(target)];
+ }
+
+ private:
+ const std::vector<std::string> names_;
+};
+
+std::string GetName(CompilationCacheFetchTarget target) {
+ static const auto* util = new CompilationCacheFetchTargetUtility();
+ return util->name(target);
+}
+
+} // namespace
+
+TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
+ TpuCompilationCacheInterface* cache)
+ : cache_(cache) {}
+
+TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
+ cache_->Unref();
+}
+
+Status TpuCompilationCacheLocalLookup::Lookup(
+ const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
+ CompilationCacheFetchTarget fetch_target) {
+ profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
+ /*level=*/2);
+ Status s = cache_->Lookup(proto_key, entry);
+ VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
+ << s;
+ if (!s.ok()) {
+ return s;
+ }
+ s = cache_->ToSubEntryRef(entry->get(), fetch_target);
+
+ VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
+ << s;
+ return s;
+}
+
+Status TpuCompilationCacheLocalLookup::Lookup(
+ int64 uid, int proto_index,
+ std::unique_ptr<CompilationCacheEntryRef>* entry,
+ CompilationCacheFetchTarget fetch_target) {
+ profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
+ /*level=*/2);
+ Status s = cache_->Lookup(uid, proto_index, entry);
+ VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
+ << " in local subgraph cache status " << s;
+ if (!s.ok()) {
+ return s;
+ }
+ s = cache_->ToSubEntryRef(entry->get(), fetch_target);
+ VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
+ << s;
+ return s;
+}
+
+string TpuCompilationCacheLocalLookup::DebugString() const {
+ return "TpuCompilationCacheLocalLookup";
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
new file mode 100644
index 0000000..138777a
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
@@ -0,0 +1,99 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
+
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Base class allowing Execute Ops to look up ISA protos. Different subclasses
+// are used when the execute Op is in the same address space as the compile Op,
+// and when they need to communicate over RPC.
+class TpuCompilationCacheLookup : public ResourceBase {
+ public:
+ ~TpuCompilationCacheLookup() override = default;
+
+ // Looks up an executable corresponding to the model-parallel core index of
+ // the subgraph represented by key. On success a wrapper for the proto is
+ // returned in program. The wrapper is guaranteed to be valid only during the
+ // execution of the Op requesting the proto.
+ //
+ // Only one of the main, sharding, unsharding entries is fetched, as specified
+ // in fetch_target.
+ //
+ // If the compilation does not create sharding/unsharding programs, but the
+ // fetch_target requests one of them, then after this call
+ // (*entry)->get().get_executable() will return nullptr.
+ virtual Status Lookup(const string& proto_key,
+ std::unique_ptr<CompilationCacheEntryRef>* entry,
+ CompilationCacheFetchTarget fetch_target) = 0;
+
+ virtual Status Lookup(const string& proto_key,
+ std::unique_ptr<CompilationCacheEntryRef>* entry) {
+ return Lookup(proto_key, std::move(entry),
+ CompilationCacheFetchTarget::MAIN);
+ }
+
+ // Looks up an executable corresponding to the model-parallel core index of
+ // the subgraph represented by uid. On success a wrapper for the proto is
+ // returned in program. The wrapper is guaranteed to be valid only during the
+ // execution of the Op requesting the proto.
+ virtual Status Lookup(int64 uid, int proto_index,
+ std::unique_ptr<CompilationCacheEntryRef>* entry,
+ CompilationCacheFetchTarget fetch_target) = 0;
+
+ virtual Status Lookup(int64 uid, int proto_index,
+ std::unique_ptr<CompilationCacheEntryRef>* entry) {
+ return Lookup(uid, proto_index, std::move(entry),
+ CompilationCacheFetchTarget::MAIN);
+ }
+};
+
+// Forward declaration to break cycle dependency graph.
+class TpuCompilationCacheInterface;
+
+// Class for looking up ISA protos when the execute and compile Op are in the
+// same address space. The proto is simply looked up in the compilation cache,
+// without any serialization taking place.
+class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
+ public:
+ explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
+ ~TpuCompilationCacheLocalLookup() override;
+
+ Status Lookup(const string& proto_key,
+ std::unique_ptr<CompilationCacheEntryRef>* entry,
+ CompilationCacheFetchTarget fetch_target) override;
+
+ Status Lookup(int64 uid, int proto_index,
+ std::unique_ptr<CompilationCacheEntryRef>* entry,
+ CompilationCacheFetchTarget fetch_target) override;
+
+ string DebugString() const override;
+
+ private:
+ // The subgraph compilation cache, in the same process address space where the
+ // lookups are happening.
+ TpuCompilationCacheInterface* cache_;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc
new file mode 100644
index 0000000..ba4e2cc
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
+
+namespace tensorflow {
+namespace tpu {
+
+/* static */
+void TpuCompilationCacheMetrics::IncrementCacheLookupCount(
+ bool is_cache_hit, absl::string_view session_name) {
+ // A placeholder for tracking metrics.
+}
+
+/* static */
+void TpuCompilationCacheMetrics::SetCacheEntryCount(int64 count) {
+ // A placeholder for tracking metrics.
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h
new file mode 100644
index 0000000..e30a7a4
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Tracks Tpu compilation cache metrics.
+class TpuCompilationCacheMetrics {
+ public:
+ // Increments the number of cache lookup count.
+ static void IncrementCacheLookupCount(bool is_cache_hit,
+ absl::string_view session_name);
+
+ // Sets the total count of cache entries.
+ static void SetCacheEntryCount(int64 count);
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile.proto b/tensorflow/core/tpu/kernels/tpu_compile.proto
new file mode 100644
index 0000000..5b70de6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile.proto
@@ -0,0 +1,144 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/protobuf/tpu/compile_metadata.proto";
+import "tensorflow/core/tpu/kernels/tpu_executable_info.proto";
+
+message PerCoreVariableIndices {
+ // For each resource variable output, what was the index of the corresponding
+ // input and was it updated? The indices are sorted by input order.
+ repeated TPUExecutableInfoProto.UpdateIndexPair variable_indices = 1;
+}
+
+message PerCoreArgShapes {
+ // Argument shapes for each Tpu core.
+ repeated xla.ShapeProto shapes = 1;
+}
+
+message PerCoreOutputShapes {
+ // Output shapes for each Tpu core.
+ repeated xla.ShapeProto shapes = 1;
+}
+
+message OutputDescriptionProto {
+ // Type and shape of the output. The shape is the unflattened shape.
+ // When `type` is DT_RESOURCE, `shape` is the shape of the resource
+ // variable's value.
+ tensorflow.DataType type = 1;
+ tensorflow.TensorShapeProto shape = 2;
+
+ // Constant output value, if known to be constant at JIT compilation time.
+ // 'Tensor' is in host memory.
+ bool is_constant = 3;
+ tensorflow.TensorProto constant_value = 4;
+
+ // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
+ // the index of the input that contains the resource.
+ int32 input_index = 5;
+
+ // Whether this output is a TensorList.
+ bool is_tensor_list = 6;
+}
+
+// Describes a variable write side effect of the computation.
+message ResourceUpdateProto {
+ // Index of the input that contains the variable resource to write to.
+ int32 input_index = 1;
+
+ // Type and shape of the tensor to be written back.
+ // The `shape` field has the same meaning as the Argument::shape field.
+ tensorflow.DataType type = 2;
+ tensorflow.TensorShapeProto shape = 3;
+
+ // Was the value of the variable modified by the computation?
+ // (Always true, unless `return_updated_values_for_all_resources` is true.)
+ bool modified = 4;
+
+ // If the resource is a TensorArray, the set of gradients read or written.
+ map<string, bool> tensor_array_gradients_accessed = 5;
+}
+
+// Describes the result of a XLA Compiler compilation.
+message XlaCompilationResultProto {
+ // Vector that maps from the parameters of the XLA computation to their
+ // original argument positions. To handle compile-time constant inputs, the
+ // parameters to the XLA computation may be a subset of the original
+ // arguments. The relative ordering of parameters are maintained.
+ repeated int32 input_mappings = 1;
+
+ // Input shapes of the computation. If we are flattening inputs, these are
+ // the flattened shapes.
+ repeated xla.ShapeProto xla_input_shapes = 2;
+
+ // Output shape in XLA format. The output shape is always a tuple. If we
+ // are flattening outputs, these are the flattened shapes.
+ xla.ShapeProto xla_output_shape = 3;
+
+ // TensorFlow shapes of outputs, together with the values of any
+ // constant arguments. Vector indexed by Tensorflow _Retval number,
+ // containing both constant and non-constant results.
+ repeated OutputDescriptionProto outputs = 4;
+
+ // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
+ // matching RecvAtHost/SendFromHost Ops in the outer graph.
+ tf2xla.HostComputeMetadata host_compute_metadata = 5;
+
+ // Resources whose values were updated by the computation, ordered
+ // by return value position (which is the same as the order the resources
+ // were passed as arguments). Resource updates follow the non-constant
+ // results in the outputs of XLA computation.
+ repeated ResourceUpdateProto resource_updates = 6;
+
+ // The XLA computation built from the tensorflow subgraph.
+ xla.HloModuleProto computation = 7;
+}
+
+// TpuAotCompilationRequestProto represents a compilation request for performing
+// ahead-of-time (AOT) compilation of XLA Computations into XLA HLO IR.
+message TpuAotCompilationRequestProto {
+ // A set of HLO module built to run concurrently
+ // across different devices.
+ xla.HloModuleGroupProto hlo_module_group = 1;
+
+ // Compilation metadata.
+ TPUCompileMetadataProto metadata = 2;
+
+ // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
+ // represents the device ids assigned to a set of replicated computations.
+ // See xla::DeviceAssignment class comment for more details.
+ xla.DeviceAssignmentProto device_assignment = 3;
+
+ // Per TPU core program arguments shapes.
+ repeated PerCoreArgShapes per_core_arg_shapes = 4;
+
+ // Per TPU core program outputs shapes.
+ repeated PerCoreOutputShapes per_core_output_shapes = 5;
+
+ // Per TPU core information containing what was the index of the corresponding
+ // input and if whether it was updated. The indices are sorted by input order.
+ repeated PerCoreVariableIndices per_core_variable_indices = 6;
+
+ // XLA compiler compilation result.
+ XlaCompilationResultProto compilation_result = 7;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
new file mode 100644
index 0000000..53e79aa
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
@@ -0,0 +1,119 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
+
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+enum TpuCoreTypeEnum {
+ kTensorCore,
+ kEmbeddingV1,
+ kEmbeddingV2,
+};
+
+typedef struct XLA_TpuProgram XLA_TpuProgram;
+
+// Property for creating compilation cache key.
+struct CompilationCacheKeyProperty {
+ const char* config_prefix;
+ const char* shapes_prefix;
+ const char* function_name;
+ const char* mlir_module;
+ const int32_t* device_ids;
+ size_t device_ids_size;
+ int32_t guaranteed_constants_size;
+ uint64_t function_library_fingerprint;
+ int32_t num_cores_per_replica;
+ int32_t num_replicas;
+ const XLA_TpuMeshState* mesh_state;
+};
+
+extern "C" {
+
+// Creates a new TPU program.
+XLA_TpuProgram* TpuProgram_New();
+
+// Destroys the `tpu_program`.
+void TpuProgram_Free(XLA_TpuProgram* tpu_program);
+
+
+// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
+// destroyed, it is in an unusable state.
+void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
+ SE_Status* status);
+
+// Gets TPU program size in bytes from the `tpu_program`.
+int64_t TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
+
+// Logs the summary of current memory state snapshot of the `tpu_program`.
+bool TpuProgram_LogProgramMemorySummary(const XLA_TpuProgram* tpu_program);
+
+// Gets TPU program executable info from the `tpu_program`.
+void TpuProgram_GetExecutableInfo(const XLA_TpuProgram* tpu_program,
+ TpuSerializedProto* executable_info);
+
+// Gets host transfer info proto.
+void TpuProgram_GetHostTransferInfo(
+ const XLA_TpuProgram* tpu_program,
+ TpuSerializedProto* host_transfer_info);
+
+// Gets HLO metadata proto.
+void TpuProgram_GetHloMetadata(const XLA_TpuProgram* tpu_program,
+ TpuSerializedProto* hlo_metadata);
+
+// Returns the number of available TPU core count.
+int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state,
+ TpuCoreTypeEnum tpu_core_type);
+
+// Creates a unique compilation cache `key` used for `put` and `get` operations.
+// Returned buffer is heap-allocated and must be owned.
+const char* TpuCompile_CreateCompilationCacheKey(
+ CompilationCacheKeyProperty property);
+
+// Creates a guaranteed const fingerprint. Guarantee const is normally used in
+// TPU inference to avoid re-copying unchanged variables onto the TPU device.
+// It promises the value is identical for every execution in the same session
+// even if the actual value changes in later executions.
+uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint,
+ const char* data,
+ size_t size);
+
+// Checks if whether a TPU compilation is enabled.
+bool TpuCompile_IsTpuCompilationEnabled();
+
+// Executes the computations using XLA TPU compiler and returns TPU programs
+// ready for execution.
+void TpuCompile_CompileAheadOfTime(
+ TpuSerializedProto aot_compilation_request,
+ XLA_TpuProgram** tpu_programs[],
+ size_t* count, SE_Status* status);
+
+// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
+void TpuCompile_BuildXLADeviceAssignment(
+ TpuSerializedProto serialized_tpu_compile_metadata,
+ const XLA_TpuMeshState* mesh_state,
+ TpuSerializedProto* serialized_device_assignment, SE_Status* status);
+
+// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
+void TpuCompile_ToTpuShapeRepresentation(
+ TpuSerializedProto serialized_xla_shape, int data_type,
+ bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
+ SE_Status* status);
+
+} // extern "C"
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc
new file mode 100644
index 0000000..49a2a08
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc
@@ -0,0 +1,42 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
+
+namespace tensorflow {
+namespace internal {
+
+namespace {
+static bool tpu_compilation_cancellation_terminates_process = true;
+static bool tpu_compilation_failure_closes_chips = true;
+} // namespace
+
+void SetTpuCompilationCancellationTerminatesProcess(bool b) {
+ tpu_compilation_cancellation_terminates_process = b;
+}
+
+bool TpuCompilationCancellationTerminatesProcess() {
+ return tpu_compilation_cancellation_terminates_process;
+}
+
+void SetTpuCompilationFailureClosesChips(bool value) {
+ tpu_compilation_failure_closes_chips = value;
+}
+
+bool TpuCompilationFailureClosesChips() {
+ return tpu_compilation_failure_closes_chips;
+}
+
+} // namespace internal
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_options.h b/tensorflow/core/tpu/kernels/tpu_compile_op_options.h
new file mode 100644
index 0000000..b81fe4a
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_options.h
@@ -0,0 +1,42 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
+
+#include <string>
+
+namespace tensorflow {
+namespace internal {
+
+// Setter and getter that determine how TPUCompile responds to cancelled
+// compilation. By default this is true, meaning cancelled compilation will
+// abort the process, since that's the only mechanism we have available.
+//
+// Setting this to false allows the process to remain alive, and should only be
+// used in tests.
+void SetTpuCompilationCancellationTerminatesProcess(bool b);
+bool TpuCompilationCancellationTerminatesProcess();
+
+// Setter and getter that determine whether TPU compilation failure will cause
+// chips to close. By default this is true, it is suitable for training. For
+// inference, we never want servers to die and thus chips will keep alive.
+// See b/109873767.
+void SetTpuCompilationFailureClosesChips(bool value);
+bool TpuCompilationFailureClosesChips();
+
+} // namespace internal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
new file mode 100644
index 0000000..d42c604
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
@@ -0,0 +1,439 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+
+#include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/dump.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+
+namespace tensorflow {
+namespace tpu {
+
+using stream_executor::port::Status;
+using stream_executor::port::StatusOr;
+using xla::ComputationLayout;
+using xla::DebugOptions;
+using xla::DeviceAssignment;
+using xla::HloModuleConfig;
+using xla::HloSharding;
+using xla::InvalidArgument;
+using xla::ProgramShape;
+using xla::Shape;
+using xla::ShapeTree;
+using xla::ShapeUtil;
+
+Status ValidateResultShape(const Shape& client_shape,
+ const Shape& result_shape) {
+ TF_RETURN_IF_ERROR(
+ xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
+ if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
+ return InvalidArgument(
+ "Shape used to set computation result layout %s is not compatible "
+ "with result shape %s",
+ xla::ShapeUtil::HumanStringWithLayout(client_shape),
+ xla::ShapeUtil::HumanString(result_shape));
+ }
+ return Status::OK();
+}
+
+StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
+ const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
+ absl::optional<const Shape> result_layout,
+ absl::optional<const DeviceAssignment> device_assignment, int replica_count,
+ int num_partitions, const DebugOptions* debug_options, const int* seed,
+ const int* launch_id, const bool* alias_passthrough_params,
+ const xla::FusionConfigCollection* fusion_config_collection,
+ const std::vector<std::vector<bool>>* fusion_config) {
+ auto config = absl::make_unique<HloModuleConfig>(program_shape);
+ ComputationLayout* computation_layout =
+ config->mutable_entry_computation_layout();
+ if (program_shape.parameters_size() != argument_shapes.size()) {
+ return InvalidArgument("computation takes %d parameters, but %u given",
+ program_shape.parameters_size(),
+ argument_shapes.size());
+ }
+ for (int i = 0; i < argument_shapes.size(); ++i) {
+ // Verify that shape of arguments matches the shape of the arguments in the
+ // ProgramShape.
+ if (!ShapeUtil::Compatible(argument_shapes[i],
+ program_shape.parameters(i))) {
+ return InvalidArgument(
+ "Argument does not match shape of computation parameter %d: want "
+ "%s, got %s",
+ i, ShapeUtil::HumanString(program_shape.parameters(i)),
+ ShapeUtil::HumanString(argument_shapes[i]));
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ argument_shapes[i]));
+ }
+
+ if (result_layout.has_value()) {
+ TF_RETURN_IF_ERROR(
+ ValidateResultShape(result_layout.value(), program_shape.result()));
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ result_layout.value()));
+ } else {
+ // If the result layout is not set, then choose the default.
+ computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ }
+
+ config->set_replica_count(replica_count);
+ config->set_num_partitions(num_partitions);
+ if (seed != nullptr) {
+ config->set_seed(*seed);
+ }
+ if (launch_id != nullptr) {
+ config->set_launch_id(*launch_id);
+ }
+ if (debug_options != nullptr) {
+ config->set_debug_options(*debug_options);
+ } else {
+ config->set_debug_options(xla::GetDebugOptionsFromFlags());
+ }
+
+ // TODO(henrytan): set intra_op_parallelism_threads.
+ // Reference:
+ // tensorflow/compiler/xla/service/service.cc?l=324.
+
+ if (device_assignment.has_value()) {
+ config->set_static_device_assignment(device_assignment.value());
+ }
+
+ if (alias_passthrough_params != nullptr) {
+ config->set_alias_passthrough_params(*alias_passthrough_params);
+ }
+
+ if (fusion_config_collection != nullptr && fusion_config != nullptr &&
+ *fusion_config_collection != xla::FusionConfigCollection::kOff) {
+ config->set_fusion_config_collection(*fusion_config_collection);
+ *config->mutable_fusion_config() = *fusion_config;
+ }
+
+ return std::move(config);
+}
+
+StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
+ const xla::ProgramShape& program_shape,
+ absl::Span<const Shape> argument_shapes,
+ absl::optional<const Shape> result_layout,
+ absl::optional<const DeviceAssignment> device_assignment, int replica_count,
+ int num_partitions, const DebugOptions* debug_options) {
+ return CreateModuleConfig(program_shape, argument_shapes, result_layout,
+ device_assignment, replica_count, num_partitions,
+ debug_options, /*seed=*/nullptr,
+ /*launch_id=*/nullptr,
+ /*alias_passthrough_params=*/nullptr,
+ /*fusion_config_collection=*/nullptr,
+ /*fusion_config=*/nullptr);
+}
+
+ShapeTree<HloSharding> GetSubtree(
+ const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
+ ShapeTree<HloSharding> element_shape_tree(
+ xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
+ element_index),
+ HloSharding::Replicate());
+
+ xla::ShapeIndex src_index;
+ src_index.push_back(element_index);
+ element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
+ return element_shape_tree;
+}
+
+Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
+ int64 device) {
+ if (shape.IsTuple()) {
+ ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
+ std::vector<Shape> arg_shapes;
+ for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
+ Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
+ HloSharding element_sharding = tuple_shape_tree.element({i});
+ if (element_shape.IsTuple()) {
+ element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
+ }
+ if (element_sharding.UsesDevice(device)) {
+ arg_shapes.push_back(
+ GetPerDeviceShape(element_shape, element_sharding, device));
+ }
+ }
+ return xla::ShapeUtil::MakeTupleShape(arg_shapes);
+ }
+
+ if (sharding.IsTileMaximal()) {
+ return shape;
+ }
+
+ std::vector<int64> dimensions;
+ std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device);
+ std::vector<int64> limit = sharding.TileLimitForDevice(shape, device);
+ for (int64 i = 0; i < limit.size(); ++i) {
+ dimensions.push_back(limit[i] - offset[i]);
+ }
+ if (shape.has_layout()) {
+ return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
+ shape.layout().minor_to_major());
+ }
+ return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
+}
+
+Status AddVariableUpdatesToCores(
+ const TPUCompileMetadataProto& metadata,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const std::vector<ShardingAndIndex>& arg_core_mapping,
+ std::vector<bool>* may_modify_variables,
+ std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
+ std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
+ // Add all variables to the corresponding core.
+ may_modify_variables->resize(metadata.num_cores_per_replica(), false);
+ int resource_update_pos = 0;
+ for (int i = 0; i < metadata.args_size(); ++i) {
+ const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
+ if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
+ const auto& sharding = proto_arg.sharding();
+ bool updated = false;
+ if (resource_update_pos < compilation_result.resource_updates.size()) {
+ const XlaCompiler::ResourceUpdate& update =
+ compilation_result.resource_updates[resource_update_pos];
+ if (update.input_index == i) {
+ updated = true;
+ int pos = compilation_result.outputs.size() + resource_update_pos;
+ xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
+ compilation_result.xla_output_shape, pos);
+ auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) {
+ (*per_core_output_shapes)[core].push_back(per_core_shape);
+ (*may_modify_variables)[core] =
+ (*may_modify_variables)[core] || update.modified;
+ };
+ if (sharding.type() == xla::OpSharding::MAXIMAL) {
+ add_to_core(sharding.tile_assignment_devices(0), shape);
+ } else if (sharding.type() == xla::OpSharding::OTHER) {
+ auto sharding_or =
+ xla::HloSharding::FromProto(proto_arg.sharding());
+ TF_RET_CHECK(sharding_or.ok());
+ for (int64 core : proto_arg.sharding().tile_assignment_devices()) {
+ xla::Shape per_core_shape =
+ GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
+ add_to_core(core, per_core_shape);
+ }
+ } else {
+ TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
+ for (int64 core = 0; core < metadata.num_cores_per_replica();
+ ++core) {
+ add_to_core(core, shape);
+ }
+ }
+ ++resource_update_pos;
+ }
+ }
+ if (sharding.type() == xla::OpSharding::MAXIMAL) {
+ (*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
+ .push_back(
+ std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
+ } else if (sharding.type() == xla::OpSharding::OTHER) {
+ for (int core : sharding.tile_assignment_devices()) {
+ (*per_core_variable_indices)[core].push_back(
+ std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
+ }
+ } else {
+ TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
+ for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) {
+ (*per_core_variable_indices)[core].push_back(
+ std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status ComputeOutputShapesForEachCore(
+ const tpu::TPUCompileMetadataProto& metadata,
+ const XlaCompiler::CompilationResult& compilation_result,
+ std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
+ for (int i = 0; i < metadata.retvals_size(); ++i) {
+ const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
+ TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
+ << "TPU compilation output " << i
+ << " has a compile-time constant value. "
+ "This should never happen.";
+
+ xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
+ compilation_result.xla_output_shape, i);
+ auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
+ (*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
+ };
+ if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
+ add_shape_to_core(retval.sharding().tile_assignment_devices(0),
+ std::move(shape));
+ } else if (retval.sharding().type() == xla::OpSharding::OTHER) {
+ auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
+ TF_RET_CHECK(sharding_or.ok());
+ for (int64 core : retval.sharding().tile_assignment_devices()) {
+ xla::Shape per_core_shape =
+ GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
+ add_shape_to_core(core, std::move(per_core_shape));
+ }
+ } else {
+ TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
+ << "Not all of the constant tensors were consumed.";
+ for (int core = 0; core < per_core_output_shapes->size(); ++core) {
+ add_shape_to_core(core, shape);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status CreateHloModules(
+ const TPUCompileMetadataProto& metadata,
+ const tensorflow::XlaCompiler::CompilationResult& compilation_result,
+ const absl::optional<xla::DeviceAssignment>& device_assignment,
+ std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
+ TF_RET_CHECK(
+ compilation_result.computation->proto().has_host_program_shape());
+
+ auto debug_options = xla::DebugOptions();
+ debug_options.set_xla_step_marker_location(metadata.step_marker_location());
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<xla::HloModuleConfig> module_config,
+ CreateModuleConfig(
+ xla::ProgramShape(
+ compilation_result.computation->proto().host_program_shape()),
+ compilation_result.xla_input_shapes,
+ compilation_result.xla_output_shape, device_assignment,
+ metadata.num_replicas(), metadata.num_cores_per_replica(),
+ &debug_options));
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<xla::HloModule> hlo_module,
+ xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
+ *module_config));
+ DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
+ hlo_modules->push_back(std::move(hlo_module));
+
+ return Status::OK();
+}
+
+XlaCompilationResultProto SerializeCompilationResult(
+ const XlaCompiler::CompilationResult& compilation_result) {
+ XlaCompilationResultProto compilation_result_proto;
+ for (int input_mapping : compilation_result.input_mapping) {
+ compilation_result_proto.add_input_mappings(input_mapping);
+ }
+
+ for (const Shape& input_shape : compilation_result.xla_input_shapes) {
+ *(compilation_result_proto.add_xla_input_shapes()) = input_shape.ToProto();
+ }
+ *(compilation_result_proto.mutable_xla_output_shape()) =
+ compilation_result.xla_output_shape.ToProto();
+
+ for (const XlaCompiler::OutputDescription& output_description :
+ compilation_result.outputs) {
+ auto* new_output = compilation_result_proto.add_outputs();
+ new_output->set_type(output_description.type);
+ output_description.shape.AsProto(new_output->mutable_shape());
+ new_output->set_is_constant(output_description.is_constant);
+ output_description.constant_value.AsProtoField(
+ new_output->mutable_constant_value());
+ new_output->set_input_index(output_description.input_index);
+ new_output->set_is_tensor_list(output_description.is_tensor_list);
+ }
+
+ *compilation_result_proto.mutable_host_compute_metadata() =
+ compilation_result.host_compute_metadata;
+
+ for (const XlaCompiler::ResourceUpdate& resource_update :
+ compilation_result.resource_updates) {
+ auto* new_resource_update = compilation_result_proto.add_resource_updates();
+ new_resource_update->set_input_index(resource_update.input_index);
+ new_resource_update->set_type(resource_update.type);
+ resource_update.shape.AsProto(new_resource_update->mutable_shape());
+ new_resource_update->set_modified(resource_update.modified);
+ for (const std::string& gradient_access :
+ resource_update.tensor_array_gradients_accessed) {
+ new_resource_update->mutable_tensor_array_gradients_accessed()->insert(
+ {gradient_access, true});
+ }
+ }
+
+ if (compilation_result.computation != nullptr) {
+ *compilation_result_proto.mutable_computation() =
+ compilation_result.computation->proto();
+ }
+
+ return compilation_result_proto;
+}
+
+StatusOr<TpuAotCompilationRequestProto> CreateTpuAotCompilationRequest(
+ const xla::HloModuleGroup& module_group,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const TPUCompileMetadataProto& metadata,
+ const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+ const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+ const std::vector<std::vector<std::pair<int, bool>>>&
+ per_core_variable_indices,
+ const absl::optional<xla::DeviceAssignment>& device_assignment) {
+ VLOG(1) << "CreateTpuAotCompilationRequest.";
+ TpuAotCompilationRequestProto aot_request;
+ *(aot_request.mutable_hlo_module_group()) = module_group.ToProto();
+ *(aot_request.mutable_metadata()) = metadata;
+ if (device_assignment.has_value()) {
+ xla::DeviceAssignmentProto device_assignment_proto;
+ Status status = device_assignment->Serialize(&device_assignment_proto);
+ if (!status.ok()) {
+ return status;
+ }
+ *(aot_request.mutable_device_assignment()) = device_assignment_proto;
+ }
+
+ for (const auto& arg_shapes : per_core_arg_shapes) {
+ auto* new_shape_list = aot_request.add_per_core_arg_shapes();
+ for (const auto& arg_shape : arg_shapes) {
+ *new_shape_list->add_shapes() = arg_shape.ToProto();
+ }
+ }
+
+ for (const auto& output_shapes : per_core_output_shapes) {
+ auto* new_shape_list = aot_request.add_per_core_output_shapes();
+ for (const auto& output_shape : output_shapes) {
+ *new_shape_list->add_shapes() = output_shape.ToProto();
+ }
+ }
+
+ for (const auto& variable_indices : per_core_variable_indices) {
+ auto* new_list = aot_request.add_per_core_variable_indices();
+ for (const auto& variable_index : variable_indices) {
+ auto* core_index = new_list->add_variable_indices();
+ core_index->set_index(variable_index.first);
+ core_index->set_updated(variable_index.second);
+ }
+ }
+
+ XlaCompilationResultProto compilation_result_proto =
+ SerializeCompilationResult(compilation_result);
+ *aot_request.mutable_compilation_result() = compilation_result_proto;
+
+ VLOG(1) << "TpuAotCompilationRequest:\n" << aot_request.DebugString();
+ return aot_request;
+}
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
new file mode 100644
index 0000000..0f21e45
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
@@ -0,0 +1,122 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace se = ::stream_executor;
+
+// Describes the position of an argument or return value after the computation
+// has been partitioned into cores.
+struct ShardingAndIndex {
+ // Sharding across cores.
+ ::xla::OpSharding sharding;
+ // Argument/return value number. If sharding is single-core, `indices` has a
+ // single element; otherwise, it has num_cores elements.
+ std::vector<int> indices;
+};
+
+// TODO(b/158279168): Dedup with internal version.
+// Return the per-device shape for a `shape` with a given `sharding`.
+xla::Shape GetPerDeviceShape(const xla::Shape& shape,
+ const xla::HloSharding& sharding,
+ int64 device);
+
+stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
+CreateModuleConfig(
+ const xla::ProgramShape& program_shape,
+ absl::Span<const xla::Shape> argument_shapes,
+ absl::optional<const xla::Shape> result_layout,
+ absl::optional<const xla::DeviceAssignment> device_assignment,
+ int replica_count, int num_partitions,
+ const xla::DebugOptions* debug_options, const int* seed,
+ const int* launch_id, const bool* alias_passthrough_params,
+ const xla::FusionConfigCollection* fusion_config_collection,
+ const std::vector<std::vector<bool>>* fusion_config);
+
+stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
+CreateModuleConfig(
+ const xla::ProgramShape& program_shape,
+ absl::Span<const xla::Shape> argument_shapes,
+ absl::optional<const xla::Shape> result_layout,
+ absl::optional<const xla::DeviceAssignment> device_assignment,
+ int replica_count,
+ int num_partitions, const xla::DebugOptions* debug_options);
+
+xla::ShapeTree<xla::HloSharding> GetSubtree(
+ const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
+ int element_index);
+
+xla::Shape GetPerDeviceShape(const xla::Shape& shape,
+ const xla::HloSharding& sharding,
+ int64 device);
+
+Status AddVariableUpdatesToCores(
+ const TPUCompileMetadataProto& metadata,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const std::vector<ShardingAndIndex>& arg_core_mapping,
+ std::vector<bool>* may_modify_variables,
+ std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
+ std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices);
+
+se::port::Status ComputeOutputShapesForEachCore(
+ const tpu::TPUCompileMetadataProto& metadata,
+ const XlaCompiler::CompilationResult& compilation_result,
+ std::vector<std::vector<xla::Shape>>* per_core_output_shapes);
+
+se::port::Status CreateHloModules(
+ const TPUCompileMetadataProto& metadata,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const absl::optional<xla::DeviceAssignment>& device_assignment,
+ std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules);
+
+se::port::StatusOr<TpuAotCompilationRequestProto>
+CreateTpuAotCompilationRequest(
+ const xla::HloModuleGroup& module_group,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const TPUCompileMetadataProto& metadata,
+ const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+ const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+ const std::vector<std::vector<std::pair<int, bool>>>&
+ per_core_variable_indices,
+ const absl::optional<xla::DeviceAssignment>& device_assignment);
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
new file mode 100644
index 0000000..7fa345d
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
@@ -0,0 +1,298 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
+
+#include <cstdint>
+
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/refcount.h"
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
+#include "tensorflow/core/tpu/tpu_config_c_api.h"
+#include "tensorflow/core/tpu/tpu_configuration.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+
+namespace tensorflow {
+namespace {
+
+Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
+ tpu::TpuMeshStateInterface** state) {
+ if (!rmgr->Lookup(rmgr->default_container(),
+ tpu::kTpuMeshCommonStateResourceName, state)
+ .ok()) {
+ return errors::FailedPrecondition(
+ "The TPU system has not been initialized.");
+ }
+ return Status::OK();
+}
+
+// Attempt to delete resource_name from resource_manager's default_container.
+// Returns OK if the deletion succeeded, or if the resource was not found. Else
+// return the deletion error.
+template <class ResourceT>
+Status DeleteIfExists(ResourceMgr* resource_manager,
+ const char* resource_name) {
+ VLOG(1) << "Removing resource " << resource_name << " if it exists";
+ Status status = resource_manager->Delete<ResourceT>(
+ resource_manager->default_container(), resource_name);
+ if (status.ok()) {
+ VLOG(1) << "Removed existing resource " << resource_name;
+ return Status::OK();
+ }
+ if (status.code() == error::NOT_FOUND) {
+ VLOG(1) << "No resource " << resource_name << " to remove";
+ return Status::OK();
+ }
+ VLOG(1) << "Error removing resource " << resource_name << " : " << status;
+ return status;
+}
+
+} // namespace
+
+void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "ConfigureDistributedTpuOp";
+ XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
+
+ std::vector<int32_t> num_devices_per_host;
+ int chips_per_host = -1;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const Tensor& input_tensor = ctx->input(i);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
+ errors::InvalidArgument("Input ", i, " should be a scalar but has ",
+ input_tensor.dims(), " dimensions"));
+ if (chips_per_host == -1) {
+ chips_per_host = input_tensor.scalar<int32_t>()();
+ } else {
+ OP_REQUIRES(
+ ctx, chips_per_host == input_tensor.scalar<int32>()(),
+ errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
+ " TPU chips but host 0 has ", chips_per_host));
+ }
+ num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
+ }
+
+ TF_Status* status = TF_NewStatus();
+ size_t host_config_output_size;
+ char* host_config_output;
+
+ auto* rmgr = GetTPUConfigResourceMgr();
+ OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
+ rmgr, tpu::kTpuMeshCommonStateResourceName));
+
+ ConfigureDistributedTpuOp_DoWork(
+ num_devices_per_host.size(), num_devices_per_host.data(),
+ &host_config_output_size, &host_config_output, status);
+
+ OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
+ tpu::kTpuMeshCommonStateResourceName,
+ tpu::TpuMeshStateInterface::Create()));
+
+ Tensor* ctx_output;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
+ ctx_output->scalar<tstring>()() =
+ std::string(host_config_output, host_config_output_size);
+
+ OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+ TF_DeleteStatus(status);
+ TpuConfigurationApi_FreeCharArray(host_config_output);
+
+ VLOG(1) << "ConfigureDistributedTpuOp done";
+}
+
+void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "WaitForDistributedTpuOp";
+ XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
+
+ size_t num_devices_per_host = -1;
+ size_t num_hosts = ctx->num_inputs();
+
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
+ OP_REQUIRES(
+ ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
+ errors::InvalidArgument("Input ", i, " should be a vector but has ",
+ host_ordinal_to_global_device_id_tensor.dims(),
+ " dimensions"));
+ }
+
+ std::vector<std::vector<int32_t>> mapping;
+ std::vector<int32_t*> mapping_arg;
+
+ mapping.resize(ctx->num_inputs());
+
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
+ const auto host_ordinal_to_global_device_id =
+ host_ordinal_to_global_device_id_tensor.flat<int>();
+ if (num_devices_per_host == -1) {
+ num_devices_per_host =
+ host_ordinal_to_global_device_id_tensor.dim_size(0);
+ } else {
+ OP_REQUIRES(ctx,
+ num_devices_per_host ==
+ host_ordinal_to_global_device_id_tensor.dim_size(0),
+ errors::Internal(
+ "Host ", i, " has ",
+ host_ordinal_to_global_device_id_tensor.dim_size(0),
+ " TPU devices but host 0 has ", num_devices_per_host));
+ }
+ for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
+ ++j) {
+ int32_t global_device_id = host_ordinal_to_global_device_id(j);
+ mapping[i].push_back(global_device_id);
+ }
+ mapping_arg.push_back(mapping[i].data());
+ }
+
+ TF_Status* status = TF_NewStatus();
+ size_t tpu_topology_output_size;
+ char* tpu_topology_output;
+
+ tpu::TpuMeshStateInterface* mesh_state;
+ auto* rmgr = GetTPUConfigResourceMgr();
+ OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
+ core::ScopedUnref mesh_state_unref(mesh_state);
+
+ WaitForDistributedTpuOp_DoWork(
+ num_hosts, num_devices_per_host,
+ const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
+ &tpu_topology_output_size, &tpu_topology_output, status);
+
+ Tensor* ctx_output;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
+ ctx_output->scalar<tstring>()() =
+ std::string(tpu_topology_output, tpu_topology_output_size);
+
+ OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+ TF_DeleteStatus(status);
+ TpuConfigurationApi_FreeCharArray(tpu_topology_output);
+
+ VLOG(1) << "WaitForDistributedTpuOp done";
+}
+
+void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "ShutdownDistributedTpuOp";
+ XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
+
+ TF_Status* status = TF_NewStatus();
+ OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
+ GetTPUConfigResourceMgr(),
+ tpu::kTpuMeshCommonStateResourceName));
+ ShutdownDistributedTpuOp_DoWork(status);
+ OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+ TF_DeleteStatus(status);
+
+ VLOG(1) << "ShutdownDistributedTpuOp done";
+}
+
+void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "InitializeHostForDistributedTpuOp";
+ XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
+
+ auto tpu_host_config = ctx->input(0).scalar<tstring>()();
+
+ size_t device_id_output_size;
+ int32_t* device_id_output;
+ TF_Status* status = TF_NewStatus();
+
+ InitializeHostForDistributedTpuOp_DoWork(
+ tpu_host_config.size(), tpu_host_config.data(),
+ enable_whole_mesh_compilations_, &device_id_output_size,
+ &device_id_output, status);
+
+ Tensor* ctx_output;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(
+ 0, TensorShape({static_cast<long long>(device_id_output_size)}),
+ &ctx_output));
+
+ for (size_t i = 0; i < device_id_output_size; ++i) {
+ ctx_output->flat<int32>()(i) = device_id_output[i];
+ }
+
+ OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+ TF_DeleteStatus(status);
+ TpuConfigurationApi_FreeInt32Array(device_id_output);
+
+ VLOG(1) << "InitializeHostForDistributedTpuOp done";
+}
+
+void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "SetGlobalTPUArrayOp";
+ XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
+
+ auto tpu_topology = ctx->input(0).scalar<tstring>()();
+ TF_Status* status = TF_NewStatus();
+
+ SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status);
+
+ OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+ TF_DeleteStatus(status);
+
+ VLOG(1) << "SetGlobalTPUArrayOp done";
+}
+
+void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "DisconnectDistributedTpuChipsOp";
+ XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
+
+ TF_Status* status = TF_NewStatus();
+ int32_t number_of_chips_output = 0;
+
+ DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status);
+
+ Tensor* ctx_output;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
+ ctx_output->scalar<int32_t>()() = number_of_chips_output;
+
+ OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+ TF_DeleteStatus(status);
+
+ VLOG(1) << "DisconnectDistributedTpuChipsOp done";
+}
+
+// These ops execute on the TPU_SYSTEM device only.
+REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
+ .Device(DEVICE_TPU_SYSTEM)
+ .HostMemory("output"),
+ ConfigureDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
+ .Device(DEVICE_TPU_SYSTEM)
+ .HostMemory("inputs")
+ .HostMemory("topology"),
+ WaitForDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
+ ShutdownDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
+ .Device(DEVICE_TPU_SYSTEM)
+ .HostMemory("input")
+ .HostMemory("tpu_ids"),
+ InitializeHostForDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
+ SetGlobalTPUArrayOp);
+REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
+ .Device(DEVICE_TPU_SYSTEM)
+ .HostMemory("number_of_tpu_chips"),
+ DisconnectDistributedTpuChipsOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
new file mode 100644
index 0000000..f75a47e
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
@@ -0,0 +1,156 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+// The ConfigureDistributedTpu op is used to start an TPUDriver from
+// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
+// connection host:port for the CompilationCacheServer. The
+// CompilationCacheServer will remain live until the device's Resource Manager
+// is cleared or a ShutdownDistributedTpuOp is run on the same device.
+class ConfigureDistributedTpuOp : public OpKernel {
+ public:
+ explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() > 0,
+ errors::Internal("_ConfigureDistributedTPU needs at least one input"));
+ }
+ void Compute(OpKernelContext* ctx) override;
+ ~ConfigureDistributedTpuOp() override {}
+
+ private:
+ // ConfigureDistributedTpuOp is neither copyable nor movable.
+ ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete;
+ ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) =
+ delete;
+};
+
+// The WaitForDistributedTpuOp op is used to block execution until
+// the distributed Tpu system has started up. It must be run on
+// the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run
+// on, after all of the InitializeHostForDistributedTpuOp Ops have
+// completed.
+class WaitForDistributedTpuOp : public OpKernel {
+ public:
+ explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_));
+ OP_REQUIRES(ctx, startup_timeout_sec_ > 0,
+ errors::InvalidArgument("startup_timeout_sec ",
+ startup_timeout_sec_, " must be >0"));
+ }
+ void Compute(OpKernelContext* ctx) override;
+ ~WaitForDistributedTpuOp() override {}
+
+ private:
+ // The time to wait for all hosts to start up.
+ int startup_timeout_sec_;
+
+ // WaitForDistributedTpuOp is neither copyable nor movable.
+ WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete;
+ WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete;
+};
+
+// The ShutdownDistributedTpu op is used to stop a running TPUDriver from
+// TensorFlow. It should be run on the TPU_SYSTEM device where
+// ConfigureDistributedTpuOp was run.
+class ShutdownDistributedTpuOp : public OpKernel {
+ public:
+ explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+
+ ~ShutdownDistributedTpuOp() override {}
+
+ private:
+ // ShutdownDistributedTpuOp is neither copyable nor movable.
+ ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete;
+ ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete;
+};
+
+// The InitializeHostForDistributedTpu op is used to initialize the
+// TPUPlatform on a host in a distributed TPU system. It should be
+// run on every host containing TPU devices before any other Ops that use
+// TPU are run.
+class InitializeHostForDistributedTpuOp : public OpKernel {
+ public:
+ explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ ctx->GetAttr("enable_whole_mesh_compilations",
+ &enable_whole_mesh_compilations_)
+ .IgnoreError();
+ }
+
+ void Compute(OpKernelContext* ctx) override;
+
+ ~InitializeHostForDistributedTpuOp() override {}
+
+ private:
+ // InitializeHostForDistributedTpuOp is neither copyable nor movable.
+ InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) =
+ delete;
+ InitializeHostForDistributedTpuOp& operator=(
+ const InitializeHostForDistributedTpuOp&) = delete;
+
+ bool enable_whole_mesh_compilations_ = false;
+};
+
+// The SetGlobalTPUArray op is used to initialize the TPUPlatform on a
+// host in a distributed TPU system. It should be run on every host
+// containing TPU devices before any other Ops that use TPU are run.
+class SetGlobalTPUArrayOp : public OpKernel {
+ public:
+ explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+
+ ~SetGlobalTPUArrayOp() override {}
+
+ private:
+ // SetGlobalTPUArrayOp is neither copyable nor movable.
+ SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete;
+ SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete;
+};
+
+// The DisconnectDistributedTpuChips op is used to disconnect all the chips on a
+// host from a running TPUDriver instance. It should be run on every host
+// containing TPU devices before the ShutdownDistributedTpuOp is run on
+// the TPU_SYSTEM.
+class DisconnectDistributedTpuChipsOp : public OpKernel {
+ public:
+ explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+
+ ~DisconnectDistributedTpuChipsOp() override {}
+
+ private:
+ // DisconnectDistributedTpuChipsOp is neither copyable nor movable.
+ DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) =
+ delete;
+ DisconnectDistributedTpuChipsOp& operator=(
+ const DisconnectDistributedTpuChipsOp&) = delete;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_executable_info.proto b/tensorflow/core/tpu/kernels/tpu_executable_info.proto
new file mode 100644
index 0000000..359dad0
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_executable_info.proto
@@ -0,0 +1,94 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+
+import "tensorflow/compiler/xla/service/hlo.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+
+// A serialization of TPUExecutable. Only includes fields necessary to load
+// and execute a program on a worker node.
+message TPUExecutableInfoProto {
+ reserved 1;
+
+ // The shapes of the inputs and outputs.
+ repeated xla.ShapeProto input_shapes = 2;
+ reserved 7; // was input_shape
+ xla.ShapeProto output_shape = 3;
+
+ message UpdateIndexPair {
+ int32 index = 1;
+ bool updated = 2;
+ }
+
+ message ShapeIndex {
+ repeated int32 index = 1;
+ }
+
+ // Dynamic output indices indicate which outputs have dynamic dimensions.
+ repeated ShapeIndex dynamic_output_indices = 11;
+
+ // For each resource variable output, what was the index of the corresponding
+ // input and was it updated? The indices are sorted by input order.
+ repeated UpdateIndexPair variable_indices = 10;
+
+ // The shapes of the outputs when represented as Tensors. These may not
+ // match the output_shape values because we may flatten tensors to avoid
+ // excess padding.
+ repeated TensorShapeProto output_tensor_shapes = 8;
+
+ reserved 4;
+
+ // Optional session module for passing XLA computations between TPUCompileOp
+ // and TPUExecuteOp. This is needed to support the
+ // --xla_dump_hlo_snapshots flag.
+ xla.HloSnapshot session_module = 5;
+
+ // The physical device ids assigned to the replicated cores.
+ xla.DeviceAssignmentProto device_assignment = 6;
+}
+
+// Metadata for a data transfer between device and host.
+message TPUHostTransferProto {
+ enum TransferDirection {
+ NONE = 0;
+ DEVICE_TO_HOST = 1;
+ HOST_TO_DEVICE = 2;
+ }
+ // Channel identifier assigned by compiler and used in host commands.
+ int64 channel = 1;
+ // Direction of the transfer operation.
+ TransferDirection direction = 2;
+ // Channel identifier prodided by XLA client.
+ string key = 3;
+ // Depth of nested loops for this transfer operation.
+ int64 nested_while_level = 4;
+ // Shape of the data to be transferred (including layout).
+ xla.ShapeProto shape = 5;
+ // Address of the device buffer in HBM (byte offset).
+ int64 buffer_offset = 6;
+ // Original data type for this host transfer before X64 rewrite.
+ xla.PrimitiveType original_type = 7;
+ // If this host transfer is a splitted X64 transfer, sepcifies whether this
+ // transfer is for lower bits.
+ bool is_lower_bits = 8;
+}
+
+message TPUHostTransferInfoProto {
+ repeated TPUHostTransferProto host_transfers = 1;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h
new file mode 100644
index 0000000..cb6a82e
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
+
+typedef struct XLA_TpuMeshState XLA_TpuMeshState;
+
+// Creates a new TPU mesh state object.
+XLA_TpuMeshState* TpuMeshState_Create();
+
+// Deletes the given TPU `mesh_state` object. Once deleted the object is
+// unusable.
+void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
+
+// Returns a pointer to an opaque mesh data structure used internally.
+void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state);
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h
new file mode 100644
index 0000000..34202a7
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
+
+namespace tensorflow {
+
+class TpuMeshCommonState;
+
+namespace tpu {
+
+const char kTpuMeshCommonStateResourceName[] = "tpu_mesh_common_state";
+
+class TpuMeshStateInterface : public tensorflow::ResourceBase {
+ public:
+ explicit TpuMeshStateInterface(XLA_TpuMeshState* handle)
+ : mesh_state_(handle) {
+ }
+
+ ~TpuMeshStateInterface() override {
+ if (mesh_state_ != nullptr) {
+ TpuMeshState_Free(mesh_state_);
+ }
+ }
+
+ static TpuMeshStateInterface* Create() {
+ return new TpuMeshStateInterface(TpuMeshState_Create());
+ }
+
+ const XLA_TpuMeshState* data() const { return mesh_state_; }
+
+ tensorflow::TpuMeshCommonState* mesh_common_state() const {
+ return static_cast<tensorflow::TpuMeshCommonState*>(
+ TpuMeshState_MeshCommonState(mesh_state_));
+ }
+
+ // Returns whether we should include the device assignment as a static field
+ // to the TPU program. This also determines whether we should include the
+ // device assignment as part of the compilation cache key.
+ bool NeedsStaticDeviceAssignment(
+ const TPUCompileMetadataProto& metadata,
+ TpuCoreTypeEnum tpu_core_type) const {
+ // Static device assignment enables XLA to perform certain optimization when
+ // all cores are used in the replicated computation.
+ return metadata.num_cores_per_replica() * metadata.num_replicas() ==
+ TpuTopology_AvailableCoreCount(mesh_state_,
+ tpu_core_type);
+ }
+
+ string DebugString() const override { return "TpuMeshStateInterface"; }
+
+ private:
+ XLA_TpuMeshState* mesh_state_;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_program.cc b/tensorflow/core/tpu/kernels/tpu_program.cc
new file mode 100644
index 0000000..7d89ad1
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_program.cc
@@ -0,0 +1,201 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+
+namespace se_tpu = ::stream_executor::tpu;
+
+using stream_executor::port::StatusOr;
+using xla::Shape;
+
+StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
+ std::unique_ptr<xla::HloModuleGroup> module_group,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const TPUCompileMetadataProto& metadata,
+ const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+ const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+ const std::vector<std::vector<std::pair<int, bool>>>&
+ per_core_variable_indices,
+ const absl::optional<xla::DeviceAssignment>& device_assignment) {
+ VLOG(1) << "Run CompileAheadOfTime.";
+ TF_ASSIGN_OR_RETURN(TpuAotCompilationRequestProto aot_request,
+ CreateTpuAotCompilationRequest(
+ *module_group, compilation_result, metadata,
+ per_core_arg_shapes, per_core_output_shapes,
+ per_core_variable_indices, device_assignment));
+ se_tpu::SerializedProto serialized_aot_request =
+ se_tpu::SerializeProto(aot_request);
+ auto cleanup = gtl::MakeCleanup([serialized_aot_request] {
+ se_tpu::SerializedProto_Free(serialized_aot_request);
+ });
+
+ XLA_TpuProgram** xla_tpu_programs = nullptr;
+ size_t count = 0;
+ StatusHelper status;
+ VLOG(1) << "Run TpuCompile_CompileAheadOfTime.";
+ TpuCompile_CompileAheadOfTime(serialized_aot_request, &xla_tpu_programs,
+ &count, status.c_status);
+ VLOG(1) << "Run CompileAheadOfTime completed.";
+ if (!status.status().ok()) {
+ return status.status();
+ }
+ std::vector<XLA_TpuProgram*> tpu_programs(count, nullptr);
+ for (size_t i = 0; i < count; ++i) {
+ tpu_programs[i] = xla_tpu_programs[i];
+ }
+ delete[] xla_tpu_programs;
+ return tpu_programs;
+ return Status::OK();
+}
+
+StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
+ const TPUCompileMetadataProto& metadata,
+ const XlaCompiler::CompilationResult& compilation_result,
+ const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+ const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+ const std::vector<std::vector<std::pair<int, bool>>>&
+ per_core_variable_indices,
+ const absl::optional<xla::DeviceAssignment>& device_assignment) {
+ VLOG(1) << "Compile Tpu programs.";
+ std::vector<std::unique_ptr<xla::HloModule>> hlo_modules;
+ auto status = CreateHloModules(metadata, compilation_result,
+ device_assignment, &hlo_modules);
+ if (!status.ok()) {
+ return status;
+ }
+
+ return CompileAheadOfTime(
+ absl::make_unique<xla::HloModuleGroup>(hlo_modules[0]->name(),
+ absl::MakeSpan(hlo_modules)),
+ compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
+ per_core_variable_indices, device_assignment);
+}
+
+} // namespace
+
+int64_t TpuProgram::program_size() const {
+ int64_t total_size = 0;
+ for (XLA_TpuProgram* tpu_program : tpu_programs_) {
+ total_size += TpuProgram_GetProgramSize(tpu_program);
+ }
+ return total_size;
+}
+
+bool TpuProgram::LogProgramMemorySummary() {
+ bool success = true;
+ for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
+ success &= TpuProgram_LogProgramMemorySummary(tpu_program);
+ }
+ return success;
+}
+
+void TpuProgram::UnloadAndDestroyPrograms() {
+ for (XLA_TpuProgram* tpu_program : tpu_programs_) {
+ StatusHelper status;
+ TpuProgram_UnloadAndDestroy(tpu_program, status.c_status);
+ auto s = status.status();
+ if (!s.ok()) {
+ LOG(ERROR) << "TpuProgram::UnloadPrograms(): " << s.ToString();
+ }
+ }
+ tpu_programs_.clear();
+}
+
+/*static*/ Status TpuProgram::Build(
+ const TPUCompileMetadataProto& metadata,
+ const tensorflow::XlaCompiler::CompilationResult& compilation_result,
+ const std::vector<ShardingAndIndex>& arg_core_mapping,
+ const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+ const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
+ TpuProgram* tpu_program) {
+ std::vector<std::vector<xla::Shape>> per_core_output_shapes(
+ metadata.num_cores_per_replica());
+ TF_RETURN_IF_ERROR(ComputeOutputShapesForEachCore(
+ metadata, compilation_result, &per_core_output_shapes));
+
+ std::vector<std::vector<std::pair<int, bool>>> per_core_variable_indices(
+ metadata.num_cores_per_replica());
+ std::vector<bool> may_modify_variables;
+ TF_RETURN_IF_ERROR(AddVariableUpdatesToCores(
+ metadata, compilation_result, arg_core_mapping, &may_modify_variables,
+ &per_core_output_shapes, &per_core_variable_indices));
+ TF_RET_CHECK(per_core_arg_shapes.size() == metadata.num_cores_per_replica());
+ TF_RET_CHECK(per_core_output_shapes.size() == per_core_arg_shapes.size());
+ TF_RET_CHECK(per_core_output_shapes.size() ==
+ per_core_variable_indices.size());
+ tpu_program->set_may_modify_variables(may_modify_variables);
+
+ // With shardable input/output pairs, XLA could generate separate
+ // sharding/unsharding programs along with the main program. The
+ // sharding/unsharding programs will be in nested entries of the AOT
+ // compilation result.
+ auto status_or = CompileAheadOfTime(
+ metadata, compilation_result, per_core_arg_shapes, per_core_output_shapes,
+ per_core_variable_indices, xla_device_assignment);
+
+ TF_ASSIGN_OR_RETURN(std::vector<XLA_TpuProgram*> xla_tpu_programs,
+ std::move(status_or));
+ // SPMD could return 1 result for all partitions.
+ TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
+ xla_tpu_programs.size() == metadata.num_cores_per_replica());
+ tpu_program->set_tpu_programs(xla_tpu_programs);
+
+ // TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
+ TpuSerializedProto serialized_executable_info;
+ TpuProgram_GetExecutableInfo(xla_tpu_programs[0],
+ &serialized_executable_info);
+ TPUExecutableInfoProto executable_info =
+ se_tpu::DeserializeProto<TPUExecutableInfoProto>(
+ serialized_executable_info);
+ tpu_program->set_executable_info(executable_info);
+ StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
+
+ TPUHostTransferInfoProto host_transfer_info;
+ TpuSerializedProto serialized_host_transfer_info;
+ TpuProgram_GetHostTransferInfo(xla_tpu_programs[0],
+ &serialized_host_transfer_info);
+ if (serialized_host_transfer_info.size > 0) {
+ host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
+ serialized_host_transfer_info);
+ StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
+ }
+ tpu_program->set_host_transfer_info(host_transfer_info);
+
+ TpuSerializedProto serialized_hlo_metadata;
+ TpuProgram_GetHloMetadata(xla_tpu_programs[0], &serialized_hlo_metadata);
+ xla::HloProto hlo_metadata =
+ se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
+ tpu_program->set_hlo_metadata(hlo_metadata);
+ StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
+
+ return Status::OK();
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_program.h b/tensorflow/core/tpu/kernels/tpu_program.h
new file mode 100644
index 0000000..aee55bd
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_program.h
@@ -0,0 +1,161 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuAotCompilationOptions : public xla::AotCompilationOptions {
+ public:
+ explicit TpuAotCompilationOptions(int64 replica_count)
+ : num_cores_(0), replica_count_(replica_count) {}
+
+ // Returns the ID of the platform to which these options apply.
+ se::Platform::Id PlatformId() const override {
+ LOG(FATAL) << "Not implemented.";
+ return nullptr;
+ };
+
+ void set_num_cores(int64 tpu_cores) { num_cores_ = tpu_cores; }
+ int64 replica_count() const override { return replica_count_; }
+ int64 num_cores() const override { return num_cores_; }
+
+ void set_allow_separate_sharding_programs(bool allow) {
+ allow_separate_sharding_programs_ = allow;
+ }
+ bool allow_separate_sharding_programs() const {
+ return allow_separate_sharding_programs_;
+ }
+
+ const std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
+ shardable_value_update_pairs() const {
+ return shardable_value_update_pairs_;
+ }
+ void set_shardable_value_update_pairs(
+ std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs) {
+ shardable_value_update_pairs_ = std::move(pairs);
+ }
+
+ private:
+ int64 num_cores_;
+ int64 replica_count_;
+
+ // Whether to allow the compiler to create separte sharding and unsharding
+ // programs, and modify the original program's input/output sharded size. This
+ // is used for XLA-chosen sharding on parameters without an on-device loop:
+ // the caller can invoke sharding first, then (repeatedly) invoke the sharded
+ // main program, and finally invoke the unsharding program when it needs the
+ // full output.
+ bool allow_separate_sharding_programs_ = false;
+
+ // The list of input/output pairs in the main program that could be sharded.
+ std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
+ shardable_value_update_pairs_;
+};
+
+// An executable capable of being fed to a TPU device.
+class TpuProgram {
+ public:
+ using Status = ::stream_executor::port::Status;
+
+ virtual ~TpuProgram() = default;
+
+ static Status Build(
+ const TPUCompileMetadataProto& metadata,
+ const tensorflow::XlaCompiler::CompilationResult& compilation_result,
+ const std::vector<ShardingAndIndex>& arg_core_mapping,
+ const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+ const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
+ TpuProgram* tpu_program);
+
+ size_t program_count() const {
+ return tpu_programs_.size();
+ }
+
+ int64_t program_size() const;
+
+ bool LogProgramMemorySummary();
+
+ void UnloadAndDestroyPrograms();
+
+ const std::vector<bool>& may_modify_variables() const {
+ return may_modify_variables_;
+ }
+ void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
+ may_modify_variables_ = may_modify_variables;
+ }
+
+ const tf2xla::HostComputeMetadata& host_compute_metadata() const {
+ return host_compute_metadata_;
+ }
+ void set_host_compute_metadata(
+ const tf2xla::HostComputeMetadata& host_compute_metadata) {
+ host_compute_metadata_ = host_compute_metadata;
+ }
+
+ const std::vector<XLA_TpuProgram*>& tpu_programs() const {
+ return tpu_programs_;
+ }
+ void set_tpu_programs(std::vector<XLA_TpuProgram*> tpu_programs) {
+ tpu_programs_ = tpu_programs;
+ }
+
+ const TPUExecutableInfoProto& executable_info() const {
+ return executable_info_;
+ }
+ void set_executable_info(const TPUExecutableInfoProto& executable_info) {
+ executable_info_ = executable_info;
+ }
+
+ const TPUHostTransferInfoProto& host_transfer_info() const {
+ return host_transfer_info_;
+ }
+ void set_host_transfer_info(
+ const TPUHostTransferInfoProto& host_transfer_info) {
+ host_transfer_info_ = host_transfer_info;
+ }
+
+ const xla::HloProto& hlo_metadata() const { return hlo_metadata_; }
+ void set_hlo_metadata(const xla::HloProto& hlo_metadata) {
+ hlo_metadata_ = hlo_metadata;
+ }
+
+ private:
+ std::vector<bool> may_modify_variables_;
+ tf2xla::HostComputeMetadata host_compute_metadata_;
+
+ std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
+ TPUExecutableInfoProto executable_info_;
+ TPUHostTransferInfoProto host_transfer_info_;
+ xla::HloProto hlo_metadata_;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc
new file mode 100644
index 0000000..5c286de
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_util.cc
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_util.h"
+
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/platform/random.h"
+
+namespace tensorflow {
+namespace tpu {
+
+std::string SessionNameFromMetadata(const SessionMetadata* session_metadata) {
+ return session_metadata ? session_metadata->name() : "";
+}
+
+std::string ProtoKeyForComputation(const std::string& key, int core) {
+ return absl::StrCat(key, ":", core);
+}
+
+xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
+ const std::string& key) {
+ const std::vector<std::string> splits = absl::StrSplit(key, '|');
+ if (splits.size() == 1) {
+ // No guaranteed_const.
+ return TpuCompilationCacheKey(key);
+ } else if (splits.size() != 3) {
+ return errors::InvalidArgument("Invalid TPU compilation cache key:", key);
+ }
+
+ TpuCompilationCacheKey parsed_key(splits.at(0));
+ parsed_key.has_guaranteed_const = true;
+ parsed_key.session_handle = splits.at(1);
+ const string fingerprint = splits.at(2);
+ parsed_key.guaranteed_const_fingerprint = [fingerprint] {
+ return fingerprint;
+ };
+ return parsed_key;
+}
+
+xla::CompileOnlyClient::AotXlaComputationInstance
+BuildAotXlaComputationInstance(
+ const XlaCompiler::CompilationResult& compilation_result) {
+ xla::CompileOnlyClient::AotXlaComputationInstance instance;
+ instance.computation = compilation_result.computation.get();
+ for (const xla::Shape& shape : compilation_result.xla_input_shapes) {
+ instance.argument_layouts.push_back(&shape);
+ }
+ instance.result_layout = &compilation_result.xla_output_shape;
+ return instance;
+}
+
+Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape) {
+ if (tensor.dtype() != DT_INT64 ||
+ !TensorShapeUtils::IsVector(tensor.shape())) {
+ return errors::InvalidArgument("Shape tensor must be an int64 vector.");
+ }
+ const int64 rank = tensor.NumElements();
+ auto tensor_dims = tensor.flat<int64>();
+ std::vector<int64> dims(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ dims[i] = tensor_dims(i);
+ }
+ return TensorShapeUtils::MakeShape(dims, shape);
+}
+
+Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
+ std::vector<TensorShape>* shapes) {
+ shapes->resize(dynamic_shapes.size());
+ for (int i = 0; i < dynamic_shapes.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
+ }
+ return Status::OK();
+}
+
+Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
+ std::vector<TensorShape>* shapes) {
+ shapes->resize(dynamic_shapes.end() - dynamic_shapes.begin());
+ size_t i = 0;
+ for (auto& dynamic_shape : dynamic_shapes) {
+ TF_RETURN_IF_ERROR(
+ ShapeTensorToTensorShape(dynamic_shape.tensor(), &(*shapes)[i]));
+ ++i;
+ }
+ return Status::OK();
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h
new file mode 100644
index 0000000..0ca94d0
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_util.h
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Utility to get session_name from `SessionMetadata`. `SessionMetadata` may
+// be null.
+std::string SessionNameFromMetadata(const SessionMetadata* session_metadata);
+
+// Generates cache proto key for a given computation on a TPU core.
+std::string ProtoKeyForComputation(const std::string& key, int core);
+
+// Returns a TpuCompilationCacheKey parsed from given key or an error.
+xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
+ const std::string& key);
+
+xla::CompileOnlyClient::AotXlaComputationInstance
+BuildAotXlaComputationInstance(
+ const XlaCompiler::CompilationResult& compilation_result);
+
+// Returns true if TPU compilation is enabled.
+bool IsTpuCompilationEnabled();
+
+// Converts an int64 host memory `tensor` to a `shape`.
+Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape);
+
+Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
+ std::vector<TensorShape>* shapes);
+Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
+ std::vector<TensorShape>* shapes);
+
+// Given a tensor of `shape` and `type`, as what shape should it be stored on
+// the TPU device? This function tranposes or flattens the excessively-padded
+// tensors to rank 1, but leaves other tensor shapes alone.
+xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
+ DataType type,
+ bool use_fast_memory);
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
diff --git a/tensorflow/core/tpu/kernels/trace_util.h b/tensorflow/core/tpu/kernels/trace_util.h
new file mode 100644
index 0000000..4e0b7c9
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/trace_util.h
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
+
+#ifdef PLATFORM_GOOGLE
+#include "base/tracer.h"
+#else
+#undef TRACESTRING
+#define TRACESTRING(x)
+#undef TRACELITERAL
+#define TRACELITERAL(x)
+#endif
+
+#endif // TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
diff --git a/tensorflow/core/tpu/tpu_library_loader.cc b/tensorflow/core/tpu/tpu_library_loader.cc
index 3bc835c..c89de14 100644
--- a/tensorflow/core/tpu/tpu_library_loader.cc
+++ b/tensorflow/core/tpu/tpu_library_loader.cc
@@ -64,13 +64,20 @@
}
Status InitializeTpuLibrary(void* library_handle) {
+ bool shared_object_loaded = true;
if (library_handle == nullptr) {
library_handle = dlopen(nullptr, RTLD_LAZY);
+ shared_object_loaded = false;
}
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
+ if (shared_object_loaded) {
+ // Initialize TPU platform when the platform code is loaded from a library.
+ InitializeApiFn()->TfTpu_InitializeFn();
+ }
+
return Status::OK();
}
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
new file mode 100644
index 0000000..52ca40f
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -0,0 +1,234 @@
+# Description: StreamExecutor Interface for TPUs
+
+package(
+ default_visibility = ["//tensorflow/core/tpu:__subpackages__"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "tpu_executor_c_api_hdrs",
+ hdrs = ["tpu_executor_c_api.h"],
+ deps = [
+ "//tensorflow/c:tf_attrtype",
+ "//tensorflow/c:tf_datatype",
+ "//tensorflow/c:tf_status",
+ ],
+)
+
+cc_library(
+ name = "tpu_node_context_c_api_hdrs",
+ hdrs = ["tpu_node_context_c_api.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ ],
+)
+
+cc_library(
+ name = "status_helper",
+ hdrs = ["status_helper.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ "//tensorflow/core/platform:status",
+ ],
+)
+
+cc_library(
+ name = "c_api_conversions",
+ hdrs = ["c_api_conversions.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/stream_executor:stream",
+ "@com_google_absl//absl/container:inlined_vector",
+ ],
+)
+
+cc_library(
+ name = "proto_helper",
+ srcs = ["proto_helper.cc"],
+ hdrs = ["proto_helper.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
+cc_library(
+ name = "tpu_stream",
+ hdrs = ["tpu_stream.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ "//tensorflow/stream_executor:stream",
+ "//tensorflow/stream_executor/lib",
+ ],
+)
+
+cc_library(
+ name = "tpu_timer",
+ hdrs = ["tpu_timer.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ "//tensorflow/core/platform:types",
+ "//tensorflow/stream_executor:stream",
+ ],
+)
+
+cc_library(
+ name = "tpu_executor",
+ srcs = ["tpu_executor.cc"],
+ hdrs = ["tpu_executor.h"],
+ deps = [
+ ":c_api_conversions",
+ ":status_helper",
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_executor_interface",
+ ":tpu_platform",
+ ":tpu_platform_interface",
+ ":tpu_stream",
+ ":tpu_timer",
+ "//tensorflow/c:tf_status",
+ "//tensorflow/core:lib",
+ "//tensorflow/stream_executor:stream",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ ],
+)
+
+cc_library(
+ name = "tpu_executor_hdrs",
+ hdrs = ["tpu_executor.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_executor_interface",
+ ":tpu_platform_hdrs",
+ ":tpu_platform_interface",
+ "//tensorflow/core/platform:types",
+ "//tensorflow/stream_executor:stream_header",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ ],
+)
+
+cc_library(
+ name = "tpu_platform_hdrs",
+ hdrs = ["tpu_platform.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_platform_interface",
+ "//tensorflow/core/platform:types",
+ "//tensorflow/stream_executor:stream_header",
+ "@com_google_absl//absl/container:flat_hash_map",
+ ],
+)
+
+cc_library(
+ name = "tpu_node_context",
+ srcs = ["tpu_node_context.cc"],
+ hdrs = ["tpu_node_context.h"],
+ deps = [
+ ":status_helper",
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_node_context_c_api_hdrs",
+ ":tpu_platform_interface",
+ ":tpu_transfer_manager",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:stream_pool",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/core:framework",
+ "//tensorflow/stream_executor:device_memory_allocator",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
+ name = "tpu_platform",
+ srcs = ["tpu_platform.cc"],
+ hdrs = ["tpu_platform.h"],
+ deps = [
+ ":status_helper",
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_executor_hdrs",
+ ":tpu_platform_interface",
+ "//tensorflow/c:tf_status",
+ "//tensorflow/core/platform:types",
+ "//tensorflow/stream_executor:stream",
+ "@com_google_absl//absl/container:flat_hash_map",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "tpu_transfer_manager",
+ srcs = ["tpu_transfer_manager_registration.cc"],
+ deps = [
+ ":tpu_platform",
+ ":tpu_transfer_manager_base",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ ],
+)
+
+cc_library(
+ name = "tpu_transfer_manager_base",
+ srcs = ["tpu_transfer_manager.cc"],
+ hdrs = ["tpu_transfer_manager.h"],
+ deps = [
+ ":c_api_conversions",
+ ":proto_helper",
+ ":status_helper",
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_platform",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/stream_executor:stream",
+ ],
+)
+
+cc_library(
+ name = "tpu_computation_placer",
+ srcs = ["tpu_computation_placer.cc"],
+ hdrs = ["tpu_computation_placer.h"],
+ deps = [
+ ":tpu_executor_c_api_hdrs",
+ ":tpu_platform",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "tpu_platform_interface",
+ srcs = ["tpu_platform_interface.cc"],
+ hdrs = ["tpu_platform_interface.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core/platform:types",
+ "//tensorflow/stream_executor:multi_platform_manager",
+ "//tensorflow/stream_executor:stream_executor_headers",
+ ],
+)
+
+cc_library(
+ name = "tpu_stream_interface",
+ hdrs = ["tpu_stream_interface.h"],
+ visibility = ["//visibility:public"],
+ deps = ["//tensorflow/stream_executor:stream_executor_internal"],
+)
+
+cc_library(
+ name = "tpu_executor_interface",
+ hdrs = ["tpu_executor_interface.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tpu_platform_interface",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/stream_executor:stream_executor_internal",
+ "//tensorflow/stream_executor:stream_header",
+ ],
+)
diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.h b/tensorflow/stream_executor/tpu/c_api_conversions.h
new file mode 100644
index 0000000..1bb9ece
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/c_api_conversions.h
@@ -0,0 +1,115 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
+
+#include "absl/container/inlined_vector.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class TpuConversions {
+ public:
+ static stream_executor::DeviceMemoryBase
+ SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) {
+ stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size);
+ base.SetPayload(se_base.payload);
+ return base;
+ }
+
+ static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase(
+ const stream_executor::DeviceMemoryBase& base) {
+ SE_DeviceMemoryBase se_base;
+ se_base.opaque = const_cast<void*>(base.opaque());
+ se_base.payload = base.payload();
+ se_base.size = base.size();
+ return se_base;
+ }
+
+ static xla::Shape CShapeToXlaShape(XLA_Shape* shape) {
+ xla::ShapeProto p;
+ p.ParseFromArray(shape->bytes, shape->size);
+ return xla::Shape(p);
+ }
+
+ static void XlaShapeToCShape(const xla::Shape& xla_shape,
+ XLA_Shape* c_shape) {
+ xla::ShapeProto p = xla_shape.ToProto();
+ std::string p_str = p.SerializeAsString();
+ c_shape->bytes = new char[p_str.size()];
+ c_shape->size = p_str.size();
+ memcpy(c_shape->bytes, p_str.data(), p_str.size());
+ }
+
+ static void XLAShapedBufferToCShapedBuffer(
+ const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
+ XlaShapeToCShape(buffer.on_host_shape(), &c_device_buffer->on_host_shape);
+ XlaShapeToCShape(buffer.on_device_shape(),
+ &c_device_buffer->on_device_shape);
+ c_device_buffer->device_ordinal = buffer.device_ordinal();
+ absl::InlinedVector<SE_DeviceMemoryBase, 2> bases;
+ for (auto& pair : buffer.buffers()) {
+ bases.push_back(DeviceMemoryBaseToSE_DeviceMemoryBase(pair.second));
+ }
+ c_device_buffer->count = bases.size();
+ c_device_buffer->bases = new SE_DeviceMemoryBase[bases.size()];
+ for (int i = 0; i < bases.size(); ++i) {
+ c_device_buffer->bases[i] = bases[i];
+ }
+ }
+
+ static void XLALiteralToCLiteral(const xla::LiteralSlice& literal,
+ XLA_Literal* c_literal) {
+ XlaShapeToCShape(literal.shape(), &c_literal->shape);
+ auto shapes = xla::ShapeUtil::GetLeafShapes(literal.shape());
+ c_literal->buffers = new char*[shapes.size()];
+ c_literal->sizes = new size_t[shapes.size()];
+ c_literal->count = shapes.size();
+ for (int i = 0; i < shapes.size(); ++i) {
+ c_literal->buffers[i] = reinterpret_cast<char*>(
+ const_cast<void*>(literal.untyped_data(shapes[i].index)));
+ c_literal->sizes[i] = literal.size_bytes(shapes[i].index);
+ }
+ }
+
+ static xla::MutableBorrowingLiteral CLiteralToXLALiteral(
+ XLA_Literal* c_literal) {
+ xla::Shape shape = CShapeToXlaShape(&c_literal->shape);
+ LOG(INFO) << "Shape: " << shape.DebugString();
+ return xla::MutableBorrowingLiteral(
+ absl::MakeSpan(c_literal->buffers, c_literal->count), shape);
+ }
+
+ static void CShapeCleanup(XLA_Shape* c_shape) { delete[] c_shape->bytes; }
+
+ static void CLiteralCleanup(XLA_Literal* c_literal) {
+ delete[] c_literal->buffers;
+ delete[] c_literal->sizes;
+ CShapeCleanup(&c_literal->shape);
+ }
+
+ static void CShapedBufferCleanup(XLA_ShapedBuffer* c_buffer) {
+ CShapeCleanup(&c_buffer->on_device_shape);
+ CShapeCleanup(&c_buffer->on_host_shape);
+ delete[] c_buffer->bases;
+ }
+};
+
+#endif // THIRD_PARTY_TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
diff --git a/tensorflow/stream_executor/tpu/proto_helper.cc b/tensorflow/stream_executor/tpu/proto_helper.cc
new file mode 100644
index 0000000..db663c6
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/proto_helper.cc
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+
+extern "C" {
+
+void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto) {
+ CHECK_NE(proto, nullptr);
+ CHECK_NE(proto->bytes, nullptr);
+ CHECK_GT(proto->size, 0);
+ delete[] proto->bytes;
+}
+
+} // extern "C"
diff --git a/tensorflow/stream_executor/tpu/proto_helper.h b/tensorflow/stream_executor/tpu/proto_helper.h
new file mode 100644
index 0000000..3bd2b09
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/proto_helper.h
@@ -0,0 +1,85 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
+
+#include <cstddef>
+
+#include "tensorflow/core/platform/logging.h"
+
+extern "C" {
+
+typedef struct TpuSerializedProto {
+ const char* bytes;
+ size_t size;
+} TpuSerializedProto;
+
+void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto);
+
+} // extern "C"
+
+namespace stream_executor {
+namespace tpu {
+
+using SerializedProto = TpuSerializedProto;
+
+// Serializes a proto and put the result in the given SerializedProto* argument.
+//
+// Users should call SerializedProto_Free on `serialized_proto` afterwards.
+template <class Proto>
+inline void SerializeProto(const Proto& proto,
+ SerializedProto* serialized_proto) {
+ auto size = proto.ByteSizeLong();
+ auto bytes = new char[size];
+ CHECK(proto.SerializeToArray(bytes, size));
+ serialized_proto->size = size;
+ serialized_proto->bytes = bytes;
+}
+
+// Serializes a proto and return the result as a SerializedProto value.
+//
+// Users should call SerializedProto_Free on the return value afterwards.
+template <class Proto>
+inline SerializedProto SerializeProto(const Proto& proto) {
+ SerializedProto serialized_proto;
+ SerializeProto(proto, &serialized_proto);
+ return serialized_proto;
+}
+
+// Deserializes a buffer and return the corresponding proto. If the buffer is
+// empty, return an empty proto.
+template <class Proto>
+inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
+ Proto proto;
+ if (serialized_proto.bytes != nullptr) {
+ CHECK_GT(serialized_proto.size, 0);
+ CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
+ << "Invalid buffer, failed to deserialize buffer.";
+ }
+ return proto;
+}
+
+// Releases the memory allocated for serialized protos.
+inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
+ CHECK_NE(serialized_proto.bytes, nullptr);
+ CHECK_GT(serialized_proto.size, 0);
+ delete[] serialized_proto.bytes;
+}
+
+} // namespace tpu
+} // namespace stream_executor
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/status_helper.h b/tensorflow/stream_executor/tpu/status_helper.h
new file mode 100644
index 0000000..8fcf302
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/status_helper.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
+
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+struct StatusHelper {
+ StatusHelper() : c_status(TpuStatus_New()) {}
+ ~StatusHelper() { TpuStatus_Free(c_status); }
+ bool ok() { return TpuStatus_Code(c_status) == 0; }
+ tensorflow::Status status() {
+ if (!ok()) {
+ return tensorflow::Status(
+ tensorflow::error::Code(TpuStatus_Code(c_status)),
+ TpuStatus_Message(c_status));
+ } else {
+ return tensorflow::Status::OK();
+ }
+ }
+ SE_Status* c_status;
+};
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.cc b/tensorflow/stream_executor/tpu/tpu_computation_placer.cc
new file mode 100644
index 0000000..660b446
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_computation_placer.cc
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
+
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+template <typename T>
+using StatusOr = TpuComputationPlacer::StatusOr<T>;
+
+TpuComputationPlacer::TpuComputationPlacer() {
+ placer_ = TpuComputationPlacer_New();
+}
+
+TpuComputationPlacer::~TpuComputationPlacer() {
+ TpuComputationPlacer_Free(placer_);
+}
+
+StatusOr<int> TpuComputationPlacer::DeviceId(int replica, int computation,
+ int replica_count,
+ int computation_count) {
+ LOG(FATAL) << "Unimplemented.";
+}
+
+StatusOr<xla::DeviceAssignment> TpuComputationPlacer::AssignDevices(
+ int replica_count, int computation_count) {
+ LOG(FATAL) << "Unimplemented.";
+}
+
+static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
+ return std::make_unique<TpuComputationPlacer>();
+}
+
+static bool InitModule() {
+ xla::ComputationPlacer::RegisterComputationPlacer(
+ tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer);
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.h b/tensorflow/stream_executor/tpu/tpu_computation_placer.h
new file mode 100644
index 0000000..c8f4c9e
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_computation_placer.h
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
+
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class TpuComputationPlacer : public xla::ComputationPlacer {
+ public:
+ template <typename T>
+ using StatusOr = xla::StatusOr<T>;
+
+ TpuComputationPlacer();
+ ~TpuComputationPlacer() override;
+
+ StatusOr<int> DeviceId(int replica, int computation, int replica_count,
+ int computation_count) override;
+
+ StatusOr<xla::DeviceAssignment> AssignDevices(int replica_count,
+ int computation_count) override;
+
+ private:
+ XLA_ComputationPlacer* placer_;
+};
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc
new file mode 100644
index 0000000..92808936
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor.cc
@@ -0,0 +1,355 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_executor.h"
+
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_stream.h"
+#include "tensorflow/stream_executor/tpu/tpu_timer.h"
+
+using stream_executor::DeviceMemoryBase;
+
+namespace tensorflow {
+
+namespace {
+using ::stream_executor::port::Status;
+} // namespace
+
+TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); }
+
+Status TpuExecutor::Init(int device_ordinal,
+ ::stream_executor::DeviceOptions device_options) {
+ StatusHelper status;
+ SE_DeviceOptions* options =
+ TpuExecutor_NewDeviceOptions(device_options.flags());
+ TpuExecutor_Init(executor_, device_ordinal, options, status.c_status);
+ TpuExecutor_FreeDeviceOptions(options);
+ return status.status();
+}
+
+int TpuExecutor::PlatformDeviceCount() {
+ return TpuExecutor_PlatformDeviceCount(executor_);
+}
+
+void TpuExecutor::SyncAndForgetFailedStreams() {
+ TpuExecutor_SyncAndForgetFailedStreams(executor_);
+}
+
+bool TpuExecutor::SynchronizeAllActivity() {
+ return TpuExecutor_SynchronizeAllActivity(executor_);
+}
+
+Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
+ StatusHelper status;
+ TpuExecutor_BlockHostUntilDone(
+ executor_, stream_map().at(stream->implementation()), status.c_status);
+ return status.status();
+}
+
+Status TpuExecutor::BlockUntilDoneOrFailed() {
+ StatusHelper status;
+ TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status);
+ return status.status();
+}
+
+Status TpuExecutor::GetStatus(Stream* stream) {
+ StatusHelper status;
+ TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()),
+ status.c_status);
+ return status.status();
+}
+
+bool TpuExecutor::AllocateStream(Stream* stream) {
+ return TpuExecutor_AllocateStream(executor_,
+ stream_map().at(stream->implementation()));
+}
+
+void TpuExecutor::DeallocateStream(Stream* stream) {
+ TpuExecutor_DeallocateStream(executor_,
+ stream_map().at(stream->implementation()));
+ stream_map().erase(stream->implementation());
+}
+
+bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
+ return TpuExecutor_CreateStreamDependency(
+ executor_, stream_map().at(dependent->implementation()),
+ stream_map().at(other->implementation()));
+}
+
+Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); }
+
+Status TpuExecutor::DeallocateEvent(Event* event) { return Status::OK(); }
+
+// AllocateTimer/DeallocateTimer have no specialization.
+bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
+
+void TpuExecutor::DeallocateTimer(Timer* timer) {}
+
+bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
+ return TpuExecutor_StartTimer(executor_,
+ stream_map().at(stream->implementation()),
+ timer_map_.at(timer->implementation()));
+}
+
+bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
+ return TpuExecutor_StopTimer(executor_,
+ stream_map().at(stream->implementation()),
+ timer_map_.at(timer->implementation()));
+}
+
+stream_executor::Event::Status TpuExecutor::PollForEventStatus(
+ stream_executor::Event* event) {
+ return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
+ executor_, event_map_.at(event->implementation())));
+}
+
+Status TpuExecutor::RecordEvent(Stream* stream,
+ ::stream_executor::Event* event) {
+ StatusHelper status;
+ TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
+ event_map_.at(event->implementation()),
+ status.c_status);
+ return status.status();
+}
+
+Status TpuExecutor::WaitForEvent(Stream* stream,
+ ::stream_executor::Event* event) {
+ StatusHelper status;
+ TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
+ event_map_.at(event->implementation()),
+ status.c_status);
+ return status.status();
+}
+
+// Implementations for Timer, Stream, Event
+// We need to map these implementations to internal equivalents -- thus we
+// allocate the internal Timer, Stream and Event operations here, and map
+// the implementations to the internal values. The "wrapper" interfaces are
+// responsible for deallocating the internal value when they are destroyed.
+
+// Called by Timer::Timer
+std::unique_ptr<::stream_executor::internal::TimerInterface>
+TpuExecutor::GetTimerImplementation() {
+ SE_Timer* tpu_timer = TpuTimer_New(executor_);
+ auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
+ timer_map_[ptr.get()] = tpu_timer;
+ return ptr;
+}
+
+// Called by Stream::Stream
+std::unique_ptr<::stream_executor::internal::StreamInterface>
+TpuExecutor::GetStreamImplementation() {
+ SE_Stream* tpu_stream = TpuStream_New(executor_);
+ auto ptr = absl::make_unique<TpuStream>(tpu_stream);
+ stream_map()[ptr.get()] = tpu_stream;
+ return ptr;
+}
+
+// Called by Event::Event
+std::unique_ptr<::stream_executor::internal::EventInterface>
+TpuExecutor::CreateEventImplementation() {
+ SE_Event* tpu_event = TpuEvent_New(executor_);
+ auto ptr = absl::make_unique<TpuEvent>(tpu_event);
+ event_map_[ptr.get()] = tpu_event;
+ return ptr;
+}
+
+DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
+ SE_DeviceMemoryBase se_base =
+ TpuExecutor_Allocate(executor_, size, memory_space);
+ return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
+}
+
+void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
+ SE_DeviceMemoryBase se_base =
+ TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
+ TpuExecutor_Deallocate(executor_, &se_base);
+}
+
+void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
+ SE_DeviceMemoryBase se_base =
+ TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
+ TpuExecutor_Deallocate(executor_, &se_base);
+}
+
+bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
+ int64_t _free;
+ int64_t _total;
+ if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) {
+ *free = _free;
+ *total = _total;
+ return true;
+ }
+ return false;
+}
+
+absl::optional<stream_executor::AllocatorStats>
+TpuExecutor::GetAllocatorStats() {
+ SE_AllocatorStats c_stats;
+ if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) {
+ ::stream_executor::AllocatorStats stats;
+ stats.num_allocs = c_stats.num_allocs;
+ stats.bytes_in_use = c_stats.bytes_in_use;
+ stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
+ stats.largest_alloc_size = c_stats.largest_alloc_size;
+ if (c_stats.has_bytes_limit) {
+ stats.bytes_limit = c_stats.bytes_limit;
+ }
+ stats.bytes_reserved = c_stats.bytes_reserved;
+ stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
+ if (c_stats.has_bytes_reservable_limit) {
+ stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
+ }
+ stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
+ return stats;
+ }
+ return {};
+}
+
+Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
+ StatusHelper status;
+ TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index,
+ status.c_status);
+ return status.status();
+}
+
+Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
+ StatusHelper status;
+ TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index,
+ status.c_status);
+ return status.status();
+}
+
+void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
+ absl::Span<uint8> bytes, StatusCallback done) {
+ StatusHelper status;
+ TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(),
+ bytes.size(), status.c_status);
+ done(status.status());
+}
+
+Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
+ absl::Span<const uint8> bytes) {
+ StatusHelper status;
+ TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(),
+ bytes.size(), status.c_status);
+ return status.status();
+}
+
+bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
+ const ::stream_executor::DeviceMemoryBase& device_src,
+ uint64 size) {
+ SE_DeviceMemoryBase se_base =
+ TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+ return TpuExecutor_MemcpyToHost(executor_,
+ stream_map().at(stream->implementation()),
+ host_dst, &se_base, size);
+}
+
+bool TpuExecutor::Memcpy(Stream* stream,
+ ::stream_executor::DeviceMemoryBase* device_dst,
+ const void* host_src, uint64 size) {
+ SE_DeviceMemoryBase se_base =
+ TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+ return TpuExecutor_MemcpyFromHost(executor_,
+ stream_map().at(stream->implementation()),
+ &se_base, host_src, size);
+}
+
+Status TpuExecutor::SynchronousMemcpy(
+ ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
+ uint64 size) {
+ StatusHelper status;
+ SE_DeviceMemoryBase se_base =
+ TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+ TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size,
+ status.c_status);
+ return status.status();
+}
+
+Status TpuExecutor::SynchronousMemcpy(
+ void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
+ uint64 size) {
+ StatusHelper status;
+ SE_DeviceMemoryBase se_base =
+ TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+ TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size,
+ status.c_status);
+ return status.status();
+}
+
+Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
+ ::stream_executor::DeviceMemoryBase* device_dst,
+ const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) {
+ return ::stream_executor::port::UnimplementedError(
+ "This operation not supported on TPU");
+}
+
+bool TpuExecutor::MemcpyDeviceToDevice(
+ Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
+ const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) {
+ LOG(FATAL) << __func__ << " not supported on TpuExecutor";
+}
+
+struct HostCallbackContext {
+ std::function<Status()> callback;
+};
+
+SE_Status* HostCallbackTrampoline(void* ctx) {
+ HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
+ Status status = host_ctx->callback();
+ SE_Status* c_status =
+ TpuStatus_Create(status.code(), status.error_message().c_str());
+ delete host_ctx;
+ return c_status;
+}
+
+bool TpuExecutor::HostCallback(Stream* stream,
+ std::function<Status()> callback) {
+ HostCallbackContext* ctx = new HostCallbackContext{callback};
+ return TpuExecutor_HostCallback(executor_,
+ stream_map().at(stream->implementation()),
+ &HostCallbackTrampoline, ctx);
+}
+
+TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
+TpuExecutor::CreateDeviceDescription() const {
+ StatusHelper status;
+ SE_DeviceDescription* description = TpuDeviceDescription_New();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [description]() { TpuDeviceDescription_Free(description); });
+ TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status);
+ if (status.status().ok()) {
+ stream_executor::internal::DeviceDescriptionBuilder builder;
+ CHECK_NE(description->device_vendor, nullptr);
+ builder.set_device_vendor(description->device_vendor);
+ builder.set_name(description->name);
+ builder.set_clock_rate_ghz(description->clock_rate_ghz);
+ builder.set_core_count(description->core_count);
+ builder.set_ecc_enabled(description->ecc_enabled);
+ builder.set_device_memory_size(description->device_memory_size);
+ builder.set_platform_version(description->platform_version);
+ return builder.Build();
+ }
+ return status.status();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.h b/tensorflow/stream_executor/tpu/tpu_executor.h
new file mode 100644
index 0000000..5f36642
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor.h
@@ -0,0 +1,241 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/device_options.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/temporary_device_memory.h"
+#include "tensorflow/stream_executor/timer.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+
+class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
+ public:
+ using Status = ::stream_executor::port::Status;
+ template <typename T>
+ using StatusOr = ::stream_executor::port::StatusOr<T>;
+ using StatusCallback = std::function<void(const Status&)>;
+ using Stream = ::stream_executor::Stream;
+ using Event = ::stream_executor::Event;
+ using Timer = ::stream_executor::Timer;
+ using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
+ using StreamInterface = ::stream_executor::internal::StreamInterface;
+ using StreamExecutorInterface =
+ ::stream_executor::internal::StreamExecutorInterface;
+
+ using EventMap =
+ absl::flat_hash_map<stream_executor::internal::EventInterface*,
+ SE_Event*>;
+ using TimerMap =
+ absl::flat_hash_map<stream_executor::internal::TimerInterface*,
+ SE_Timer*>;
+
+ explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
+ SE_StreamExecutor* executor)
+ : platform_(platform), executor_(executor) {}
+
+ ~TpuExecutor() override;
+
+ Status Init(int device_ordinal,
+ ::stream_executor::DeviceOptions device_options) override;
+
+ DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override;
+
+ StatusOr<DeviceMemoryBase> AllocateDeviceMemoryBase(uint64 size,
+ int64 memory_space);
+
+ Status AllocateEvent(Event* event) override;
+
+ bool AllocateStream(Stream* stream) override;
+
+ bool AllocateTimer(Timer* timer) override;
+
+ Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
+
+ Status BlockUntilDoneOrFailed();
+
+ StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
+ CreateDeviceDescription() const override;
+
+ bool CreateStreamDependency(Stream* dependent, Stream* other) override;
+
+ void DeallocateStream(Stream* stream) override;
+
+ void Deallocate(const DeviceMemoryBase& memory);
+
+ void Deallocate(DeviceMemoryBase* memory) override;
+
+ Status DeallocateEvent(Event* event) override;
+
+ void DeallocateTimer(Timer* timer) override;
+
+ bool DeviceMemoryUsage(int64* free, int64* total) const override;
+
+ void DequeueOutfeed(int32 outfeed_queue_index, absl::Span<uint8> bytes,
+ StatusCallback done);
+
+ Status EnqueueInfeed(int32 infeed_queue_index,
+ absl::Span<const uint8> bytes);
+
+ absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
+
+ Status GetStatus(Stream* stream) override;
+
+ std::unique_ptr<::stream_executor::internal::StreamInterface>
+ GetStreamImplementation() override;
+
+ std::unique_ptr<::stream_executor::internal::TimerInterface>
+ GetTimerImplementation() override;
+
+ std::unique_ptr<::stream_executor::internal::EventInterface>
+ CreateEventImplementation() override;
+
+ bool HostCallback(Stream* stream, std::function<Status()> callback) override;
+
+ bool Memcpy(Stream* stream, void* host_dst,
+ const ::stream_executor::DeviceMemoryBase& device_src,
+ uint64 size) override;
+
+ bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
+ const void* host_src, uint64 size) override;
+
+ bool MemcpyDeviceToDevice(Stream* stream,
+ ::stream_executor::DeviceMemoryBase* gpu_dst,
+ const ::stream_executor::DeviceMemoryBase& host_src,
+ uint64 size) override;
+
+ void SyncAndForgetFailedStreams();
+ bool SynchronizeAllActivity() override;
+
+ Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
+ const void* host_src, uint64 size) override;
+ Status SynchronousMemcpy(
+ void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
+ uint64 size) override;
+ Status SynchronousMemcpyDeviceToDevice(
+ ::stream_executor::DeviceMemoryBase* device_dst,
+ const ::stream_executor::DeviceMemoryBase& device_src,
+ uint64 size) override;
+
+ int PlatformDeviceCount() override;
+
+ Event::Status PollForEventStatus(Event* event) override;
+ Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
+ Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
+
+ bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
+ bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
+
+ Status WaitForInfeedReady(int32 infeed_queue_index);
+
+ Status WaitForOutfeedReady(int32 outfeed_queue_index);
+
+ const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
+ return *platform_;
+ }
+
+ ::tensorflow::tpu::TpuPlatformInterface& platform() override {
+ return *platform_;
+ }
+
+ // TODO(henrytan): convert this to override once the base interface is changed
+ // to TpuExecutorInterface.
+ StatusOr<std::unique_ptr<
+ tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
+ CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
+ int64 size) override {
+ LOG(FATAL) << "Unimplemented.";
+ }
+
+ // -- Unimplemented (stubbed out) methods.
+ std::unique_ptr<stream_executor::internal::KernelInterface>
+ CreateKernelImplementation() override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ stream_executor::SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
+ LOG(FATAL) << "not yet implemented";
+ }
+
+ void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
+ uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ Status MemZero(Stream* stream, DeviceMemoryBase* location,
+ uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
+ uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ Status SetDeviceSharedMemoryConfig(
+ stream_executor::SharedMemoryConfig config) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ void* HostMemoryAllocate(uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ void HostMemoryDeallocate(void* mem) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ bool HostMemoryRegister(void* mem, uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ bool HostMemoryUnregister(void* mem) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+ Status SynchronousMemSet(DeviceMemoryBase* location, int value,
+ uint64 size) override {
+ LOG(FATAL) << "not yet implemented";
+ }
+
+ private:
+ EventMap event_map_;
+ TimerMap timer_map_;
+
+ TpuPlatform::StreamMap& stream_map() {
+ return *(static_cast<TpuPlatform*>(platform_)->stream_map());
+ }
+
+ ::tensorflow::tpu::TpuPlatformInterface* platform_;
+ SE_StreamExecutor* executor_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
new file mode 100644
index 0000000..8bf2ecb
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
@@ -0,0 +1,293 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include "tensorflow/c/tf_attrtype.h"
+#include "tensorflow/c/tf_datatype.h"
+#include "tensorflow/c/tf_status.h"
+
+typedef struct SE_Platform SE_Platform;
+typedef struct SE_StreamExecutor SE_StreamExecutor;
+typedef struct SE_Stream SE_Stream;
+typedef struct SE_Event SE_Event;
+typedef struct SE_Timer SE_Timer;
+typedef struct SE_Status SE_Status;
+
+typedef struct SE_PlatformId {
+ void* id; // aka stream_executor::Platform::Id
+} SE_PlatformId;
+typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig;
+typedef struct SE_DeviceOptions SE_DeviceOptions;
+typedef SE_Status* (*SE_StatusCallbackFn)(void*);
+
+typedef struct SE_DeviceMemoryBase {
+ void* opaque;
+ uint64_t size;
+ uint64_t payload;
+} SE_DeviceMemoryBase;
+
+typedef struct SE_AllocatorStats {
+ int64_t num_allocs;
+ int64_t bytes_in_use;
+ int64_t peak_bytes_in_use;
+ int64_t largest_alloc_size;
+
+ bool has_bytes_limit;
+ int64_t bytes_limit;
+
+ int64_t bytes_reserved;
+ int64_t peak_bytes_reserved;
+
+ bool has_bytes_reservable_limit;
+ int64_t bytes_reservable_limit;
+
+ int64_t largest_free_block_bytes;
+} SE_AllocatorStats;
+
+typedef struct SE_DeviceDescription {
+ char* device_vendor;
+ char* platform_version;
+ char* driver_version;
+ char* runtime_version;
+ char* pci_bus_id;
+ char* name;
+
+ int64_t thread_dim_limit_x;
+ int64_t thread_dim_limit_y;
+ int64_t thread_dim_limit_z;
+ int64_t block_dim_limit_x;
+ int64_t block_dim_limit_y;
+ int64_t block_dim_limit_z;
+
+ int64_t threads_per_core_limit;
+ int64_t threads_per_block_limit;
+ int64_t threads_per_warp;
+
+ int64_t registers_per_core_limit;
+ int64_t registers_per_block_limit;
+
+ int64_t device_address_bits;
+ int64_t device_memory_size;
+ int64_t memory_bandwidth;
+
+ int64_t shared_memory_per_core;
+ int64_t shared_memory_per_block;
+
+ float clock_rate_ghz;
+
+ int cuda_compute_capability_major;
+ int cuda_compute_capability_minor;
+
+ int rocm_amdgpu_isa_version;
+
+ int numa_node;
+ int core_count;
+ bool ecc_enabled;
+} SE_DeviceDescription;
+
+typedef struct XLA_TransferManager XLA_TransferManager;
+
+typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
+
+// Represents an XLA shape tree.
+// Shapes are flattened in default traversal order.
+typedef struct XLA_Shape {
+ char* bytes;
+ size_t size;
+} XLA_Shape;
+
+// Represents a leaf node for a XLA shaped buffer.
+typedef struct XLA_ShapedBuffer {
+ XLA_Shape on_host_shape;
+ XLA_Shape on_device_shape;
+ int device_ordinal;
+
+ SE_DeviceMemoryBase* bases;
+ size_t count;
+} XLA_ShapedBuffer;
+
+// Represents a leaf XLA literal.
+typedef struct XLA_Literal {
+ char** buffers;
+ size_t* sizes;
+ size_t count;
+ XLA_Shape shape;
+} XLA_Literal;
+
+typedef void (*XLA_CallbackFn)(void*);
+typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*);
+
+extern "C" {
+
+SE_Platform* TpuPlatform_New();
+void TpuPlatform_Free(SE_Platform* platform);
+void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size,
+ const char** options_key,
+ const char** options_value, SE_Status* status);
+bool TpuPlatform_Initialized(SE_Platform* platform);
+SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
+ SE_StreamExecutorConfig* config,
+ SE_Status* status);
+SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
+int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
+int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
+
+void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
+ SE_DeviceOptions* device_options, SE_Status* status);
+void TpuExecutor_Free(SE_StreamExecutor* executor);
+
+int TpuExecutor_PlatformDeviceCount(SE_StreamExecutor* executor);
+
+SE_DeviceMemoryBase TpuExecutor_Allocate(SE_StreamExecutor* executor,
+ uint64_t size, int64_t memory_space);
+void TpuExecutor_Deallocate(SE_StreamExecutor* executor,
+ SE_DeviceMemoryBase* memory);
+bool TpuExecutor_GetAllocatorStats(SE_StreamExecutor* executor,
+ SE_AllocatorStats* stats);
+bool TpuExecutor_DeviceMemoryUsage(SE_StreamExecutor* executor, int64_t* free,
+ int64_t* total);
+
+bool TpuExecutor_AllocateStream(SE_StreamExecutor* executor, SE_Stream* stream);
+void TpuExecutor_DeallocateStream(SE_StreamExecutor* executor,
+ SE_Stream* stream);
+bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor,
+ SE_Stream* dependent, SE_Stream* other);
+void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_Status* status);
+
+void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event,
+ SE_Status* status);
+void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event,
+ SE_Status* status);
+int TpuExecutor_PollForEventStatus(SE_StreamExecutor* executor,
+ SE_Event* event);
+void TpuExecutor_RecordEvent(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_Event* event, SE_Status* status);
+void TpuExecutor_WaitForEvent(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_Event* event, SE_Status* status);
+
+bool TpuExecutor_AllocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
+void TpuExecutor_DeallocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
+bool TpuExecutor_StartTimer(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_Timer* timer);
+bool TpuExecutor_StopTimer(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_Timer* timer);
+
+void TpuExecutor_SynchronousMemcpyToHost(SE_StreamExecutor* executor,
+ void* host_dst,
+ const SE_DeviceMemoryBase* device_src,
+ uint64_t size, SE_Status* status);
+void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor,
+ SE_DeviceMemoryBase* device_dst,
+ const void* host_src, uint64_t size,
+ SE_Status* status);
+bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream,
+ void* host_dst,
+ const SE_DeviceMemoryBase* device_src,
+ uint64_t size);
+
+bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_DeviceMemoryBase* device_dst,
+ const void* host_src, uint64_t size);
+
+void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor,
+ int32_t infeed_queue_index, const uint8_t* data,
+ int64_t size, SE_Status* status);
+void TpuExecutor_DequeueOutfeed(SE_StreamExecutor* executor,
+ int32_t outfeed_queue_index, uint8_t* data,
+ int64_t size, SE_Status* status);
+void TpuExecutor_WaitForInfeedReady(SE_StreamExecutor* executor,
+ int32_t infeed_queue_index,
+ SE_Status* status);
+void TpuExecutor_WaitForOutfeedReady(SE_StreamExecutor* executor,
+ int32_t outfeed_queue_index,
+ SE_Status* status);
+
+void TpuExecutor_BlockHostUntilDone(SE_StreamExecutor* executor,
+ SE_Stream* stream, SE_Status* status);
+void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
+ SE_Status* status);
+void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
+bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
+
+SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
+void TpuStream_Free(SE_Stream*);
+void* TpuStream_Stream(SE_Stream*);
+bool TpuStream_Status(SE_Stream*);
+
+SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
+void TpuEvent_Free(SE_Event*);
+
+SE_Timer* TpuTimer_New(SE_StreamExecutor* parent);
+void TpuTimer_Free(SE_Timer*);
+int64_t TpuTimer_Nanoseconds(SE_Timer*);
+int64_t TpuTimer_Microseconds(SE_Timer*);
+
+SE_Status* TpuStatus_New();
+SE_Status* TpuStatus_Create(int32_t code, const char* msg);
+void TpuStatus_Free(SE_Status* status);
+const char* TpuStatus_Message(SE_Status* status);
+int TpuStatus_Code(SE_Status* status);
+bool TpuStatus_Ok(SE_Status* status);
+
+SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default();
+void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal);
+void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*);
+
+SE_DeviceDescription* TpuDeviceDescription_New();
+void TpuDeviceDescription_Free(SE_DeviceDescription* description);
+void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor,
+ SE_DeviceDescription* description,
+ SE_Status* status);
+
+SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags);
+void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options);
+
+bool TpuExecutor_HostCallback(SE_StreamExecutor* executor, SE_Stream* stream,
+ SE_StatusCallbackFn callback_fn, void* ctx);
+
+XLA_TransferManager* TpuTransferManager_New();
+void TpuTransferManager_Free(XLA_TransferManager* manager);
+SE_PlatformId TpuTransferManager_PlatformId(XLA_TransferManager* manager);
+void TpuTransferManager_HostShapeToDeviceShape(XLA_TransferManager* manager,
+ XLA_Shape* host_shape,
+ XLA_Shape* device_shape);
+void TpuTransferManager_TransferLiteralToDeviceAsync(
+ XLA_TransferManager* manager, SE_Stream* stream, XLA_Literal* literal,
+ XLA_ShapedBuffer* device_buffer, SE_Status* status);
+void TpuTransferManager_TransferLiteralFromDevice(
+ XLA_TransferManager* manager, SE_Stream* stream,
+ XLA_ShapedBuffer* device_buffer, XLA_Literal* literal,
+ XLA_StatusCallbackFn callback, void* ctx);
+
+int64_t TpuTransferManager_GetByteSizeRequirement(XLA_TransferManager* manager,
+ XLA_Shape* shape);
+void TpuTransferManager_WriteSingleTupleIndexTable(
+ XLA_TransferManager* manager, SE_Stream* stream,
+ SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
+ SE_DeviceMemoryBase* region, SE_Status* status);
+
+XLA_ComputationPlacer* TpuComputationPlacer_New();
+void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
+}
+
+// extern "C"
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor_interface.h b/tensorflow/stream_executor/tpu/tpu_executor_interface.h
new file mode 100644
index 0000000..5b00f61
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor_interface.h
@@ -0,0 +1,64 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
+
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/timer.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tpu {
+class TpuCore;
+} // namespace tpu
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuExecutorInterface
+ : public ::stream_executor::internal::StreamExecutorInterface {
+ public:
+ using Status = ::stream_executor::port::Status;
+ template <typename T>
+ using StatusOr = ::stream_executor::port::StatusOr<T>;
+
+ class TemporaryDeviceMemory {
+ public:
+ virtual ~TemporaryDeviceMemory() {}
+ virtual stream_executor::DeviceMemoryBase AsDeviceMemoryBase() const = 0;
+ };
+
+ virtual StatusOr<std::unique_ptr<TemporaryDeviceMemory>>
+ CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
+ int64 size) {
+ LOG(FATAL) << "Unimplemented.";
+ }
+
+ virtual const TpuPlatformInterface& platform() const {
+ LOG(FATAL) << "Unimplemented.";
+ }
+
+ virtual TpuPlatformInterface& platform() { LOG(FATAL) << "Unimplemented."; }
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.cc b/tensorflow/stream_executor/tpu/tpu_node_context.cc
new file mode 100644
index 0000000..2a4954d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_node_context.cc
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/stream_executor/device_memory_allocator.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
+
+namespace tensorflow {
+namespace tpu {
+
+using stream_executor::port::Status;
+using stream_executor::port::StatusOr;
+
+/*static*/ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Initialize(
+ int device_ordinal) {
+ StatusHelper status;
+ XLA_TpuNodeContext* node_context =
+ TpuNodeContext_Create(device_ordinal, status.c_status);
+ if (!status.status().ok()) {
+ TpuNodeContext_Free(node_context);
+ return status.status();
+ }
+ return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
+}
+
+TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
+
+/* static */
+Status TpuNodeContext::StopChipHeartbeats() {
+ StatusHelper status;
+ TpuNodeContext_StopChipHeartbeats(status.c_status);
+ return status.status();
+}
+
+/* static */
+Status TpuNodeContext::CloseTpuHost() {
+ StatusHelper status;
+ TpuNodeContext_CloseTpuHost(status.c_status);
+ return status.status();
+}
+
+/* static */
+tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() {
+ return TpuPlatformInterface::GetRegisteredPlatform();
+}
+
+/* static */
+stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() {
+ static stream_executor::StreamExecutorMemoryAllocator* memory_allocator =
+ new stream_executor::StreamExecutorMemoryAllocator(
+ platform(),
+ xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie());
+ return memory_allocator;
+}
+
+/* static */
+xla::Backend* TpuNodeContext::backend() {
+ static xla::Backend* backend =
+ xla::Backend::CreateBackend(
+ xla::BackendOptions().set_platform(platform()))
+ .ValueOrDie()
+ .release();
+ return backend;
+}
+
+/* static */
+StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
+ int device_ordinal) {
+ return backend()->BorrowStream(device_ordinal);
+}
+
+/* static */
+StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
+ stream_executor::StreamExecutor* executor) {
+ return backend()->BorrowStream(executor);
+}
+
+/* static */
+xla::TransferManager* TpuNodeContext::transfer_manager() {
+ return xla::TransferManager::GetForPlatform(platform()).ValueOrDie();
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.h b/tensorflow/stream_executor/tpu/tpu_node_context.h
new file mode 100644
index 0000000..e1e1ffc
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_node_context.h
@@ -0,0 +1,89 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
+
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/stream_executor/device_memory_allocator.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuNodeContext final {
+ public:
+ using Status = stream_executor::port::Status;
+ template <typename T>
+ using StatusOr = stream_executor::port::StatusOr<T>;
+
+ static StatusOr<std::unique_ptr<TpuNodeContext>> Initialize(
+ int device_ordinal);
+
+ explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context)
+ : device_ordinal_(device_ordinal), node_context_(node_context) {
+ CHECK_NE(node_context, nullptr);
+ }
+ ~TpuNodeContext();
+
+ TpuNodeContext(const TpuNodeContext&) = delete;
+ TpuNodeContext& operator=(const TpuNodeContext&) = delete;
+
+ static Status StopChipHeartbeats();
+
+ static Status CloseTpuHost();
+
+ static tensorflow::tpu::TpuPlatformInterface* platform();
+
+ static stream_executor::DeviceMemoryAllocator* memory_allocator();
+
+ static xla::TransferManager* transfer_manager();
+
+ static xla::Backend* backend();
+
+ static StatusOr<xla::StreamPool::Ptr> BorrowStream(int device_ordinal);
+
+ static StatusOr<xla::StreamPool::Ptr> BorrowStream(
+ stream_executor::StreamExecutor* executor);
+
+ stream_executor::StreamExecutor* stream_executor() {
+ LOG(FATAL) << "Not implemented yet.";
+ }
+
+ std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; }
+
+ int index_on_host() { LOG(FATAL) << "Not implemented yet."; }
+
+ int device_ordinal() const { return device_ordinal_; }
+
+ private:
+ const int device_ordinal_;
+ XLA_TpuNodeContext* const node_context_;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h b/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h
new file mode 100644
index 0000000..d2684e4
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
+
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
+
+XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
+ SE_Status* status);
+void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
+
+void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
+void TpuNodeContext_CloseTpuHost(SE_Status* status);
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc
new file mode 100644
index 0000000..c44926d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform.cc
@@ -0,0 +1,125 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+namespace tensorflow {
+
+PLATFORM_DEFINE_ID(TpuPlatform::kId);
+TpuPlatform* tpu_registered_platform = nullptr;
+
+using Status = ::stream_executor::port::Status;
+template <typename T>
+using StatusOr = ::stream_executor::port::StatusOr<T>;
+
+TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); }
+
+TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
+ return tpu_registered_platform;
+}
+
+Status TpuPlatform::Initialize(
+ const std::map<std::string, std::string>& platform_options) {
+ StatusHelper status;
+
+ size_t options_size = platform_options.size();
+ const char** options_key =
+ static_cast<const char**>(malloc(sizeof(const char*) * options_size));
+ const char** options_value =
+ static_cast<const char**>(malloc(sizeof(const char*) * options_size));
+
+ size_t i = 0;
+ for (const auto& option : platform_options) {
+ options_key[i] = option.first.c_str();
+ options_value[i] = option.second.c_str();
+ i++;
+ }
+
+ TpuPlatform_Initialize(platform_, options_size, options_key, options_value,
+ status.c_status);
+
+ free(options_key);
+ free(options_value);
+
+ return status.status();
+}
+
+TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); }
+
+int TpuPlatform::VisibleDeviceCount() const {
+ return TpuPlatform_VisibleDeviceCount(platform_);
+}
+
+StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
+ const ::stream_executor::StreamExecutorConfig& config) {
+ return executor_cache_.GetOrCreate(
+ config, [&]() { return GetUncachedExecutor(config); });
+}
+
+StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
+TpuPlatform::GetUncachedExecutor(
+ const ::stream_executor::StreamExecutorConfig& config) {
+ SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default();
+
+ TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal);
+
+ StatusHelper status;
+ SE_StreamExecutor* executor =
+ TpuPlatform_GetExecutor(platform_, c_config, status.c_status);
+ TpuStreamExecutorConfig_Free(c_config);
+ if (!status.ok()) {
+ return status.status();
+ }
+ return std::make_unique<stream_executor::StreamExecutor>(
+ this, absl::make_unique<tensorflow::TpuExecutor>(this, executor),
+ config.ordinal);
+}
+
+::stream_executor::Platform::Id TpuPlatform::id() const {
+ return TpuPlatform::kId;
+}
+
+const std::string& TpuPlatform::Name() const {
+ static std::string* name = new std::string(kName);
+ return *name;
+}
+
+int64 TpuPlatform::TpuMemoryLimit() {
+ return TpuPlatform_TpuMemoryLimit(platform_);
+}
+
+} // namespace tensorflow
+
+void RegisterTpuPlatform() {
+ tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
+ std::unique_ptr<stream_executor::Platform> platform(
+ tensorflow::tpu_registered_platform);
+ SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
+ std::move(platform)));
+}
+
+REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
+
+// Note that module initialization sequencing is not supported in the
+// open-source project, so this will be a no-op there.
+REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
+REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
+ tpu_platform);
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.h b/tensorflow/stream_executor/tpu/tpu_platform.h
new file mode 100644
index 0000000..9a67045
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform.h
@@ -0,0 +1,121 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
+
+#include <memory>
+
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/executor_cache.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+
+class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
+ public:
+ using StreamMap =
+ absl::flat_hash_map<stream_executor::internal::StreamInterface*,
+ SE_Stream*>;
+
+ static const ::stream_executor::Platform::Id kId;
+ static constexpr char kName[] = "TPU";
+
+ using Status = ::stream_executor::port::Status;
+ template <typename T>
+ using StatusOr = ::stream_executor::port::StatusOr<T>;
+
+ TpuPlatform();
+
+ ~TpuPlatform() override;
+
+ static TpuPlatform* GetRegisteredPlatform();
+
+ Id id() const override;
+
+ const std::string& Name() const override;
+
+ int VisibleDeviceCount() const override;
+
+ int64 TpuMemoryLimit() override;
+
+ bool Initialized() const override {
+ return TpuPlatform_Initialized(platform_);
+ }
+
+ Status Initialize(
+ const std::map<std::string, std::string>& platform_options) override;
+
+ Status Reset() override { return Reset(false); }
+
+ Status Reset(bool only_tear_down) override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
+ DescriptionForDevice(int ordinal) const override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice(
+ int ordinal) override {
+ stream_executor::StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ return GetExecutor(config);
+ }
+
+ StatusOr<::stream_executor::StreamExecutor*>
+ ExecutorForDeviceWithPluginConfig(
+ int ordinal,
+ const ::stream_executor::PluginConfig& plugin_config) override {
+ stream_executor::StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ config.plugin_config = plugin_config;
+ return GetExecutor(config);
+ }
+
+ StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
+ const ::stream_executor::StreamExecutorConfig& config) override;
+
+ StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
+ GetUncachedExecutor(
+ const ::stream_executor::StreamExecutorConfig& config) override;
+
+ void RegisterTraceListener(
+ std::unique_ptr<stream_executor::TraceListener> listener) override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ void UnregisterTraceListener(
+ stream_executor::TraceListener* listener) override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ StreamMap* stream_map() { return &stream_map_; }
+
+ private:
+ SE_Platform* platform_;
+
+ stream_executor::ExecutorCache executor_cache_;
+ StreamMap stream_map_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
new file mode 100644
index 0000000..c5b8ece
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+namespace tensorflow {
+namespace tpu {
+
+/* static */
+TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
+ // Prefer TpuPlatform if it's registered.
+ auto status_or_tpu_platform =
+ stream_executor::MultiPlatformManager::PlatformWithName("TPU");
+ if (status_or_tpu_platform.ok()) {
+ return static_cast<TpuPlatformInterface*>(
+ status_or_tpu_platform.ValueOrDie());
+ }
+ if (status_or_tpu_platform.status().code() != error::NOT_FOUND) {
+ LOG(WARNING) << "Error when getting the TPU platform: "
+ << status_or_tpu_platform.status();
+ return nullptr;
+ }
+
+ // Use any other registered TPU platform.
+ auto status_or_other_tpu_platforms =
+ stream_executor::MultiPlatformManager::PlatformsWithFilter(
+ [](const stream_executor::Platform* platform) {
+ return dynamic_cast<const TpuPlatformInterface*>(platform) !=
+ nullptr;
+ });
+ if (!status_or_other_tpu_platforms.ok()) {
+ LOG(WARNING) << "Error when getting other TPU platforms: "
+ << status_or_tpu_platform.status();
+ return nullptr;
+ }
+ auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie();
+ if (!other_tpu_platforms.empty()) {
+ LOG(WARNING) << other_tpu_platforms.size()
+ << " TPU platforms registered, selecting "
+ << other_tpu_platforms[0]->Name();
+ return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
+ }
+
+ LOG(WARNING) << "No TPU platform registered";
+ return nullptr;
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
new file mode 100644
index 0000000..5c7aa8e
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
@@ -0,0 +1,44 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/platform.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuPlatformInterface : public stream_executor::Platform {
+ public:
+ using Status = stream_executor::port::Status;
+
+ // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
+ // registered, finds the most suitable one. Returns nullptr if no TPU platform
+ // is registered or an error occurred.
+ static TpuPlatformInterface* GetRegisteredPlatform();
+
+ virtual Status Reset() { return Reset(false); }
+
+ virtual Status Reset(bool only_tear_down) = 0;
+
+ virtual int64 TpuMemoryLimit() = 0;
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h
new file mode 100644
index 0000000..b8fd10d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_stream.h
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
+
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class TpuStream : public stream_executor::internal::StreamInterface {
+ public:
+ explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
+ ~TpuStream() override { TpuStream_Free(stream_); }
+
+ private:
+ SE_Stream* stream_;
+};
+
+class TpuEvent : public ::stream_executor::internal::EventInterface {
+ public:
+ explicit TpuEvent(SE_Event* event) : event_(event) {}
+ ~TpuEvent() override { TpuEvent_Free(event_); }
+
+ private:
+ SE_Event* event_;
+};
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_stream_interface.h b/tensorflow/stream_executor/tpu/tpu_stream_interface.h
new file mode 100644
index 0000000..2e5e02d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_stream_interface.h
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
+
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuStreamInterface : public ::stream_executor::internal::StreamInterface {
+};
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_timer.h b/tensorflow/stream_executor/tpu/tpu_timer.h
new file mode 100644
index 0000000..246a0b7
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_timer.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+namespace tensorflow {
+
+class TpuTimer : public ::stream_executor::internal::TimerInterface {
+ public:
+ explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
+ ~TpuTimer() override { TpuTimer_Free(timer_); }
+ uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); }
+ uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); }
+
+ private:
+ SE_Timer* timer_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc
new file mode 100644
index 0000000..473585a
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+namespace tensorflow {
+
+using Status = stream_executor::port::Status;
+template <typename T>
+using StatusOr = stream_executor::port::StatusOr<T>;
+
+TpuTransferManager::TpuTransferManager() {
+ manager_ = TpuTransferManager_New();
+}
+
+TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); }
+
+stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
+ return TpuPlatform::kId;
+}
+
+xla::Shape TpuTransferManager::HostShapeToDeviceShape(
+ const xla::Shape& host_shape) const {
+ XLA_Shape c_host_shape;
+ XLA_Shape c_device_shape;
+
+ TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
+
+ TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape,
+ &c_device_shape);
+ xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
+ TpuConversions::CShapeCleanup(&c_host_shape);
+ TpuConversions::CShapeCleanup(&c_device_shape);
+ return device_shape;
+}
+
+Status TpuTransferManager::TransferLiteralToDeviceAsync(
+ stream_executor::Stream* stream, const xla::LiteralSlice& literal,
+ const xla::ShapedBuffer& device_buffer,
+ const TransferMetadata* transfer_metadata) {
+ StatusHelper status;
+
+ XLA_Literal c_literal;
+ TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
+
+ XLA_ShapedBuffer c_device_buffer;
+ TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
+ &c_device_buffer);
+
+ TpuTransferManager_TransferLiteralToDeviceAsync(
+ manager_,
+ TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
+ stream->implementation()),
+ &c_literal, &c_device_buffer, status.c_status);
+ TpuConversions::CShapedBufferCleanup(&c_device_buffer);
+ TpuConversions::CLiteralCleanup(&c_literal);
+ return status.status();
+}
+
+struct TransferFromDeviceState {
+ std::atomic<int64_t> remaining_transfers;
+ StatusHelper status_helper;
+ std::function<void(Status)> done;
+
+ void TransferFinished(SE_Status* status) {
+ if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) {
+ status_helper.c_status = status;
+ } else {
+ TpuStatus_Free(status);
+ }
+
+ if (--remaining_transfers == 0) {
+ done(status_helper.status());
+ delete this;
+ }
+ }
+};
+
+void TransferLiteralFromDeviceTrampoline(void* ctx, SE_Status* status) {
+ reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
+}
+
+void TpuTransferManager::TransferLiteralFromDevice(
+ stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
+ xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
+ const TransferMetadata* transfer_metadata) {
+ TransferFromDeviceState* state = new TransferFromDeviceState;
+ state->remaining_transfers = 1;
+ state->done = done;
+ XLA_ShapedBuffer c_device_buffer;
+ TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
+ &c_device_buffer);
+ XLA_Literal c_literal;
+ TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
+
+ TpuTransferManager_TransferLiteralFromDevice(
+ manager_,
+ TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
+ stream->implementation()),
+ &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
+ TpuConversions::CShapedBufferCleanup(&c_device_buffer);
+ TpuConversions::CLiteralCleanup(&c_literal);
+}
+
+int64 TpuTransferManager::GetByteSizeRequirement(
+ const xla::Shape& shape) const {
+ XLA_Shape c_shape;
+ TpuConversions::XlaShapeToCShape(shape, &c_shape);
+
+ int64 size_in_bytes =
+ TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape);
+
+ TpuConversions::CShapeCleanup(&c_shape);
+ return size_in_bytes;
+}
+
+Status TpuTransferManager::WriteSingleTupleIndexTable(
+ stream_executor::Stream* stream,
+ absl::Span<const stream_executor::DeviceMemoryBase> elements,
+ const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
+ CHECK_GT(elements.size(), 0);
+ SE_DeviceMemoryBase* elements_bases =
+ new SE_DeviceMemoryBase[elements.size()];
+ for (int i = 0; i < elements.size(); i++) {
+ elements_bases[i] =
+ SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
+ elements[i].size(), elements[i].payload()};
+ }
+ XLA_Shape c_shape;
+ TpuConversions::XlaShapeToCShape(shape, &c_shape);
+ SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
+ region->payload()};
+ StatusHelper status;
+
+ TpuTransferManager_WriteSingleTupleIndexTable(
+ manager_,
+ TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
+ stream->implementation()),
+ elements_bases, elements.size(), &c_shape, ®ion_base, status.c_status);
+
+ delete[] elements_bases;
+ TpuConversions::CShapeCleanup(&c_shape);
+ return status.status();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
new file mode 100644
index 0000000..163ac81
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
@@ -0,0 +1,83 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+namespace tensorflow {
+
+class TpuTransferManager : public xla::TransferManager {
+ public:
+ TpuTransferManager();
+ ~TpuTransferManager() override;
+
+ using Status = stream_executor::port::Status;
+ template <typename T>
+ using StatusOr = stream_executor::port::StatusOr<T>;
+
+ stream_executor::Platform::Id PlatformId() const override;
+
+ xla::Shape HostShapeToDeviceShape(
+ const xla::Shape& host_shape) const override;
+
+ Status TransferLiteralToDeviceAsync(
+ stream_executor::Stream* stream, const xla::LiteralSlice& literal,
+ const xla::ShapedBuffer& device_buffer,
+ const TransferMetadata* transfer_metadata) override;
+
+ void TransferLiteralFromDevice(
+ stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
+ xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
+ const TransferMetadata* transfer_metadata) override;
+
+ Status TransferLiteralToInfeed(stream_executor::StreamExecutor* executor,
+ const xla::LiteralSlice& literal) override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ Status TransferLiteralFromOutfeed(
+ stream_executor::StreamExecutor* executor,
+ const xla::Shape& literal_shape,
+ xla::MutableBorrowingLiteral literal) override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ Status ResetDevices(
+ absl::Span<stream_executor::StreamExecutor* const> executor) override {
+ LOG(FATAL) << "Not yet implemented";
+ }
+
+ int64 GetByteSizeRequirement(const xla::Shape& shape) const override;
+
+ Status WriteSingleTupleIndexTable(
+ stream_executor::Stream* stream,
+ absl::Span<const stream_executor::DeviceMemoryBase> elements,
+ const xla::Shape& shape,
+ stream_executor::DeviceMemoryBase* region) override;
+
+ private:
+ XLA_TransferManager* manager_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc
new file mode 100644
index 0000000..f7f0c6f
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
+ return std::make_unique<TpuTransferManager>();
+}
+
+static bool InitModule() {
+ xla::TransferManager::RegisterTransferManager(TpuPlatform::kId,
+ CreateTpuTransferManager);
+ return true;
+}
+static bool module_initialized = InitModule();
+
+} // namespace tensorflow