Open sourcing some TPU-related work

PiperOrigin-RevId: 315431095
Change-Id: I734632c0e5723dfca37acf53bbbd2b378b04c95d
diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD
new file mode 100644
index 0000000..ef9e4a0
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/BUILD
@@ -0,0 +1,55 @@
+# Contains graph rewrites for TPU runtimes and optimizations.
+
+package(
+    default_visibility = [
+        "//tensorflow/core/tpu:__subpackages__",
+        "//tensorflow/stream_executor/tpu:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "distributed_tpu_configuration_rewrite_registration",
+    srcs = ["distributed_tpu_configuration_rewrite_registration.cc"],
+    deps = [
+        ":distributed_tpu_configuration_rewrite_pass",
+        "//tensorflow/core:core_cpu",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "distributed_tpu_configuration_rewrite_pass",
+    srcs = [
+        "distributed_tpu_configuration_rewrite_pass.cc",
+    ],
+    hdrs = [
+        "distributed_tpu_configuration_rewrite_pass.h",
+    ],
+    deps = [
+        ":distributed_tpu_rewrite_helpers",
+        "//tensorflow/cc:scope",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/protobuf/tpu:topology_proto_cc",
+        "//tensorflow/core/tpu:tpu_init_mode",
+        "//tensorflow/core/tpu/kernels:tpu_compile_op_options",
+    ],
+)
+
+cc_library(
+    name = "distributed_tpu_rewrite_helpers",
+    srcs = ["distributed_tpu_rewrite_helpers.cc"],
+    hdrs = ["distributed_tpu_rewrite_helpers.h"],
+    deps = [
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/core:core_cpu_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/tpu:tpu_defs",
+    ],
+)
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc
new file mode 100644
index 0000000..3b1e9d7
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc
@@ -0,0 +1,402 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Configuration for distributed TPU jobs
+
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
+
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
+#include "tensorflow/core/tpu/tpu_init_mode.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/dump_graph.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr char kIdentityOp[] = "Identity";
+constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
+constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
+constexpr char kWaitOp[] = "_WaitForDistributedTPU";
+constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
+constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
+constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
+constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
+constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
+constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
+constexpr int kDefaultStartupTimeout = 20;
+
+Status AddConfigurationNode(const string& configuration_device_name,
+                            int number_of_hosts, Graph* graph,
+                            bool enable_whole_mesh_compilations,
+                            Node** configuration_node) {
+  NodeDef config_def;
+  config_def.set_name(graph->NewName("configure_distributed_tpu"));
+  config_def.set_op(kInternalConfigureOp);
+  config_def.set_device(configuration_device_name);
+  AddNodeAttr("N", number_of_hosts, &config_def);
+  AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
+              &config_def);
+  // TODO(shikharagarwal): Fill with appropriate original node debug info.
+
+  Status status;
+  *configuration_node = graph->AddNode(config_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  (*configuration_node)->set_assigned_device_name(configuration_device_name);
+  return Status::OK();
+}
+
+Status AddHostConfigNode(const string& host_device_name,
+                         Node* configuration_node, Graph* graph,
+                         bool enable_whole_mesh_compilations,
+                         Node** host_configuration_node) {
+  NodeDef host_config_def;
+  host_config_def.set_name(graph->NewName("configure_tpu_host"));
+  host_config_def.set_op(kHostConfigureOp);
+  host_config_def.set_device(host_device_name);
+  AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
+              &host_config_def);
+  MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
+
+  Status status;
+  *host_configuration_node = graph->AddNode(host_config_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  (*host_configuration_node)->set_assigned_device_name(host_device_name);
+  graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
+  return Status::OK();
+}
+
+Status AddWaitNode(const string& configuration_device_name,
+                   const std::vector<Node*>& host_configuration_nodes,
+                   Graph* graph, Node** wait_node) {
+  NodeDef wait_def;
+  wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
+  wait_def.set_op(kWaitOp);
+  wait_def.set_device(configuration_device_name);
+  AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
+              &wait_def);
+  AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
+  if (!host_configuration_nodes.empty()) {
+    MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
+                   &wait_def);
+  }
+
+  Status status;
+  *wait_node = graph->AddNode(wait_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  (*wait_node)->set_assigned_device_name(configuration_device_name);
+  // Get the inputs from the host configuration nodes.
+  for (int i = 0; i < host_configuration_nodes.size(); ++i) {
+    graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
+  }
+  return Status::OK();
+}
+
+Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
+                             Graph* graph, Node** global_tpu_array_node) {
+  NodeDef global_tpu_array_def;
+  global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
+  global_tpu_array_def.set_op(kGlobalTPUArrayOp);
+  global_tpu_array_def.set_device(host_device_name);
+  MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
+
+  Status status;
+  *global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  (*global_tpu_array_node)->set_assigned_device_name(host_device_name);
+  graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
+  return Status::OK();
+}
+
+Status AddSynchronizationNode(
+    const NodeDef& sync_node_def, const string& device_name,
+    const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
+    const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+        output_dependencies,
+    Graph* graph) {
+  NodeDef sync_def;
+  sync_def.set_name(sync_node_def.name());
+  sync_def.set_op(kIdentityOp);
+  sync_def.set_device(device_name);
+  AddNodeAttr("T", DT_STRING, &sync_def);
+  MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
+
+  Status status;
+  Node* sync_node = graph->AddNode(sync_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  sync_node->set_assigned_device_name(device_name);
+  // Add control edges from the global array id nodes.
+  for (auto node : global_array_id_nodes) {
+    graph->AddControlEdge(node, sync_node);
+  }
+  // Forward the data from the wait node.
+  graph->AddEdge(wait_node, 0, sync_node, 0);
+  // Replace the output edges.
+  for (const DistributedTPURewriteHelpers::OutputDependency& dep :
+       output_dependencies) {
+    if (dep.dst_input == Graph::kControlSlot) {
+      graph->AddControlEdge(sync_node, dep.dst);
+    } else {
+      graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
+    }
+  }
+  return Status::OK();
+}
+
+
+Status AddShutdownNode(
+    const NodeDef& shutdown_node_def, const string& shutdown_device_name,
+    const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+        output_dependencies,
+    Graph* graph, Node** shutdown_node) {
+  NodeDef shutdown_def;
+  shutdown_def.set_name(shutdown_node_def.name());
+  shutdown_def.set_op(kInternalShutdownOp);
+  shutdown_def.set_device(shutdown_device_name);
+  MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
+
+  Status status;
+  *shutdown_node = graph->AddNode(shutdown_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  (*shutdown_node)->set_assigned_device_name(shutdown_device_name);
+  // Replace the output control edges.
+  for (const DistributedTPURewriteHelpers::OutputDependency& dep :
+       output_dependencies) {
+    if (dep.dst_input != Graph::kControlSlot) {
+      return errors::Internal("Shutdown node had non-control edge output");
+    }
+    graph->AddControlEdge(*shutdown_node, dep.dst);
+  }
+  return Status::OK();
+}
+
+Status AddHostDisconnectNode(const string& host_device_name,
+                             const std::vector<Node*>& input_dependencies,
+                             Node* post_disconnect_node, int output_index,
+                             Graph* graph) {
+  NodeDef host_disconnect_def;
+  host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
+  host_disconnect_def.set_op(kHostDisconnectOp);
+  host_disconnect_def.set_device(host_device_name);
+  MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
+                 &host_disconnect_def);
+
+  Status status;
+  Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status);
+  if (!status.ok()) {
+    return status;
+  }
+  host_disconnect_node->set_assigned_device_name(host_device_name);
+  // Replace the input control edges.
+  for (Node* src_node : input_dependencies) {
+    graph->AddControlEdge(src_node, host_disconnect_node);
+  }
+  if (output_index == -1) {
+    graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
+  } else {
+    graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+Status DistributedTPUConfigurationRewritePass::Run(
+    const GraphOptimizationPassOptions& options) {
+  VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
+
+  Graph* graph = options.graph->get();
+
+  if (VLOG_IS_ON(1)) {
+    DumpGraphToFile("distributed_tpu_configuration_before", *graph,
+                    options.flib_def);
+  }
+
+  // This pass can only run in the session master, which should fill
+  // in the device_set field to the options.
+  TF_RET_CHECK(options.device_set != nullptr);
+
+  TF_RETURN_IF_ERROR(
+      DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
+          kConfigureOp, graph, *options.device_set,
+          [](const NodeDef& configuration_node_def,
+             const string& configuration_device_name,
+             const std::vector<Device*>& host_devices,
+             const std::vector<Node*>& input_dependencies,
+             const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+                 output_dependencies,
+             Graph* graph) -> Status {
+            const std::string& embedding_attr_string = GetNodeAttrString(
+                AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
+
+            if (!embedding_attr_string.empty()) {
+              return errors::InvalidArgument("embedding_config must be empty.");
+            }
+
+            bool is_global_init = false;
+            bool enable_whole_mesh_compilations = false;
+            TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
+                                           "is_global_init", &is_global_init));
+            TryGetNodeAttr(configuration_node_def,
+                           "enable_whole_mesh_compilations",
+                           &enable_whole_mesh_compilations);
+            TF_RETURN_IF_ERROR(SetTPUInitMode(
+                is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
+
+            bool compilation_failure_closes_chips;
+            TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
+                                           "compilation_failure_closes_chips",
+                                           &compilation_failure_closes_chips));
+            internal::SetTpuCompilationFailureClosesChips(
+                compilation_failure_closes_chips);
+
+            // Add the global TPU system configuration node.
+            Node* configuration_node;
+            TF_RETURN_IF_ERROR(AddConfigurationNode(
+                configuration_device_name, host_devices.size(), graph,
+                enable_whole_mesh_compilations, &configuration_node));
+
+            // Add the host disconnect nodes.
+            for (int i = 0; i < host_devices.size(); ++i) {
+              const auto host_device = host_devices[i];
+              TF_RETURN_IF_ERROR(
+                  AddHostDisconnectNode(host_device->name(), input_dependencies,
+                                        configuration_node, i, graph));
+            }
+
+            // Add the host configuration nodes.
+            std::vector<Node*> host_configuration_nodes;
+            for (const auto host_device : host_devices) {
+              Node* host_configuration_node;
+              TF_RETURN_IF_ERROR(AddHostConfigNode(
+                  host_device->name(), configuration_node, graph,
+                  enable_whole_mesh_compilations, &host_configuration_node));
+              host_configuration_nodes.push_back(host_configuration_node);
+            }
+
+            // Add the node to wait for the system configuration to
+            // stabilize. Use the name of the original dummy Op in case it was
+            // the target of a Session::Run call.
+            Node* wait_node;
+            TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
+                                           host_configuration_nodes, graph,
+                                           &wait_node));
+
+            // Add the nodes to set the global TPU ids at each host.
+            std::vector<Node*> global_array_id_nodes;
+            for (const auto host_device : host_devices) {
+              Node* global_array_id_node;
+              TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
+                                                       wait_node, graph,
+                                                       &global_array_id_node));
+              global_array_id_nodes.push_back(global_array_id_node);
+            }
+
+            if (host_devices.empty()) {
+              return errors::InvalidArgument("TPU job contains no CPU devices");
+            }
+            TF_RET_CHECK(!host_devices.empty());
+
+            TF_RETURN_IF_ERROR(AddSynchronizationNode(
+                configuration_node_def, host_devices.front()->name(),
+                global_array_id_nodes, wait_node, output_dependencies, graph));
+
+            return Status::OK();
+          }));
+
+  if (VLOG_IS_ON(1)) {
+    DumpGraphToFile("distributed_tpu_configuration_after", *graph,
+                    options.flib_def);
+  }
+
+  VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
+  return Status::OK();
+}
+
+Status DistributedTPUShutdownRewritePass::Run(
+    const GraphOptimizationPassOptions& options) {
+  VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
+
+  Graph* graph = options.graph->get();
+
+  if (VLOG_IS_ON(1)) {
+    DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
+                    options.flib_def);
+  }
+
+  // This pass can only run in the session master, which should fill
+  // in the device_set field to the options.
+  TF_RET_CHECK(options.device_set != nullptr);
+
+  TF_RETURN_IF_ERROR(
+      DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
+          kShutdownOp, graph, *options.device_set,
+          [](const NodeDef& shutdown_node_def,
+             const string& shutdown_device_name,
+             const std::vector<Device*>& host_devices,
+             const std::vector<Node*>& input_dependencies,
+             const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
+                 output_dependencies,
+             Graph* graph) -> Status {
+            Node* shutdown_node;
+            TF_RETURN_IF_ERROR(
+                AddShutdownNode(shutdown_node_def, shutdown_device_name,
+                                output_dependencies, graph, &shutdown_node));
+
+            // Add the host disconnect nodes.
+            for (const auto host_device : host_devices) {
+              TF_RETURN_IF_ERROR(
+                  AddHostDisconnectNode(host_device->name(), input_dependencies,
+                                        shutdown_node, -1, graph));
+            }
+
+            return Status::OK();
+          }));
+
+  if (VLOG_IS_ON(1)) {
+    DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
+  }
+
+  VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h
new file mode 100644
index 0000000..191f32f
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Rewrites ConfigureDistributedTPU Op into a graph that configures each host.
+//
+// See the comment at the top of
+// third_party/tensorflow/core/ops/tpu_configuration_ops.cc to see the
+// sequence of Ops used to configure a distributed TPU system.
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+// Replaces dummy ConfigureDistributedTPU Ops assigned to TPU_SYSTEM
+// devices with _ConfigureDistributedTPU and _WaitForDistributedTPU
+// Ops on TPU_SYSTEM, and _InitializeHostForDistributedTPU on the CPU
+// device of each host in the same job as the given TPU_SYSTEM device.
+class DistributedTPUConfigurationRewritePass : public GraphOptimizationPass {
+ public:
+  Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+// Replaces dummy ShutdownDistributedTPU Ops assigned to TPU_SYSTEM
+// devices with _ShutdownDistributedTPU Ops on TPU_SYSTEM and
+// _DisconnectHostFromDistributedTPUSystem on the CPU device of each
+// host in the same job as the given TPU_SYSTEM device.
+class DistributedTPUShutdownRewritePass : public GraphOptimizationPass {
+ public:
+  Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc
new file mode 100644
index 0000000..db2b3a5
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
+
+namespace tensorflow {
+namespace {
+
+// This pass removes the TPUEmbeddingConfiguration in ConfigureDistributedTPU.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
+                      DistributedTPUConfigurationRewritePass);
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
+                      DistributedTPUShutdownRewritePass);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc
new file mode 100644
index 0000000..965a174
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc
@@ -0,0 +1,255 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions for TPU rewrite passes.
+
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// LINT.IfChange
+Status DistributedTPURewriteHelpers::GetSystemDevice(
+    const string& system_spec_string, const DeviceSet& device_set,
+    DeviceNameUtils::ParsedName* system_spec, Device** system_device) {
+  if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) {
+    system_spec->Clear();
+  }
+
+  // Callers may have relied on an Op only being registered on TPU_SYSTEM
+  // devices to ensure the Op is placed there. Augment the device spec to make
+  // the device type explicit.
+  if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) {
+    system_spec->type = DEVICE_TPU_SYSTEM;
+    system_spec->has_type = true;
+    system_spec->id = 0;
+    system_spec->has_id = true;
+  }
+
+  std::vector<Device*> system_devices;
+  device_set.FindMatchingDevices(*system_spec, &system_devices);
+  if (system_devices.empty()) {
+    if (system_spec_string.empty()) {
+      return errors::InvalidArgument(
+          "No TPU_SYSTEM device found. Please ensure that you're connected to "
+          "a host with a TPU_SYSTEM device.");
+    }
+    return errors::InvalidArgument("No matching devices found for '",
+                                   system_spec_string, "'");
+  } else if (system_devices.size() > 1) {
+    // Validate that all system devices are part of the same job.
+    std::unordered_set<string> job_names;
+    for (auto device : system_devices) {
+      const auto& parsed_name = device->parsed_name();
+      TF_RET_CHECK(parsed_name.has_job);
+      job_names.insert(parsed_name.job);
+    }
+    if (job_names.size() > 1) {
+      return errors::InvalidArgument(
+          "System devices cannot be part "
+          "of multiple different jobs.  Found: ",
+          str_util::Join(job_names, ","));
+    }
+
+    // Identify the lexicographically first device from the list of
+    // valid TPU SYSTEM devices, so that every process in the same
+    // 'cluster' definition uses the same system device.
+    std::sort(system_devices.begin(), system_devices.end(),
+              [](Device* i, Device* j) {
+                auto i_name = i->parsed_name();
+                auto j_name = j->parsed_name();
+                if (i_name.replica != j_name.replica) {
+                  return i_name.replica < j_name.replica;
+                }
+                return i_name.task < j_name.task;
+              });
+  }
+
+  *system_device = system_devices[0];
+  if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) {
+    return errors::InvalidArgument("Unable to re-parse system device name ",
+                                   (*system_device)->name(),
+                                   " as a device spec.");
+  }
+  return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+// LINT.IfChange
+Status DistributedTPURewriteHelpers::GetHostSystemDevices(
+    const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
+    std::vector<Device*>* host_system_devices) {
+  DeviceNameUtils::ParsedName host_spec;
+  if (system_spec.has_job) {
+    // The system Op has been explicitly assigned to a job, so we want
+    // all the hosts in that job.
+    CHECK(DeviceNameUtils::ParseFullName(
+        strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM,
+                        ":0"),
+        &host_spec));
+  } else {
+    // The system Op has not been explicitly assigned to a
+    // job, so take all hosts in the system. There will be a runtime
+    // error if some of those hosts don't contain TPU devices.
+    CHECK(DeviceNameUtils::ParseFullName(
+        strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec));
+  }
+  device_set.FindMatchingDevices(host_spec, host_system_devices);
+
+  TF_RET_CHECK(!host_system_devices->empty())
+      << "No hosts found matching device spec "
+      << DeviceNameUtils::ParsedNameToString(host_spec);
+
+  // Check that all the devices belong to the same job.
+  TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job);
+  const string& job_name = (*host_system_devices)[0]->parsed_name().job;
+  int replica = (*host_system_devices)[0]->parsed_name().replica;
+  for (const auto host_device : *host_system_devices) {
+    const auto& parsed_name = host_device->parsed_name();
+    TF_RET_CHECK(parsed_name.has_job);
+    if (parsed_name.job != job_name) {
+      return errors::InvalidArgument(
+          "All TPU host devices must be in the same job");
+    }
+    TF_RET_CHECK(parsed_name.has_replica);
+    if (parsed_name.replica != replica) {
+      return errors::InvalidArgument(
+          "All TPU host devices must be in the same replica");
+    }
+  }
+
+  // Sort the devices by replica and then task.
+  std::sort(host_system_devices->begin(), host_system_devices->end(),
+            [](Device* i, Device* j) {
+              auto i_name = i->parsed_name();
+              auto j_name = j->parsed_name();
+              return i_name.task < j_name.task;
+            });
+  return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+// LINT.IfChange
+Status DistributedTPURewriteHelpers::GetTPUDevices(
+    const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
+    int* num_tpus_per_host, std::vector<std::vector<Device*>>* tpu_devices) {
+  // GetHostSystemDevices returns the CPU device on each host that is
+  // going to be used for executing TPU code.
+  std::vector<Device*> host_system_devices;
+  TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices(
+      system_spec, device_set, &host_system_devices));
+
+  // Enumerate all the physical devices. Enumerate devices on task 0,
+  // then task 1, etc.
+  std::sort(host_system_devices.begin(), host_system_devices.end(),
+            [](Device* i, Device* j) {
+              return i->parsed_name().task < j->parsed_name().task;
+            });
+
+  *num_tpus_per_host = 0;
+  tpu_devices->clear();
+  tpu_devices->reserve(host_system_devices.size());
+  for (const auto device : host_system_devices) {
+    // Make a copy of the parsed name because we are going to change it.
+    DeviceNameUtils::ParsedName device_spec = device->parsed_name();
+    device_spec.has_type = true;
+    device_spec.type = "TPU";
+    // Enumerate all the available TPUs.
+    device_spec.has_id = false;
+    std::vector<Device*> host_tpu_devices;
+    device_set.FindMatchingDevices(device_spec, &host_tpu_devices);
+    // Sort the devices by device id.
+    std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
+              [](Device* i, Device* j) {
+                return i->parsed_name().id < j->parsed_name().id;
+              });
+    if (tpu_devices->empty()) {
+      // First iteration: set *num_tpus_per_host to the number of TPUs on the
+      // first host.
+      *num_tpus_per_host = host_tpu_devices.size();
+    } else if (*num_tpus_per_host != host_tpu_devices.size()) {
+      // Subsequent iterations: check the number of TPUs match the number on
+      // the first host.
+      return errors::InvalidArgument(
+          "Mismatched number of TPU devices in cluster ", *num_tpus_per_host,
+          " vs. ", host_tpu_devices.size());
+    }
+    tpu_devices->push_back(std::move(host_tpu_devices));
+  }
+  return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
+    const string& node_type, Graph* graph, const DeviceSet& device_set,
+    const std::function<
+        Status(const NodeDef& configuration_node_def,
+               const string& configuration_device_name,
+               const std::vector<Device*>& host_devices,
+               const std::vector<Node*>& input_dependencies,
+               const std::vector<OutputDependency>& output_dependencies,
+               Graph* graph)>& action) {
+  // Find all the matching nodes before mutating the graph.
+  std::vector<Node*> nodes;
+  for (Node* node : graph->nodes()) {
+    if (node->type_string() == node_type) {
+      nodes.push_back(node);
+    }
+  }
+
+  for (Node* node : nodes) {
+    string spec_string = node->requested_device();
+    DeviceNameUtils::ParsedName spec;
+    Device* device;
+    TF_RETURN_IF_ERROR(
+        GetSystemDevice(spec_string, device_set, &spec, &device));
+    const string& device_name = device->name();
+
+    std::vector<Device*> host_devices;
+    TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices));
+
+    std::vector<Node*> input_dependencies;
+    for (const Edge* edge : node->in_edges()) {
+      // Config ops have no inputs, so all edges must be control edges.
+      CHECK(edge->IsControlEdge());
+      input_dependencies.push_back(edge->src());
+    }
+    std::vector<OutputDependency> output_dependencies;
+    for (const Edge* edge : node->out_edges()) {
+      OutputDependency dep;
+      dep.src_output = edge->src_output();
+      dep.dst = edge->dst();
+      dep.dst_input = edge->dst_input();
+      output_dependencies.push_back(dep);
+    }
+    NodeDef node_def = node->def();
+
+    // Remove the node now so we can insert a new node with the same
+    // name inside the action.
+    graph->RemoveNode(node);
+
+    TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices,
+                              input_dependencies, output_dependencies, graph));
+  }
+
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h
new file mode 100644
index 0000000..40aacce
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions for TPU rewrite passes.
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+class DistributedTPURewriteHelpers {
+ public:
+  // Given a user-assigned device string, system_spec_string, parse it into
+  // system_spec. Verify that the device type is either TPU_SYSTEM or
+  // unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the
+  // type, verify that the spec matches a unique device in device_set, and
+  // return that device in system_device. The normal use case is for
+  // system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the
+  // job that contains the TPU hardware.
+  // TODO(b/110910013): Possibly remove the tpu system device.
+  static Status GetSystemDevice(const string& system_spec_string,
+                                const DeviceSet& device_set,
+                                DeviceNameUtils::ParsedName* system_spec,
+                                Device** system_device);
+
+  // Given a parsed system spec (e.g., the one returned above from
+  // GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on
+  // every host in the spec's job. If the spec does not include an explicit job,
+  // "localhost" is used.  Returns an error if system_spec matches devices from
+  // a multiple jobs or replicas.
+  static Status GetHostSystemDevices(
+      const DeviceNameUtils::ParsedName& system_spec,
+      const DeviceSet& device_set, std::vector<Device*>* host_system_devices);
+
+  // Given a parsed system spec (e.g., the one returned above from
+  // GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU
+  // devices on every host in the spec's job. If the spec does not include an
+  // explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number
+  // of TPU devices in each host, and verifies that each host in the job has
+  // the same number of TPU devices.
+  // Returns an error if system_spec matches devices from a multiple jobs or
+  // replicas.
+  static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec,
+                              const DeviceSet& device_set,
+                              int* num_tpus_per_host,
+                              std::vector<std::vector<Device*>>* tpu_devices);
+
+  // Perform 'action' on every node in 'graph' of type
+  // 'node_type'. This function is designed for use with configuration
+  // Ops that have no inputs or outputs. The arguments passed to 'action' are:
+  // 'configuration_node_name': the name of the node that matched
+  // 'configuration_device_name': the name of the device that the
+  // matching node is placed on
+  // 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are
+  // in the same system as the node that matched.
+  // 'input_dependencies': the set of nodes that have control edges to
+  // the matching node.
+  // 'output_dependencies': the set of output port, destination node, input port
+  // triples that have edges from the matching node. Input port is
+  // Graph::kControlSlot for a control edge.
+  // 'graph': the graph being mutated.
+  struct OutputDependency {
+    int src_output;
+    Node* dst;
+    int dst_input;
+  };
+  static Status ForConfigurationNodeMatchingType(
+      const string& node_type, Graph* graph, const DeviceSet& device_set,
+      const std::function<
+          Status(const NodeDef& configuration_node_def,
+                 const string& configuration_device_name,
+                 const std::vector<Device*>& host_devices,
+                 const std::vector<Node*>& input_dependencies,
+                 const std::vector<OutputDependency>& output_dependencies,
+                 Graph* graph)>& action);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
new file mode 100644
index 0000000..0e5a91c
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -0,0 +1,288 @@
+# TPU Kernel Implementations
+load(
+    "//tensorflow/core/platform:build_config.bzl",
+    "tf_proto_library_cc",
+)
+
+package(
+    default_visibility = [
+        "//tensorflow/core/tpu:__subpackages__",
+        "//tensorflow/stream_executor/tpu:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "tpu_compile_op_options",
+    srcs = ["tpu_compile_op_options.cc"],
+    hdrs = ["tpu_compile_op_options.h"],
+)
+
+cc_library(
+    name = "tpu_configuration_ops",
+    srcs = ["tpu_configuration_ops.cc"],
+    hdrs = ["tpu_configuration_ops.h"],
+    deps = [
+        ":tpu_mesh_state_interface",
+        "//tensorflow/c:tf_status",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/platform:refcount",
+        "//tensorflow/core/tpu:tpu_config_c_api",
+        "//tensorflow/core/tpu:tpu_configuration",
+        "//tensorflow/core/tpu:tpu_defs",
+        "//tensorflow/core/tpu:tpu_library_loader",
+        "//tensorflow/stream_executor/tpu:proto_helper",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "tpu_compile_c_api_hdrs",
+    hdrs = ["tpu_compile_c_api.h"],
+    deps = [
+        ":tpu_mesh_state_c_api",
+        "//tensorflow/c:tf_datatype",
+        "//tensorflow/stream_executor/tpu:proto_helper",
+        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
+    ],
+)
+
+tf_proto_library_cc(
+    name = "tpu_executable_info_proto",
+    srcs = ["tpu_executable_info.proto"],
+    cc_api_version = 2,
+    protodeps = [
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:hlo_proto",
+        "//tensorflow/core:protos_all",
+    ],
+)
+
+tf_proto_library_cc(
+    name = "tpu_compile_proto",
+    srcs = ["tpu_compile.proto"],
+    cc_api_version = 2,
+    protodeps = [
+        ":tpu_executable_info_proto",
+        "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:hlo_proto",
+        "//tensorflow/core:protos_all",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_key",
+    srcs = [],
+    hdrs = [
+        "tpu_compilation_cache_key.h",
+    ],
+    deps = ["@com_google_absl//absl/types:optional"],
+)
+
+cc_library(
+    name = "tpu_compile_op_support",
+    srcs = ["tpu_compile_op_support.cc"],
+    hdrs = ["tpu_compile_op_support.h"],
+    deps = [
+        ":tpu_compilation_cache_key",
+        ":tpu_compile_c_api_hdrs",
+        ":tpu_compile_proto_cc",
+        ":tpu_executable_info_proto_cc",
+        "//tensorflow/cc:ops",
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla:debug_options_flags",
+        "//tensorflow/compiler/xla:shape_tree",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:compile_only_client",
+        "//tensorflow/compiler/xla/service:computation_layout",
+        "//tensorflow/compiler/xla/service:dump",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/compiler/xla/service:hlo_module_group",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/framework:protos_all_cc",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "//tensorflow/stream_executor/tpu:proto_helper",
+        "//tensorflow/stream_executor/tpu:status_helper",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_entry",
+    hdrs = [
+        "tpu_compilation_cache_entry.h",
+    ],
+    deps = [
+        ":tpu_executable_info_proto_cc",
+        ":tpu_program",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "//tensorflow/core/lib/core:refcount",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_lookup",
+    srcs = ["tpu_compilation_cache_lookup.cc"],
+    hdrs = [
+        "tpu_compilation_cache_lookup.h",
+    ],
+    deps = [
+        ":tpu_compilation_cache_entry",
+        ":tpu_compilation_cache_external",
+        ":tpu_compilation_cache_proto_cc",
+        "//tensorflow/core/lib/core:refcount",
+        "//tensorflow/core/platform:status",
+        "//tensorflow/core/profiler/lib:traceme",
+    ],
+)
+
+cc_library(
+    name = "tpu_mesh_state_c_api",
+    hdrs = ["tpu_mesh_state_c_api.h"],
+)
+
+cc_library(
+    name = "tpu_mesh_state_interface",
+    srcs = [],
+    hdrs = ["tpu_mesh_state_interface.h"],
+    deps = [
+        ":tpu_compile_c_api_hdrs",
+        ":tpu_mesh_state_c_api",
+        "//tensorflow/compiler/xla/service",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "//tensorflow/core/tpu:tpu_config_c_api",
+    ],
+)
+
+cc_library(
+    name = "tpu_program",
+    srcs = ["tpu_program.cc"],
+    hdrs = ["tpu_program.h"],
+    deps = [
+        ":tpu_compile_c_api_hdrs",
+        ":tpu_compile_op_support",
+        ":tpu_compile_proto_cc",
+        ":tpu_executable_info_proto_cc",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla:xla_proto_cc",
+        "//tensorflow/compiler/xla/client:compile_only_client",
+        "//tensorflow/compiler/xla/service:computation_placer",
+        "//tensorflow/compiler/xla/service:hlo_module_group",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "//tensorflow/stream_executor/tpu:proto_helper",
+        "//tensorflow/stream_executor/tpu:status_helper",
+        "//tensorflow/stream_executor/tpu:tpu_platform_interface",
+        "@com_google_absl//absl/types:optional",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_external",
+    srcs = ["tpu_compilation_cache_external.cc"],
+    hdrs = [
+        "tpu_compilation_cache_external.h",
+    ],
+    deps = [
+        ":tpu_compilation_cache_entry",
+        ":tpu_compilation_cache_key",
+        ":tpu_compilation_cache_metrics",  # buildcleaner: keep
+        ":tpu_compilation_cache_metrics_hdrs",
+        ":tpu_compilation_cache_proto_cc",
+        ":tpu_compile_c_api_hdrs",
+        ":tpu_compile_op_support",
+        ":tpu_mesh_state_interface",
+        ":tpu_program",
+        ":tpu_util",
+        ":trace_util_hdrs",
+        "//tensorflow/compiler/xla/service",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/platform:refcount",
+        "//tensorflow/core/profiler/lib:traceme",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_metrics_hdrs",
+    hdrs = ["tpu_compilation_cache_metrics.h"],
+    deps = [
+        "//tensorflow/core/platform:types",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_metrics",
+    srcs = ["tpu_compilation_cache_metrics.cc"],
+    deps = [
+        ":tpu_compilation_cache_metrics_hdrs",
+    ],
+)
+
+cc_library(
+    name = "trace_util_hdrs",
+    srcs = [],
+    hdrs = ["trace_util.h"],
+    deps = [
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "tpu_util_hdrs",
+    srcs = [],
+    hdrs = ["tpu_util.h"],
+    deps = [
+        ":tpu_compilation_cache_key",
+        "//tensorflow/cc:ops",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/client:compile_only_client",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
+    name = "tpu_util",
+    srcs = ["tpu_util.cc"],
+    hdrs = ["tpu_util.h"],
+    deps = [
+        ":tpu_compilation_cache_key",
+        "//tensorflow/cc:ops",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/client:compile_only_client",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
+    ],
+    alwayslink = 1,
+)
+
+tf_proto_library_cc(
+    name = "tpu_compilation_cache_proto",
+    srcs = ["tpu_compilation_cache.proto"],
+    cc_api_version = 2,
+)
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
new file mode 100644
index 0000000..8308cba
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
@@ -0,0 +1,25 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// Target type for compilation cache fetch operation.
+enum CompilationCacheFetchTarget {
+  INVALID = 0;
+  MAIN = 1;
+  SHARDING = 2;
+  UNSHARDING = 3;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
new file mode 100644
index 0000000..d16b2d5
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
@@ -0,0 +1,84 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class CompilationCacheEntry {
+ public:
+  explicit CompilationCacheEntry(
+      std::unique_ptr<const TpuProgram> tpu_program)
+      : tpu_program_(std::move(tpu_program)) {}
+
+  // Constructor for an empty entry.
+  CompilationCacheEntry()
+      : tpu_program_(nullptr) {}
+
+  const TPUExecutableInfoProto* get_executable_info() const {
+    return &tpu_program_->executable_info();
+  }
+
+  const TPUHostTransferInfoProto* get_host_transfer_info() const {
+    return &tpu_program_->host_transfer_info();
+  }
+
+  const xla::HloProto* get_hlo_metadata() const {
+    return &tpu_program_->hlo_metadata();
+  }
+
+  // TODO(henrytan,jiawenhao): When should we expect more than one
+  // XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
+  const XLA_TpuProgram* get_tpu_program() const {
+    CHECK_EQ(tpu_program_->program_count(), 1);
+    return tpu_program_->tpu_programs()[0];
+  }
+
+ private:
+  std::unique_ptr<const TpuProgram> tpu_program_;
+};
+
+// Base class for a reference to a cached proto. A unique_ptr to a
+// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
+// and ensures the underlying proto is not garbage-collected until the client
+// discards the ptr.
+class CompilationCacheEntryRef {
+ public:
+  virtual ~CompilationCacheEntryRef() = default;
+
+  // Returns a CompilationCacheEntry that should not be used beyond the lifetime
+  // of the CompilationCacheEntryRef.
+  virtual CompilationCacheEntry get() = 0;
+};
+
+// Base class that holds references to compiled protos so that the protos are
+// not garbage-collected before being used by execute ops. Use
+// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
+// ref holder object.
+class CompilationRefHolder : public ResourceBase {
+ public:
+  ~CompilationRefHolder() override = default;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
new file mode 100644
index 0000000..8dbf608
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
@@ -0,0 +1,791 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/random.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+#include "tensorflow/core/tpu/kernels/tpu_util.h"
+#include "tensorflow/core/tpu/kernels/trace_util.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+
+using CompilationEntry = TpuCompilationCacheInterface::CompilationEntry;
+
+int64 get_uid() {
+  uint64 unsigned_rand = random::New64() & INT64_MAX;
+  return static_cast<int64>(unsigned_rand);
+}
+
+void PopulateEntry(const std::string& key, CompilationEntry* entry,
+                   std::unique_ptr<TpuProgram> tpu_program) {
+  // Make the unique keys for each cached proto.
+  for (int i = 0; i < tpu_program->program_count(); ++i) {
+    entry->proto_key.push_back(ProtoKeyForComputation(key, i));
+  }
+
+  entry->tpu_program = std::move(tpu_program);
+  entry->initialized = true;
+}
+
+std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
+  if (!key.has_guaranteed_const) {
+    return key.prefix;
+  }
+  return absl::StrCat(key.prefix, "|", key.session_handle, "|",
+                      key.guaranteed_const_fingerprint());
+}
+
+// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
+// data to compute the fingerprint.
+std::string GuaranteedConstFingerprint(
+    const string& fingerprint_in_metadata,
+    const OpInputList& guaranteed_constants) {
+  if (fingerprint_in_metadata.empty()) {
+    uint64_t fingerprint = 0;
+    for (const auto& constant : guaranteed_constants) {
+      fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
+          fingerprint, constant.tensor_data().data(),
+          constant.tensor_data().size());
+    }
+    return std::to_string(fingerprint);
+  } else {
+    return fingerprint_in_metadata;
+  }
+}
+
+std::string CreateShapePrefix(
+    const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
+  std::string shapes_prefix;
+  for (const TensorShape& shape : dynamic_shapes) {
+    for (int64 size : shape.dim_sizes()) {
+      absl::StrAppend(&shapes_prefix, size, ",");
+    }
+    absl::StrAppend(&shapes_prefix, ";");
+  }
+  return shapes_prefix;
+}
+
+// Include compilation configurations of the arguments that are not captured
+// by the called graph.
+std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
+  std::string config_prefix;
+  for (const auto& arg : metadata.args()) {
+    if (arg.is_same_data_across_replicas()) {
+      absl::StrAppend(&config_prefix, ":s");
+      // Same.
+    } else {
+      // Different.
+      absl::StrAppend(&config_prefix, ":");
+    }
+    if (arg.enable_xla_sharding() ==
+        tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
+      // Enabled.
+      absl::StrAppend(&config_prefix, "e");
+    }
+    if (arg.unrestricted_layout()) {
+      // Unrestricted.
+      absl::StrAppend(&config_prefix, ":u");
+    }
+    absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
+    if (arg.has_shape()) {
+      absl::StrAppend(&config_prefix, ",shape(");
+      for (const auto& dim : arg.shape().dim()) {
+        absl::StrAppend(&config_prefix, dim.size(), ",");
+      }
+      absl::StrAppend(&config_prefix, ")");
+    }
+  }
+  return config_prefix;
+}
+
+}  // namespace
+
+TpuCompilationCacheInterface::TpuCompilationCacheInterface(
+    int64_t max_cache_size)
+    : max_cache_size_(max_cache_size) {
+  if (max_cache_size < 0) {
+    LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
+  }
+  VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
+}
+
+TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
+  VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
+  // A buggy client may be holding onto a reference, or a client might have
+  // crashed while holding onto a reference. In either case, discard all
+  // outstanding client references to avoid leaking storage.
+  for (const auto& entry : entries_by_uid_) {
+    while (entry.second->external_references > 0) {
+      TF_CHECK_OK(Release(entry.first));
+    }
+  }
+  while (!entries_by_last_use_.empty()) {
+    UnloadAndDestroy(MarkOldestEntryForEviction());
+  }
+  // By the time the cache is deleted all reference holders should have already
+  // been deleted, since they were holding references to the cache. So all
+  // entries should be gone at this point.
+  CHECK_EQ(cache_store_.size(), 0);
+  CHECK_EQ(entries_by_uid_.size(), 0);
+  CHECK_EQ(entries_by_proto_key_.size(), 0);
+  CHECK_EQ(cache_size_, 0);
+  CHECK_EQ(marked_for_eviction_size_, 0);
+}
+
+std::string TpuCompilationCacheInterface::FindCacheKey(
+    const TpuCompilationCacheKey& subgraph_key) const {
+  if (!subgraph_key.has_guaranteed_const) {
+    return subgraph_key.prefix;
+  }
+  auto iter = session_key_map_.find(
+      strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
+  if (iter != session_key_map_.end()) {
+    return iter->second;
+  }
+  iter = fingerprint_key_map_.find(strings::StrCat(
+      subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
+  if (iter != session_key_map_.end()) {
+    return iter->second;
+  }
+  VLOG(1) << "No matching cache key found for key "
+          << ConstructCompilationCacheKey(subgraph_key);
+  return "";
+}
+
+void TpuCompilationCacheInterface::InsertEntry(
+    const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
+    CompilationEntry* entry) {
+  entry->parent = this;
+  entry->subgraph_key = cache_key;
+  entry->uid = get_uid();
+  TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
+  entry->cache_entry_debug_string = subgraph_key.prefix;
+  VLOG(1) << "Cache Initializing Entry Session Debug "
+          << entry->cache_entry_debug_string;
+
+  if (!subgraph_key.has_guaranteed_const) {
+    return;
+  }
+  session_key_map_.insert(std::make_pair(
+      strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
+      cache_key));
+  fingerprint_key_map_.insert(std::make_pair(
+      strings::StrCat(subgraph_key.prefix,
+                      subgraph_key.guaranteed_const_fingerprint()),
+      cache_key));
+}
+
+CompilationEntry* TpuCompilationCacheInterface::InitializeEntry(
+    const string& key,
+    const std::function<Status(TpuProgram*)>& initialize_program,
+    const TpuCompilationCacheKey& subgraph_key) {
+  CompilationEntry* main_entry = new CompilationEntry();
+
+  // Add the entry to the cache, with size zero since there are no compiled
+  // programs in it. Once the subgraph has been compiled,
+  // UpdateEntryAfterCompilation will be called to potentially mark old entries
+  // that don't fit any more for eviction.
+  //
+  // At this point there is one reference to entry, which is owned by the caller
+  // who created the entry. A second reference, owned by the cache, will be
+  // added below since we leave the entry in the 'marked for eviction' state
+  // here.
+  InsertEntry(key, subgraph_key, main_entry);
+
+  // Initialize the programs outside the lock so that other cache operations
+  // can proceed during the (potentially lengthy) initialization.
+  Status initialization_status;
+
+  auto tpu_program = absl::make_unique<TpuProgram>();
+  {
+    mu_.Unlock();
+    {
+      profiler::TraceMe compile_programs_traceme(
+          "TPU compilation cache compile",
+          /*level=*/2);
+      initialization_status = initialize_program(tpu_program.get());
+    }
+    mu_.Lock();
+  }
+
+  main_entry->initialization_status = initialization_status;
+
+  // Add the entry to the uid index.
+  auto uid_inserted = entries_by_uid_.insert(
+      std::pair<int64, CompilationEntry*>(main_entry->uid, main_entry));
+  CHECK(uid_inserted.second);
+
+  if (initialization_status.ok()) {
+    // Compute the entries total size once all members are initialized.
+    main_entry->total_size = tpu_program->program_size();
+  }
+
+  // TODO(henrytan): handle sharding/unsharding.
+  PopulateEntry(key, main_entry, std::move(tpu_program));
+
+  for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
+    auto entry_inserted = entries_by_proto_key_.insert(
+        std::pair<string, std::pair<CompilationEntry*, int>>(
+            main_entry->proto_key[i], std::make_pair(main_entry, i)));
+    CHECK(entry_inserted.second);
+  }
+
+  // Add the size to marked_for_eviction_size_ since it will be adjusted down
+  // again when the newly-created entry gets unmarked.
+  marked_for_eviction_size_ += main_entry->total_size;
+  return main_entry;
+}
+
+/*static*/ TpuCompilationCacheKey
+TpuCompilationCacheInterface::CreateCompilationCacheKey(
+    absl::string_view function_name, uint64 function_library_fingerprint,
+    absl::string_view mlir_module,
+    const tensorflow::OpInputList& guaranteed_constants,
+    const std::vector<tensorflow::TensorShape>& dynamic_shapes,
+    const tensorflow::tpu::TPUCompileMetadataProto& metadata,
+    const TpuMeshStateInterface& mesh_state) {
+  VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
+  std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
+  VLOG(1) << "shapes_prefix = " << shapes_prefix;
+  std::string config_prefix = CreateConfigPrefix(metadata);
+  VLOG(1) << "config_prefix = " << config_prefix;
+  std::vector<int32_t> flattened_device_ids;
+  if (metadata.has_device_assignment()) {
+    for (const auto& device :
+         metadata.device_assignment().computation_devices()) {
+      flattened_device_ids.insert(flattened_device_ids.end(),
+                                  device.replica_device_ids().begin(),
+                                  device.replica_device_ids().end());
+    }
+  }
+  // TODO(henrytan): return the debug_string.
+  const char* prefix =
+      TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty{
+          config_prefix.data(),
+          shapes_prefix.data(),
+          function_name.data(),
+          mlir_module.data(),
+          flattened_device_ids.data(),
+          flattened_device_ids.size(),
+          guaranteed_constants.size(),
+          function_library_fingerprint,
+          metadata.num_cores_per_replica(),
+          metadata.num_replicas(),
+          mesh_state.data(),
+      });
+  auto buffer_cleanup = gtl::MakeCleanup([prefix]() { delete[] prefix; });
+  TpuCompilationCacheKey key;
+  key.prefix = prefix;
+
+  // Guaranteed constants can be different across sessions. Use session_handle
+  // and guaranteed_const fingerprint to guarantee no collision.
+  if (guaranteed_constants.size() > 0) {
+    key.has_guaranteed_const = true;
+    key.session_handle = metadata.session_handle();
+    // Both `metadata` and `guaranteed_constants` lifetime are captured by
+    // reference based on the assumption that these variables lifetime is
+    // managed through the `TPUCompileOpKernelImpl` that outlives the
+    // lifetime of the compilation cache lookups.
+    string fingerprint;
+    key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
+                                        fingerprint]() mutable {
+      if (fingerprint.empty()) {
+        fingerprint = GuaranteedConstFingerprint(
+            metadata.guaranteed_const_fingerprint(), guaranteed_constants);
+      }
+      return fingerprint;
+    };
+  }
+  return key;
+}
+
+TpuCompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
+  return new RefHolder(this);
+}
+
+Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) {
+  profiler::TraceMe key_release_traceme(
+      "TPU compilation cache possibly evict uid",
+      /*level=*/2);
+  CompilationEntry* deleted_entry = nullptr;
+  {
+    absl::MutexLock lock(&mu_);
+    auto iter = entries_by_uid_.find(subgraph_uid);
+    if (iter == entries_by_uid_.end()) {
+      // If already evicted, return ok.
+      return Status::OK();
+    }
+
+    // Mark entry for eviction.
+    CompilationEntry* subgraph_to_evict = iter->second;
+    // If there are external references, should not use this API.
+    if (subgraph_to_evict->external_references != 0) {
+      return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
+                              " external_references greater than zero. Should "
+                              "use TpuCompilationCache::Release.");
+    }
+
+    VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
+    entries_by_last_use_.erase(subgraph_to_evict->last_use);
+    cache_size_ -= subgraph_to_evict->total_size;
+    marked_for_eviction_size_ += subgraph_to_evict->total_size;
+
+    // Evict if refcount exactly one, otherwise only discard cache's reference
+    // to the entry while the actual eviction will happen when refholder's
+    // references go away.
+    deleted_entry = DiscardEntryRef(subgraph_to_evict);
+
+    VLOG(1) << "After possibly evicting entry " << subgraph_uid
+            << " refs cache is " << cache_store_.size() << " entries ("
+            << cache_size_ + marked_for_eviction_size_
+            << " bytes), marked for eviction "
+            << (cache_store_.size() - entries_by_last_use_.size())
+            << " entries (" << marked_for_eviction_size_ << " bytes).";
+  }
+
+  // Unload from device cache if entry is evicted from host cache.
+  UnloadAndDestroy(deleted_entry);
+  return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::Release(int64 subgraph_uid) {
+  profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
+                                        /*level=*/2);
+
+  CompilationEntry* deleted_entry = nullptr;
+  {
+    absl::MutexLock lock(&mu_);
+    auto iter = entries_by_uid_.find(subgraph_uid);
+
+    if (iter == entries_by_uid_.end()) {
+      return errors::NotFound("No cache entry found for uid ", subgraph_uid);
+    }
+
+    CHECK_GT(iter->second->external_references, 0);
+    --iter->second->external_references;
+
+    deleted_entry = DiscardEntryRef(iter->second);
+
+    VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
+            << cache_store_.size() << " entries ("
+            << cache_size_ + marked_for_eviction_size_
+            << " bytes), marked for eviction "
+            << (cache_store_.size() - entries_by_last_use_.size())
+            << " entries (" << marked_for_eviction_size_ << " bytes).";
+  }
+  UnloadAndDestroy(deleted_entry);
+  return Status::OK();
+}
+
+void TpuCompilationCacheInterface::UnloadAndDestroy(CompilationEntry* entry) {
+  if (!entry) return;
+
+  CHECK(entry->RefCountIsOne());
+  entry->tpu_program->UnloadAndDestroyPrograms();
+  entry->Unref();
+}
+
+size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
+  auto erased = cache_store_.erase(key);
+  TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
+  auto parsed_key_or_status = ParseCompilationCacheKey(key);
+  CHECK(parsed_key_or_status.status().ok());
+  const TpuCompilationCacheKey parsed_key =
+      parsed_key_or_status.ConsumeValueOrDie();
+  if (!parsed_key.has_guaranteed_const) {
+    return erased;
+  }
+  session_key_map_.erase(
+      strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
+  fingerprint_key_map_.erase(strings::StrCat(
+      parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
+  return erased;
+}
+
+ABSL_MUST_USE_RESULT CompilationEntry*
+TpuCompilationCacheInterface::DiscardEntryRef(CompilationEntry* entry) {
+  if (entry->RefCountIsOne()) {
+    // The last reference to this entry is going away, so really delete it from
+    // the cache in such a way that it can't be restored by being looked up
+    // again.
+
+    // Sanity-check that it has been marked for eviction.
+    CHECK(entries_by_last_use_.find(entry->last_use) ==
+          entries_by_last_use_.end());
+    // Update the counter tracking how much space is taken up by entries that
+    // are marked for eviction.
+    marked_for_eviction_size_ -= entry->total_size;
+
+    // Remove the entry from the cache.
+    auto erased = RemoveEntry(entry->subgraph_key);
+
+    if (erased == 0) {
+      LOG(FATAL) << "Tried to discard nonexistent cache entry";
+    }
+    erased = entries_by_uid_.erase(entry->uid);
+    CHECK_EQ(erased, 1);
+    for (const string& key : entry->proto_key) {
+      erased = entries_by_proto_key_.erase(key);
+      CHECK_EQ(erased, 1);
+    }
+    // The actual deletion will happen outside the lock in UnloadAndDestroy().
+    return entry;
+  }
+  entry->Unref();
+  return nullptr;
+}
+
+void TpuCompilationCacheInterface::DiscardEntryRefs(
+    gtl::ArraySlice<CompilationEntry*> entries) {
+  std::vector<CompilationEntry*> removed_entries;
+  {
+    absl::MutexLock lock(&mu_);
+
+    for (auto entry : entries) {
+      removed_entries.push_back(DiscardEntryRef(entry));
+    }
+
+    VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
+            << " entries (" << cache_size_ + marked_for_eviction_size_
+            << " bytes), marked for eviction "
+            << (cache_store_.size() - entries_by_last_use_.size())
+            << " entries (" << marked_for_eviction_size_ << " bytes).";
+  }
+  for (auto removed_entry : removed_entries) {
+    UnloadAndDestroy(removed_entry);
+  }
+}
+
+ABSL_MUST_USE_RESULT CompilationEntry*
+TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
+  CompilationEntry* entry_to_mark = entries_by_last_use_.begin()->second;
+  VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
+  entries_by_last_use_.erase(entry_to_mark->last_use);
+  cache_size_ -= entry_to_mark->total_size;
+  marked_for_eviction_size_ += entry_to_mark->total_size;
+  // Discard the cache's reference to entry. If steps are holding onto
+  // references to entry it won't be deleted until the last step holding it
+  // completes. It stays in the cache in the meantime and can be resurrected
+  // by a call to CompileIfKeyAbsent if that occurs before the last reference
+  // expires.
+  return DiscardEntryRef(entry_to_mark);
+}
+
+void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
+    CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries) {
+  // The entry was previously marked for eviction (or is newly created) so
+  // unmark it. Add a reference (owned by the cache), update the cache size, and
+  // mark something old for eviction if necessary.
+  entry->Ref();
+  marked_for_eviction_size_ -= entry->total_size;
+  cache_size_ += entry->total_size;
+
+  // Mark the least-recently-used non-marked entry for eviction. Never mark the
+  // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
+  // which means there's only one entry not already marked for eviction), so
+  // that an entry persists in the cache even if it is larger than the allocated
+  // cache size.
+  while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
+    if (auto entry_to_evict = MarkOldestEntryForEviction()) {
+      removed_entries->push_back(entry_to_evict);
+    }
+  }
+}
+
+Status TpuCompilationCacheInterface::ToSubEntryRef(
+    CompilationCacheEntryRef* entry,
+    CompilationCacheFetchTarget fetch_target) const {
+  return static_cast<EntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
+}
+
+TpuCompilationCacheInterface::EntryRefImpl::EntryRefImpl(
+    TpuCompilationCacheInterface* parent, CompilationEntry* entry, int index)
+    : parent_(parent), entry_(entry), index_(index) {
+  if (entry_ == nullptr) {
+    return;
+  }
+  if (entry_->main_entry == nullptr) {
+    entry_->Ref();
+  } else {
+    // This is a sharding/unsharding entry nested in a main entry. Only refcount
+    // the main entry.
+    entry_->main_entry->Ref();
+  }
+}
+
+TpuCompilationCacheInterface::EntryRefImpl::~EntryRefImpl() {
+  if (entry_ == nullptr) {
+    return;
+  }
+  if (entry_->main_entry == nullptr) {
+    parent_->DiscardEntryRefs({entry_});
+  } else {
+    parent_->DiscardEntryRefs({entry_->main_entry});
+  }
+}
+
+CompilationCacheEntry TpuCompilationCacheInterface::EntryRefImpl::get() {
+  if (entry_ == nullptr) {
+    // Create an empty entry if the entry is nullptr. This corresponds to
+    // non-existing sharding/unsharding entries.
+    return CompilationCacheEntry();
+  }
+  return CompilationCacheEntry(std::move(entry_->tpu_program));
+}
+
+Status TpuCompilationCacheInterface::EntryRefImpl::ToSubEntryRef(
+    CompilationCacheFetchTarget fetch_target) {
+  CompilationEntry* target = nullptr;
+  switch (fetch_target) {
+    case CompilationCacheFetchTarget::MAIN:
+      target = entry_;
+      break;
+    case CompilationCacheFetchTarget::SHARDING:
+      target = entry_->sharding_entry.get();
+      break;
+    case CompilationCacheFetchTarget::UNSHARDING:
+      target = entry_->unsharding_entry.get();
+      break;
+    default:
+      return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
+  }
+
+  if (target == nullptr) {
+    // Cache entry does not have an unsharding subentry. Unref and replace
+    // with nullptr.
+    parent_->DiscardEntryRefs({entry_});
+  }
+  // Otherwise, since the refcount is always on the main entry, we don't need
+  // ref/unref.
+  entry_ = target;
+  return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::Lookup(
+    int64 uid, int proto_index,
+    std::unique_ptr<CompilationCacheEntryRef>* entry) {
+  entry->reset();
+
+  profiler::TraceMe proto_lookup_traceme(
+      "TPU compilation cache proto lookup by uid",
+      /*level=*/2);
+
+  absl::MutexLock lock(&mu_);
+  const auto iter = entries_by_uid_.find(uid);
+  if (iter == entries_by_uid_.end()) {
+    return errors::NotFound("No subgraph found for uid ", uid);
+  }
+  CompilationEntry* cache_entry = iter->second;
+  if (proto_index < 0 ||
+      proto_index >= cache_entry->tpu_program->program_size()) {
+    return errors::NotFound("No proto found for core index ", proto_index,
+                            " in subgraph with uid ", uid);
+  }
+  *entry = std::unique_ptr<CompilationCacheEntryRef>(
+      new EntryRefImpl(this, cache_entry, proto_index));
+  return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::Lookup(
+    const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
+  entry->reset();
+
+  profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
+                                         /*level=*/2);
+
+  absl::MutexLock lock(&mu_);
+  const auto iter = entries_by_proto_key_.find(proto_key);
+  if (iter == entries_by_proto_key_.end()) {
+    return errors::NotFound("No proto found for key ", proto_key);
+  }
+  CompilationEntry* cache_entry = iter->second.first;
+  int proto_index = iter->second.second;
+  *entry = std::unique_ptr<CompilationCacheEntryRef>(
+      new EntryRefImpl(this, cache_entry, proto_index));
+  return Status::OK();
+}
+
+Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
+    const TpuCompilationCacheKey& subgraph_key,
+    const SessionMetadata* session_metadata,
+    TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+    std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+    std::vector<CompilationEntry*>* removed_entries,
+    std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+    const std::function<Status(TpuProgram*)>& compile_function) {
+  profiler::TraceMe subgraph_lookup_traceme(
+      "TPU compilation cache subgraph lookup",
+      /*level=*/2);
+
+  // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
+  // for the lifetime of the object, see InitializeEntry() call below.
+  absl::MutexLock lock(&mu_);
+
+  std::string cache_key = FindCacheKey(subgraph_key);
+  auto iter = cache_store_.find(cache_key);
+  bool is_new_key = iter == cache_store_.end();
+
+  const std::string session_name = SessionNameFromMetadata(session_metadata);
+
+  CompilationEntry* entry = nullptr;
+  if (is_new_key) {
+    cache_key = ConstructCompilationCacheKey(subgraph_key);
+    TpuCompilationCacheMetrics::IncrementCacheLookupCount(
+        /*is_cache_hit=*/false, session_name);
+    const string msg =
+        strings::StrCat("TPU host compilation cache miss: cache_key(",
+                        cache_key, "), session_name(", session_name, ")");
+
+    TRACESTRING(msg);
+    LOG(INFO) << msg;
+
+    // Check if caller has disabled compilation. Set using
+    // internal::ScopedTpuCompileDisabler.
+    if (!IsTpuCompilationEnabled()) {
+      const string error_msg = strings::StrCat(
+          "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
+          "disabled, session_name(",
+          session_name, ") Debug String: ", subgraph_key.debug_string);
+      if (VLOG_IS_ON(2)) {
+        VLOG(2) << "Cache Missed. Current cache entries: ";
+        for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
+          // TODO(henrytan): add DebugKey as cache_entry_debug_string to
+          // TpuCompilationCacheKey.
+          VLOG(2) << "Cache Debug Info: ";
+          VLOG(2) << it->second->cache_entry_debug_string;
+        }
+      }
+
+      LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
+      return errors::NotFound(error_msg);
+    }
+
+    // The single ref on the newly-created entry is owned by the caller.
+    VLOG(1) << "Before adding new entry for key " << cache_key
+            << " with session_name( " << session_name << ");"
+            << "; cache is " << cache_store_.size() << " entries ("
+            << cache_size_ + marked_for_eviction_size_ << " bytes), "
+            << " marked for eviction "
+            << (cache_store_.size() - entries_by_last_use_.size())
+            << " entries (" << marked_for_eviction_size_ << " bytes).";
+    // Note that InitializeEntry() will Release/Reacquire mu_.
+    entry = InitializeEntry(cache_key, compile_function, subgraph_key);
+    TRACELITERAL("TPU host compilation cache: compilation done.");
+
+    LOG(INFO) << strings::StrCat(
+        "TPU host compilation cache: compilation done for cache_key(",
+        cache_key, "), session_name(", session_name, ")");
+    // If session_name is present, log some additional stats related to HBM
+    // here, so that they can be associated directly to the session.
+    if (!session_name.empty()) {
+      entry->tpu_program->LogProgramMemorySummary();
+    }
+  } else {
+    TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
+    const string msg =
+        strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
+                        "), session_name(", session_name, ")");
+    TRACESTRING(msg);
+    VLOG(1) << msg;
+    VLOG(1) << "Before refreshing entry for key " << cache_key
+            << " with session_name( " << session_name << "); cache is "
+            << cache_store_.size() << " entries ("
+            << cache_size_ + marked_for_eviction_size_ << " bytes), "
+            << " marked for eviction "
+            << (cache_store_.size() - entries_by_last_use_.size())
+            << " entries (" << marked_for_eviction_size_ << " bytes).";
+    entry = iter->second;
+    // Make a new reference that is owned by the caller.
+    entry->Ref();
+    // Block if necessary until the subgraph has been initialized.
+    mu_.Await(absl::Condition(
+        +[](CompilationEntry* e) { return e->initialized; }, entry));
+  }
+
+  // Let the caller know the uid of the entry.
+  *uid = entry->uid;
+  // Let the caller know the keys for each of the cached protos.
+  *proto_key = entry->proto_key;
+  *may_modify_variables = entry->tpu_program->may_modify_variables();
+  *hlo_metadata = entry->hlo_metadata;
+
+  // If the caller didn't supply a per_step_ref_holder then the caller is going
+  // to manually release the reference later via a call to Release().
+  if (per_step_ref_holder == nullptr) {
+    ++entry->external_references;
+  } else {
+    // The caller wants its reference to be handed off to a per-step holder that
+    // will discard the reference when the step completes.
+    RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
+    TF_RET_CHECK(cast_ref_holder != nullptr);
+    cast_ref_holder->AddRef(entry);
+  }
+
+  // Remove the old LRU-table entry if it wasn't already marked for eviction.
+  auto erased = entries_by_last_use_.erase(entry->last_use);
+  // Update the LRU table indicating this entry is the most recently used.
+  entry->last_use = use_counter_++;
+  entries_by_last_use_[entry->last_use] = entry;
+  if (erased == 0) {
+    // The entry had been marked for eviction, or is newly created.
+    LookupEntryMarkedForEviction(entry, removed_entries);
+  }
+
+  // Log a little more verbosely when a key is added.
+  if (VLOG_IS_ON(1) || is_new_key) {
+    LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
+              << " entry for key " << cache_key << " with session_name "
+              << session_name << " cache is " << cache_store_.size()
+              << " entries (" << cache_size_ + marked_for_eviction_size_
+              << " bytes), "
+              << " marked for eviction "
+              << (cache_store_.size() - entries_by_last_use_.size())
+              << " entries (" << marked_for_eviction_size_ << " bytes).";
+  }
+  return entry->initialization_status;
+}
+
+tensorflow::Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
+    const TpuCompilationCacheKey& cache_key,
+    const tensorflow::SessionMetadata* session_metadata,
+    TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+    std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+    std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+    const std::function<tensorflow::Status(TpuProgram*)>& compile_function) {
+  std::vector<CompilationEntry*> removed_entries;
+  auto status = CompileIfKeyAbsentHelper(
+      cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
+      may_modify_variables, &removed_entries, hlo_metadata, compile_function);
+  for (auto entry : removed_entries) {
+    UnloadAndDestroy(entry);
+  }
+  return status;
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
new file mode 100644
index 0000000..b6cdbe9
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
@@ -0,0 +1,394 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
+
+#include <functional>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/refcount.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+
+namespace tensorflow {
+namespace tpu {
+
+const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
+const char kCompilationCacheUnloaderResourceName[] =
+    "tpu_compilation_cache_unloader";
+
+// Base class that holds references to compiled protos so that the protos are
+// not garbage-collected before being used by execute ops. Use
+// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
+// ref holder object.
+class TpuCompilationRefHolder : public ResourceBase {
+ public:
+  ~TpuCompilationRefHolder() override = default;
+};
+
+class TpuCompilationCacheInterface : public ResourceBase {
+ public:
+  using Status = ::stream_executor::port::Status;
+
+  // An entry in the compilation cache. The entry is deleted once it has been
+  // marked for eviction from the cache _and_ all steps that use it have
+  // completed. When the entry is first created, it is uninitialized and a
+  // client-supplied compilation function is run outside the cache's lock to
+  // generate the programs to be stored in the entry. Any other client that
+  // requests the entry will block until it has been initialized. Each entry has
+  // a last_use value that set from a monotonically-increasing counter in the
+  // cache whenever the entry is referenced. When the cache becomes full,
+  // entries are marked for eviction in LRU order.
+  //
+  // The bridge can request XLA to generate separate sharding and unsharding
+  // programs along with the main program; we use nested fields sharding_entry,
+  // unsharding_entry to store them under the main entry, and these two fields
+  // must be either both present or both absent. They have a back pointer
+  // main_entry to refer to the main program. These nested entries share the
+  // same cache key and the same lifetime as the main entry, so we use the
+  // refcount on the main entry to track the access to any of them.
+  //             /-------------------------------\
+  //             v                                \
+  //   main_entry (refcount) -> sharding_entry -> main_entry
+  //             ^           \
+  //             |            \-> unsharding_entry -> main_entry
+  //             \--------------------------------------/
+  struct CompilationEntry : public core::RefCounted  {
+    TpuCompilationCacheInterface* parent = nullptr;  // Not owned.
+    bool initialized = false;
+
+    // The Status returned by the compilation function when the entry is
+    // initialized. This status will be returned to any client that requests the
+    // entry.
+    Status initialization_status;
+
+    // The uid describing this entry.
+    int64 uid;
+    std::vector<string> proto_key;
+
+    // Counter to keep track of LRU entries for the eviction policy.
+    int64 last_use = -1;
+
+    // The unique key describing this entry.
+    std::string subgraph_key;
+
+    // Entries representing the associated sharding and unsharding programs,
+    // which share the same life time of the owning main entry, so we always use
+    // the main entry's ref count.
+    std::unique_ptr<CompilationEntry> sharding_entry;
+    std::unique_ptr<CompilationEntry> unsharding_entry;
+
+    // The number of 'external' client-held references to the entry.
+    int external_references = 0;
+
+    std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadata;
+
+    // The sum of the SpaceUsed of each of the elements of programs; an estimate
+    // of how much RAM the entry consumes, used to determine when entries must
+    // be marked for eviction.
+    int64 total_size = 0;
+
+    // Only used for the nested sharding/unsharding entries to point to the
+    // owning main entry.
+    CompilationEntry* main_entry = nullptr;
+
+    // Debug info in case we miss.
+    string cache_entry_debug_string;
+
+    // Compiled Tpu program.
+    std::unique_ptr<TpuProgram> tpu_program;
+  };
+
+  explicit TpuCompilationCacheInterface(int64_t max_cache_size);
+  ~TpuCompilationCacheInterface() override;
+  TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
+  TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&)
+      = delete;
+
+  Status CompileIfKeyAbsent(
+      const TpuCompilationCacheKey& cache_key,
+      const SessionMetadata* session_metadata,
+      TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+      std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+      std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+      const std::function<tensorflow::Status(TpuProgram*)>& compile_function);
+
+  static TpuCompilationCacheKey CreateCompilationCacheKey(
+      absl::string_view function_name, uint64 function_library_fingerprint,
+      absl::string_view mlir_module,
+      const tensorflow::OpInputList& guaranteed_constants,
+      const std::vector<tensorflow::TensorShape>& dynamic_shapes,
+      const tensorflow::tpu::TPUCompileMetadataProto& metadata,
+      const TpuMeshStateInterface& mesh_state);
+
+  string DebugString() const override { return "TpuCompilationCacheInterface"; }
+
+  // Makes a reference holder for this cache, that can be stored in the per-step
+  // resource manager and will ensure that compiled entries persist until the
+  // end of a step.
+  TpuCompilationRefHolder* MakePerStepRefHolder();
+
+  // Differences between MarkEntryForEviction and Release:
+  // There are two modes of managing cache entries:
+  // 1) LRU eviction + pinning; 2) manual.
+  // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
+  // Otherwise it is manual mode (mainly used by XRT).
+  // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
+  // entries when callers know that they do not need them anymore.
+  // Release should only be used in mode 2) to explicitly remove an entry.
+
+  // Mark the entry indexed by `subgraph_uid` for eviction. This should only be
+  // called if per_step_ref_holder was NOT nullptr in the corresponding call to
+  // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
+  // subgraph_uid).
+  Status MarkEntryForEviction(int64 subgraph_uid);
+
+  // Manually discards a reference to the compiled subgraph. This should only be
+  // called if per_step_ref_holder was nullptr in the corresponding call to
+  // CompileIfKeyAbsent(subgraph_key, ...).
+  Status Release(int64 subgraph_uid);
+
+  // Looks up an executable corresponding to the model-parallel core index of
+  // the subgraph represented by key. On success a pointer to an EntryRef
+  // holding the program is returned in entry.
+  Status Lookup(const string& proto_key,
+                std::unique_ptr<CompilationCacheEntryRef>* entry);
+
+  // Looks up an executable corresponding to the model-parallel core index of
+  // the subgraph represented by uid. On success a pointer to an EntryRef
+  // holding the program is returned in entry.
+  Status Lookup(int64 uid, int proto_index,
+                std::unique_ptr<CompilationCacheEntryRef>* entry);
+
+  // Mutates the main entry ref to point to the entry's subentry
+  // (for sharding/unsharding) or main entry (unchanged) representing the
+  // fetch target. The entry ref needs to point to the main entry before this
+  // call.
+  //
+  // If the requested subentry does not exist, the ref will point to a nullptr
+  // entry.
+  Status ToSubEntryRef(CompilationCacheEntryRef* entry,
+                       CompilationCacheFetchTarget fetch_target) const;
+
+ private:
+  // Wrapper for a cache entry that holds a reference to the entry until the
+  // wrapper is deleted. This wrapper is the concrete type of
+  // CompilationCacheEntryRef returned by Lookup.
+  class EntryRefImpl : public CompilationCacheEntryRef {
+   public:
+    EntryRefImpl(TpuCompilationCacheInterface* parent, CompilationEntry* entry,
+                 int index);
+    ~EntryRefImpl() override;
+
+    CompilationCacheEntry get() override;
+
+    // Mutates this ref to point to the entry's subentry (for
+    // sharding/unsharding) or main entry (unchanged) as specified by
+    // fetch_target. The refcount is kept unchanged, since we only track the
+    // refcount of the main entry. The entry ref needs to point to the main
+    // entry before this call.
+    //
+    // If the requested subentry does not exist, the ref will point to a nullptr
+    // entry, and the original entry will be unref'ed.
+    Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
+
+   private:
+    TpuCompilationCacheInterface* parent_;  // Not owned.
+    // A reference to entry_ is acquired in the constructor and released via
+    // parent->DiscardEntryRefs in the destructor.
+    CompilationEntry* entry_;
+    // The program in entry_ that is returned by the get method.
+    int index_;
+  };
+
+  // Private implementation of the generic CompilationRefHolder that knows about
+  // CompiledSubgraph entries.
+  class RefHolder : public TpuCompilationRefHolder {
+   public:
+    explicit RefHolder(TpuCompilationCacheInterface* parent) : parent_(parent) {
+      parent_->Ref();
+    }
+    ~RefHolder() override {
+      // Release our reference to the parent.
+      parent_->Unref();
+    }
+
+    // Adds entry to the list of entries that will be released when the
+    // RefHolder is destroyed. Each entry is released via a call to
+    // parent_->DiscardEntryRefs.
+    void AddRef(CompilationEntry* entry) {
+      entries_.push_back(entry);
+    }
+
+    string DebugString() const override {
+      return "TpuCompilationCacheInterface::RefHolder";
+    }
+
+   private:
+    TpuCompilationCacheInterface* parent_;  // Not owned.
+    std::vector<CompilationEntry*> entries_;
+  };
+
+  // The bulk of implementation of CompileIfKeyAbsent() with the exception
+  // of unloading programs that corresponds to possibly removed cache
+  // entries. The split helps to manage locking since we prefer to perform
+  // unloading without holding extra locks.
+  Status CompileIfKeyAbsentHelper(
+      const TpuCompilationCacheKey& subgraph_key,
+      const SessionMetadata* session_metadata,
+      TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
+      std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+      std::vector<CompilationEntry*>* removed_entries,
+      std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
+      const std::function<Status(TpuProgram*)>& compile_function);
+
+  // This is called by the cache when entry is marked for eviction; by
+  // a RefHolder (via DiscardEntryRefs) when a step completes; and by
+  // an EntryRefImpl when it is destroyed. Releases one reference to entry
+  // if more than 1 remains. If only one reference is left, the entry is removed
+  // from cache_ and is returned to the caller; which must eventually call
+  // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
+  // to avoid holding the lock during program unloading.
+  ABSL_MUST_USE_RESULT CompilationEntry* DiscardEntryRef(
+      CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+  // Convenience method called by ~RefHolder without mu_ held. Calls
+  // DiscardEntryRef on every element of entries.
+  void DiscardEntryRefs(
+    gtl::ArraySlice<CompilationEntry*> entries);
+
+  // Marks the oldest unmarked entry for eviction. Requires that there is at
+  // least one such entry. In case the evicted entry had only 1 reference it
+  // is removed from the cache and returned to the caller which must eventually
+  // call UnloadAndDestroy.
+  CompilationEntry* MarkOldestEntryForEviction()
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Updates datastructures to indicate that entry, which had been marked for
+  // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
+  // entry is newly created, or an entry that has been marked for eviction but
+  // not yet evicted is looked up.
+  //
+  // First the entry is unmarked for eviction, i.e. the cache gains a reference
+  // to entry, entry's last_use field is set to be the most recent value of
+  // use_counter_ and entries_by_last_use_ is updated accordingly.
+  //
+  // Next, the size of the cache is examined to see if any other entries need to
+  // be marked for eviction now that entry has been unmarked. While the total
+  // size of unmarked cached entries is greater than max_cache_size_, entries
+  // are marked for eviction in LRU order. The most recently used entry is never
+  // marked for eviction, so an entry larger than the max cache size will remain
+  // in the cache until it is replaced by something else. In case some entries
+  // actually were removed from the cache, they are a returned to the caller via
+  // removed_entries. The caller must eventually delete them by calling
+  // UnloadAndDestroy.
+  void LookupEntryMarkedForEviction(
+    CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Removes the entry with given key from cache.
+  size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Inserts the given key and entry to cache.
+  void InsertEntry(const std::string& key,
+                   const TpuCompilationCacheKey& subgraph_key,
+                   CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Returns the cache key matching given subgraph_key.
+  std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Creates a new entry by running initialize_programs and places it in the
+  // cache to be looked up by key. The new entry is in the 'marked for eviction'
+  // state (not present in entries_by_last_use_) and the caller is expected to
+  // call LookupEntryMarkedForEviction after InitializeEntry.
+  //
+  // **InitializeEntry releases mu_ during the call to initialize_programs.**
+  CompilationEntry* InitializeEntry(
+    const string& key,
+    const std::function<Status(TpuProgram*)>& initialize_program,
+    const TpuCompilationCacheKey& subgraph_key)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Unloads the program associated with the entry from all local devices
+  // and deletes the entry itself. It is assumed no one else has a reference
+  // to it and all related keys had already been removed from the cache.
+  // The call can perform device IO so no locks should be held while calling it.
+  void UnloadAndDestroy(CompilationEntry* entry) ABSL_LOCKS_EXCLUDED(mu_);
+
+  // The maximum size of entries that are stored in the cache before entries are
+  // marked for eviction.
+  const int64 max_cache_size_;
+
+  mutable absl::Mutex mu_;
+  // The total size of entries that are stored and not marked for eviction.
+  int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
+
+  // The total size of entries that are marked for eviction.
+  int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
+
+  // The value to assign to the last_use field of the next entry that is looked
+  // up.
+  int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
+
+  // session_key_map_ and fingerprint_key_map_ are used for looking up the
+  // cache_ key matching a given subgraph key. When doing a lookup, check
+  // session_key_map_ first to avoid unnecessay fingerprint computation.
+  // Map from key prefix + session_handle to a cache_ key.
+  std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
+
+  // Map from key prefix + fingerprint to a cache_ key.
+  std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
+
+  // All the subgraph entries that can be looked up in the cache. An entry is
+  // marked for eviction iff it is present in cache_ and not in
+  // entries_by_last_use_.
+  std::unordered_map<string, CompilationEntry*> cache_store_
+      ABSL_GUARDED_BY(mu_);
+
+  // All the subgraph entries that can be looked up in the cache, indexed by
+  // uid.
+  absl::node_hash_map<int64, CompilationEntry*> entries_by_uid_
+      ABSL_GUARDED_BY(mu_);
+
+  // All the protos that can be looked up in the cache, indexed by proto
+  // key. The value of the map is a subgraph and the index of the proto compiled
+  // for that subgraph.
+  std::unordered_map<string, std::pair<CompilationEntry*, int>>
+      entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
+
+  // Map from last_use to entry, used to mark entries for eviction in LRU
+  // order. If an entry's last_use counter is not present as a key in
+  // entries_by_last_use_ then the entry has been marked for eviction.
+  std::map<int64, CompilationEntry*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h
new file mode 100644
index 0000000..49c2eb6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
+
+#include <functional>
+#include <string>
+
+#include "absl/types/optional.h"
+
+namespace tensorflow {
+namespace tpu {
+
+struct TpuCompilationCacheKey {
+  // Prefix of the key.
+  std::string prefix;
+
+  // A boolean flag to specify if `guaranteed_const` is used. Guarantee const is
+  // normally used in TPU inference to avoid re-copying unchanged variables onto
+  // the TPU device. It promises the value is identical for every execution in
+  // the same session even if the actual value changes in later executions.
+  bool has_guaranteed_const = false;
+
+  // Unique session identifier. It is set when `has_guaranteed_const` is true.
+  std::string session_handle;
+
+  // Fingerprint of `guaranteed_const` value. It is set when the value of the
+  // `has_guaranteed_const` is true. Produce the value when necessary.
+  std::function<std::string()> guaranteed_const_fingerprint;
+
+  // A more verbose key for debugging purpose.
+  std::string debug_string;
+
+  explicit TpuCompilationCacheKey() {}
+  explicit TpuCompilationCacheKey(const std::string& p) : prefix(p) {}
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
new file mode 100644
index 0000000..f4f8dbf
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
@@ -0,0 +1,93 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
+
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+class CompilationCacheFetchTargetUtility {
+ public:
+  CompilationCacheFetchTargetUtility()
+      : names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
+
+  std::string name(CompilationCacheFetchTarget target) const {
+    return names_[static_cast<int>(target)];
+  }
+
+ private:
+  const std::vector<std::string> names_;
+};
+
+std::string GetName(CompilationCacheFetchTarget target) {
+  static const auto* util = new CompilationCacheFetchTargetUtility();
+  return util->name(target);
+}
+
+}  // namespace
+
+TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
+    TpuCompilationCacheInterface* cache)
+    : cache_(cache) {}
+
+TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
+  cache_->Unref();
+}
+
+Status TpuCompilationCacheLocalLookup::Lookup(
+    const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
+    CompilationCacheFetchTarget fetch_target) {
+  profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
+                                         /*level=*/2);
+  Status s = cache_->Lookup(proto_key, entry);
+  VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
+          << s;
+  if (!s.ok()) {
+    return s;
+  }
+  s = cache_->ToSubEntryRef(entry->get(), fetch_target);
+
+  VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
+          << s;
+  return s;
+}
+
+Status TpuCompilationCacheLocalLookup::Lookup(
+    int64 uid, int proto_index,
+    std::unique_ptr<CompilationCacheEntryRef>* entry,
+    CompilationCacheFetchTarget fetch_target) {
+  profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
+                                         /*level=*/2);
+  Status s = cache_->Lookup(uid, proto_index, entry);
+  VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
+          << " in local subgraph cache status " << s;
+  if (!s.ok()) {
+    return s;
+  }
+  s = cache_->ToSubEntryRef(entry->get(), fetch_target);
+  VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
+          << s;
+  return s;
+}
+
+string TpuCompilationCacheLocalLookup::DebugString() const {
+  return "TpuCompilationCacheLocalLookup";
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
new file mode 100644
index 0000000..138777a
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
@@ -0,0 +1,99 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
+
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Base class allowing Execute Ops to look up ISA protos. Different subclasses
+// are used when the execute Op is in the same address space as the compile Op,
+// and when they need to communicate over RPC.
+class TpuCompilationCacheLookup : public ResourceBase {
+ public:
+  ~TpuCompilationCacheLookup() override = default;
+
+  // Looks up an executable corresponding to the model-parallel core index of
+  // the subgraph represented by key. On success a wrapper for the proto is
+  // returned in program. The wrapper is guaranteed to be valid only during the
+  // execution of the Op requesting the proto.
+  //
+  // Only one of the main, sharding, unsharding entries is fetched, as specified
+  // in fetch_target.
+  //
+  // If the compilation does not create sharding/unsharding programs, but the
+  // fetch_target requests one of them, then after this call
+  //   (*entry)->get().get_executable() will return nullptr.
+  virtual Status Lookup(const string& proto_key,
+                        std::unique_ptr<CompilationCacheEntryRef>* entry,
+                        CompilationCacheFetchTarget fetch_target) = 0;
+
+  virtual Status Lookup(const string& proto_key,
+                        std::unique_ptr<CompilationCacheEntryRef>* entry) {
+    return Lookup(proto_key, std::move(entry),
+                  CompilationCacheFetchTarget::MAIN);
+  }
+
+  // Looks up an executable corresponding to the model-parallel core index of
+  // the subgraph represented by uid. On success a wrapper for the proto is
+  // returned in program. The wrapper is guaranteed to be valid only during the
+  // execution of the Op requesting the proto.
+  virtual Status Lookup(int64 uid, int proto_index,
+                        std::unique_ptr<CompilationCacheEntryRef>* entry,
+                        CompilationCacheFetchTarget fetch_target) = 0;
+
+  virtual Status Lookup(int64 uid, int proto_index,
+                        std::unique_ptr<CompilationCacheEntryRef>* entry) {
+    return Lookup(uid, proto_index, std::move(entry),
+                  CompilationCacheFetchTarget::MAIN);
+  }
+};
+
+// Forward declaration to break cycle dependency graph.
+class TpuCompilationCacheInterface;
+
+// Class for looking up ISA protos when the execute and compile Op are in the
+// same address space. The proto is simply looked up in the compilation cache,
+// without any serialization taking place.
+class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
+ public:
+  explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
+  ~TpuCompilationCacheLocalLookup() override;
+
+  Status Lookup(const string& proto_key,
+                std::unique_ptr<CompilationCacheEntryRef>* entry,
+                CompilationCacheFetchTarget fetch_target) override;
+
+  Status Lookup(int64 uid, int proto_index,
+                std::unique_ptr<CompilationCacheEntryRef>* entry,
+                CompilationCacheFetchTarget fetch_target) override;
+
+  string DebugString() const override;
+
+ private:
+  // The subgraph compilation cache, in the same process address space where the
+  // lookups are happening.
+  TpuCompilationCacheInterface* cache_;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc
new file mode 100644
index 0000000..ba4e2cc
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
+
+namespace tensorflow {
+namespace tpu {
+
+/* static */
+void TpuCompilationCacheMetrics::IncrementCacheLookupCount(
+    bool is_cache_hit, absl::string_view session_name) {
+  // A placeholder for tracking metrics.
+}
+
+/* static */
+void TpuCompilationCacheMetrics::SetCacheEntryCount(int64 count) {
+  // A placeholder for tracking metrics.
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h
new file mode 100644
index 0000000..e30a7a4
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Tracks Tpu compilation cache metrics.
+class TpuCompilationCacheMetrics {
+ public:
+  // Increments the number of cache lookup count.
+  static void IncrementCacheLookupCount(bool is_cache_hit,
+                                        absl::string_view session_name);
+
+  // Sets the total count of cache entries.
+  static void SetCacheEntryCount(int64 count);
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile.proto b/tensorflow/core/tpu/kernels/tpu_compile.proto
new file mode 100644
index 0000000..5b70de6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile.proto
@@ -0,0 +1,144 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/protobuf/tpu/compile_metadata.proto";
+import "tensorflow/core/tpu/kernels/tpu_executable_info.proto";
+
+message PerCoreVariableIndices {
+  // For each resource variable output, what was the index of the corresponding
+  // input and was it updated? The indices are sorted by input order.
+  repeated TPUExecutableInfoProto.UpdateIndexPair variable_indices = 1;
+}
+
+message PerCoreArgShapes {
+  // Argument shapes for each Tpu core.
+  repeated xla.ShapeProto shapes = 1;
+}
+
+message PerCoreOutputShapes {
+  // Output shapes for each Tpu core.
+  repeated xla.ShapeProto shapes = 1;
+}
+
+message OutputDescriptionProto {
+  // Type and shape of the output. The shape is the unflattened shape.
+  // When `type` is DT_RESOURCE, `shape` is the shape of the resource
+  // variable's value.
+  tensorflow.DataType type = 1;
+  tensorflow.TensorShapeProto shape = 2;
+
+  // Constant output value, if known to be constant at JIT compilation time.
+  // 'Tensor' is in host memory.
+  bool is_constant = 3;
+  tensorflow.TensorProto constant_value = 4;
+
+  // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
+  // the index of the input that contains the resource.
+  int32 input_index = 5;
+
+  // Whether this output is a TensorList.
+  bool is_tensor_list = 6;
+}
+
+// Describes a variable write side effect of the computation.
+message ResourceUpdateProto {
+  // Index of the input that contains the variable resource to write to.
+  int32 input_index = 1;
+
+  // Type and shape of the tensor to be written back.
+  // The `shape` field has the same meaning as the Argument::shape field.
+  tensorflow.DataType type = 2;
+  tensorflow.TensorShapeProto shape = 3;
+
+  // Was the value of the variable modified by the computation?
+  // (Always true, unless `return_updated_values_for_all_resources` is true.)
+  bool modified = 4;
+
+  // If the resource is a TensorArray, the set of gradients read or written.
+  map<string, bool> tensor_array_gradients_accessed = 5;
+}
+
+// Describes the result of a XLA Compiler compilation.
+message XlaCompilationResultProto {
+  // Vector that maps from the parameters of the XLA computation to their
+  // original argument positions. To handle compile-time constant inputs, the
+  // parameters to the XLA computation may be a subset of the original
+  // arguments. The relative ordering of parameters are maintained.
+  repeated int32 input_mappings = 1;
+
+  // Input shapes of the computation. If we are flattening inputs, these are
+  // the flattened shapes.
+  repeated xla.ShapeProto xla_input_shapes = 2;
+
+  // Output shape in XLA format. The output shape is always a tuple. If we
+  // are flattening outputs, these are the flattened shapes.
+  xla.ShapeProto xla_output_shape = 3;
+
+  // TensorFlow shapes of outputs, together with the values of any
+  // constant arguments. Vector indexed by Tensorflow _Retval number,
+  // containing both constant and non-constant results.
+  repeated OutputDescriptionProto outputs = 4;
+
+  // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
+  // matching RecvAtHost/SendFromHost Ops in the outer graph.
+  tf2xla.HostComputeMetadata host_compute_metadata = 5;
+
+  // Resources whose values were updated by the computation, ordered
+  // by return value position (which is the same as the order the resources
+  // were passed as arguments). Resource updates follow the non-constant
+  // results in the outputs of XLA computation.
+  repeated ResourceUpdateProto resource_updates = 6;
+
+  // The XLA computation built from the tensorflow subgraph.
+  xla.HloModuleProto computation = 7;
+}
+
+// TpuAotCompilationRequestProto represents a compilation request for performing
+// ahead-of-time (AOT) compilation of XLA Computations into XLA HLO IR.
+message TpuAotCompilationRequestProto {
+  // A set of HLO module built to run concurrently
+  // across different devices.
+  xla.HloModuleGroupProto hlo_module_group = 1;
+
+  // Compilation metadata.
+  TPUCompileMetadataProto metadata = 2;
+
+  // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
+  // represents the device ids assigned to a set of replicated computations.
+  // See xla::DeviceAssignment class comment for more details.
+  xla.DeviceAssignmentProto device_assignment = 3;
+
+  // Per TPU core program arguments shapes.
+  repeated PerCoreArgShapes per_core_arg_shapes = 4;
+
+  // Per TPU core program outputs shapes.
+  repeated PerCoreOutputShapes per_core_output_shapes = 5;
+
+  // Per TPU core information containing what was the index of the corresponding
+  // input and if whether it was updated. The indices are sorted by input order.
+  repeated PerCoreVariableIndices per_core_variable_indices = 6;
+
+  // XLA compiler compilation result.
+  XlaCompilationResultProto compilation_result = 7;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
new file mode 100644
index 0000000..53e79aa
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
@@ -0,0 +1,119 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
+
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+enum TpuCoreTypeEnum {
+  kTensorCore,
+  kEmbeddingV1,
+  kEmbeddingV2,
+};
+
+typedef struct XLA_TpuProgram XLA_TpuProgram;
+
+// Property for creating compilation cache key.
+struct CompilationCacheKeyProperty {
+  const char* config_prefix;
+  const char* shapes_prefix;
+  const char* function_name;
+  const char* mlir_module;
+  const int32_t* device_ids;
+  size_t device_ids_size;
+  int32_t guaranteed_constants_size;
+  uint64_t function_library_fingerprint;
+  int32_t num_cores_per_replica;
+  int32_t num_replicas;
+  const XLA_TpuMeshState* mesh_state;
+};
+
+extern "C" {
+
+// Creates a new TPU program.
+XLA_TpuProgram* TpuProgram_New();
+
+// Destroys the `tpu_program`.
+void TpuProgram_Free(XLA_TpuProgram* tpu_program);
+
+
+// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
+// destroyed, it is in an unusable state.
+void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
+                                 SE_Status* status);
+
+// Gets TPU program size in bytes from the `tpu_program`.
+int64_t TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
+
+// Logs the summary of current memory state snapshot of the `tpu_program`.
+bool TpuProgram_LogProgramMemorySummary(const XLA_TpuProgram* tpu_program);
+
+// Gets TPU program executable info from the `tpu_program`.
+void TpuProgram_GetExecutableInfo(const XLA_TpuProgram* tpu_program,
+                                  TpuSerializedProto* executable_info);
+
+// Gets host transfer info proto.
+void TpuProgram_GetHostTransferInfo(
+    const XLA_TpuProgram* tpu_program,
+    TpuSerializedProto* host_transfer_info);
+
+// Gets HLO metadata proto.
+void TpuProgram_GetHloMetadata(const XLA_TpuProgram* tpu_program,
+                               TpuSerializedProto* hlo_metadata);
+
+// Returns the number of available TPU core count.
+int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state,
+                                   TpuCoreTypeEnum tpu_core_type);
+
+// Creates a unique compilation cache `key` used for `put` and `get` operations.
+// Returned buffer is heap-allocated and must be owned.
+const char* TpuCompile_CreateCompilationCacheKey(
+    CompilationCacheKeyProperty property);
+
+// Creates a guaranteed const fingerprint. Guarantee const is normally used in
+// TPU inference to avoid re-copying unchanged variables onto the TPU device.
+// It promises the value is identical for every execution in the same session
+// even if the actual value changes in later executions.
+uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint,
+                                                     const char* data,
+                                                     size_t size);
+
+// Checks if whether a TPU compilation is enabled.
+bool TpuCompile_IsTpuCompilationEnabled();
+
+// Executes the computations using XLA TPU compiler and returns TPU programs
+// ready for execution.
+void TpuCompile_CompileAheadOfTime(
+    TpuSerializedProto aot_compilation_request,
+    XLA_TpuProgram** tpu_programs[],
+    size_t* count, SE_Status* status);
+
+// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
+void TpuCompile_BuildXLADeviceAssignment(
+    TpuSerializedProto serialized_tpu_compile_metadata,
+    const XLA_TpuMeshState* mesh_state,
+    TpuSerializedProto* serialized_device_assignment, SE_Status* status);
+
+// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
+void TpuCompile_ToTpuShapeRepresentation(
+    TpuSerializedProto serialized_xla_shape, int data_type,
+    bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
+    SE_Status* status);
+
+}  // extern "C"
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc
new file mode 100644
index 0000000..49a2a08
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc
@@ -0,0 +1,42 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
+
+namespace tensorflow {
+namespace internal {
+
+namespace {
+static bool tpu_compilation_cancellation_terminates_process = true;
+static bool tpu_compilation_failure_closes_chips = true;
+}  // namespace
+
+void SetTpuCompilationCancellationTerminatesProcess(bool b) {
+  tpu_compilation_cancellation_terminates_process = b;
+}
+
+bool TpuCompilationCancellationTerminatesProcess() {
+  return tpu_compilation_cancellation_terminates_process;
+}
+
+void SetTpuCompilationFailureClosesChips(bool value) {
+  tpu_compilation_failure_closes_chips = value;
+}
+
+bool TpuCompilationFailureClosesChips() {
+  return tpu_compilation_failure_closes_chips;
+}
+
+}  // namespace internal
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_options.h b/tensorflow/core/tpu/kernels/tpu_compile_op_options.h
new file mode 100644
index 0000000..b81fe4a
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_options.h
@@ -0,0 +1,42 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
+
+#include <string>
+
+namespace tensorflow {
+namespace internal {
+
+// Setter and getter that determine how TPUCompile responds to cancelled
+// compilation.  By default this is true, meaning cancelled compilation will
+// abort the process, since that's the only mechanism we have available.
+//
+// Setting this to false allows the process to remain alive, and should only be
+// used in tests.
+void SetTpuCompilationCancellationTerminatesProcess(bool b);
+bool TpuCompilationCancellationTerminatesProcess();
+
+// Setter and getter that determine whether TPU compilation failure will cause
+// chips to close. By default this is true, it is suitable for training. For
+// inference, we never want servers to die and thus chips will keep alive.
+// See b/109873767.
+void SetTpuCompilationFailureClosesChips(bool value);
+bool TpuCompilationFailureClosesChips();
+
+}  // namespace internal
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
new file mode 100644
index 0000000..d42c604
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
@@ -0,0 +1,439 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+
+#include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/dump.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+
+namespace tensorflow {
+namespace tpu {
+
+using stream_executor::port::Status;
+using stream_executor::port::StatusOr;
+using xla::ComputationLayout;
+using xla::DebugOptions;
+using xla::DeviceAssignment;
+using xla::HloModuleConfig;
+using xla::HloSharding;
+using xla::InvalidArgument;
+using xla::ProgramShape;
+using xla::Shape;
+using xla::ShapeTree;
+using xla::ShapeUtil;
+
+Status ValidateResultShape(const Shape& client_shape,
+                           const Shape& result_shape) {
+  TF_RETURN_IF_ERROR(
+      xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
+  if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
+    return InvalidArgument(
+        "Shape used to set computation result layout %s is not compatible "
+        "with result shape %s",
+        xla::ShapeUtil::HumanStringWithLayout(client_shape),
+        xla::ShapeUtil::HumanString(result_shape));
+  }
+  return Status::OK();
+}
+
+StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
+    const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
+    absl::optional<const Shape> result_layout,
+    absl::optional<const DeviceAssignment> device_assignment, int replica_count,
+    int num_partitions, const DebugOptions* debug_options, const int* seed,
+    const int* launch_id, const bool* alias_passthrough_params,
+    const xla::FusionConfigCollection* fusion_config_collection,
+    const std::vector<std::vector<bool>>* fusion_config) {
+  auto config = absl::make_unique<HloModuleConfig>(program_shape);
+  ComputationLayout* computation_layout =
+      config->mutable_entry_computation_layout();
+  if (program_shape.parameters_size() != argument_shapes.size()) {
+    return InvalidArgument("computation takes %d parameters, but %u given",
+                           program_shape.parameters_size(),
+                           argument_shapes.size());
+  }
+  for (int i = 0; i < argument_shapes.size(); ++i) {
+    // Verify that shape of arguments matches the shape of the arguments in the
+    // ProgramShape.
+    if (!ShapeUtil::Compatible(argument_shapes[i],
+                               program_shape.parameters(i))) {
+      return InvalidArgument(
+          "Argument does not match shape of computation parameter %d: want "
+          "%s, got %s",
+          i, ShapeUtil::HumanString(program_shape.parameters(i)),
+          ShapeUtil::HumanString(argument_shapes[i]));
+    }
+    TF_RETURN_IF_ERROR(
+        computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+            argument_shapes[i]));
+  }
+
+  if (result_layout.has_value()) {
+    TF_RETURN_IF_ERROR(
+        ValidateResultShape(result_layout.value(), program_shape.result()));
+    TF_RETURN_IF_ERROR(
+        computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+            result_layout.value()));
+  } else {
+    // If the result layout is not set, then choose the default.
+    computation_layout->mutable_result_layout()->SetToDefaultLayout();
+  }
+
+  config->set_replica_count(replica_count);
+  config->set_num_partitions(num_partitions);
+  if (seed != nullptr) {
+    config->set_seed(*seed);
+  }
+  if (launch_id != nullptr) {
+    config->set_launch_id(*launch_id);
+  }
+  if (debug_options != nullptr) {
+    config->set_debug_options(*debug_options);
+  } else {
+    config->set_debug_options(xla::GetDebugOptionsFromFlags());
+  }
+
+  // TODO(henrytan): set intra_op_parallelism_threads.
+  // Reference:
+  // tensorflow/compiler/xla/service/service.cc?l=324.
+
+  if (device_assignment.has_value()) {
+    config->set_static_device_assignment(device_assignment.value());
+  }
+
+  if (alias_passthrough_params != nullptr) {
+    config->set_alias_passthrough_params(*alias_passthrough_params);
+  }
+
+  if (fusion_config_collection != nullptr && fusion_config != nullptr &&
+      *fusion_config_collection != xla::FusionConfigCollection::kOff) {
+    config->set_fusion_config_collection(*fusion_config_collection);
+    *config->mutable_fusion_config() = *fusion_config;
+  }
+
+  return std::move(config);
+}
+
+StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
+    const xla::ProgramShape& program_shape,
+    absl::Span<const Shape> argument_shapes,
+    absl::optional<const Shape> result_layout,
+    absl::optional<const DeviceAssignment> device_assignment, int replica_count,
+    int num_partitions, const DebugOptions* debug_options) {
+  return CreateModuleConfig(program_shape, argument_shapes, result_layout,
+                            device_assignment, replica_count, num_partitions,
+                            debug_options, /*seed=*/nullptr,
+                            /*launch_id=*/nullptr,
+                            /*alias_passthrough_params=*/nullptr,
+                            /*fusion_config_collection=*/nullptr,
+                            /*fusion_config=*/nullptr);
+}
+
+ShapeTree<HloSharding> GetSubtree(
+    const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
+  ShapeTree<HloSharding> element_shape_tree(
+      xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
+                                           element_index),
+      HloSharding::Replicate());
+
+  xla::ShapeIndex src_index;
+  src_index.push_back(element_index);
+  element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
+  return element_shape_tree;
+}
+
+Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
+                        int64 device) {
+  if (shape.IsTuple()) {
+    ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
+    std::vector<Shape> arg_shapes;
+    for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
+      Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
+      HloSharding element_sharding = tuple_shape_tree.element({i});
+      if (element_shape.IsTuple()) {
+        element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
+      }
+      if (element_sharding.UsesDevice(device)) {
+        arg_shapes.push_back(
+            GetPerDeviceShape(element_shape, element_sharding, device));
+      }
+    }
+    return xla::ShapeUtil::MakeTupleShape(arg_shapes);
+  }
+
+  if (sharding.IsTileMaximal()) {
+    return shape;
+  }
+
+  std::vector<int64> dimensions;
+  std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device);
+  std::vector<int64> limit = sharding.TileLimitForDevice(shape, device);
+  for (int64 i = 0; i < limit.size(); ++i) {
+    dimensions.push_back(limit[i] - offset[i]);
+  }
+  if (shape.has_layout()) {
+    return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
+                                               shape.layout().minor_to_major());
+  }
+  return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
+}
+
+Status AddVariableUpdatesToCores(
+    const TPUCompileMetadataProto& metadata,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const std::vector<ShardingAndIndex>& arg_core_mapping,
+    std::vector<bool>* may_modify_variables,
+    std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
+    std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
+  // Add all variables to the corresponding core.
+  may_modify_variables->resize(metadata.num_cores_per_replica(), false);
+  int resource_update_pos = 0;
+  for (int i = 0; i < metadata.args_size(); ++i) {
+    const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
+    if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
+      const auto& sharding = proto_arg.sharding();
+      bool updated = false;
+      if (resource_update_pos < compilation_result.resource_updates.size()) {
+        const XlaCompiler::ResourceUpdate& update =
+            compilation_result.resource_updates[resource_update_pos];
+        if (update.input_index == i) {
+          updated = true;
+          int pos = compilation_result.outputs.size() + resource_update_pos;
+          xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
+              compilation_result.xla_output_shape, pos);
+          auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) {
+            (*per_core_output_shapes)[core].push_back(per_core_shape);
+            (*may_modify_variables)[core] =
+                (*may_modify_variables)[core] || update.modified;
+          };
+          if (sharding.type() == xla::OpSharding::MAXIMAL) {
+            add_to_core(sharding.tile_assignment_devices(0), shape);
+          } else if (sharding.type() == xla::OpSharding::OTHER) {
+            auto sharding_or =
+                xla::HloSharding::FromProto(proto_arg.sharding());
+            TF_RET_CHECK(sharding_or.ok());
+            for (int64 core : proto_arg.sharding().tile_assignment_devices()) {
+              xla::Shape per_core_shape =
+                  GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
+              add_to_core(core, per_core_shape);
+            }
+          } else {
+            TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
+            for (int64 core = 0; core < metadata.num_cores_per_replica();
+                 ++core) {
+              add_to_core(core, shape);
+            }
+          }
+          ++resource_update_pos;
+        }
+      }
+      if (sharding.type() == xla::OpSharding::MAXIMAL) {
+        (*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
+            .push_back(
+                std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
+      } else if (sharding.type() == xla::OpSharding::OTHER) {
+        for (int core : sharding.tile_assignment_devices()) {
+          (*per_core_variable_indices)[core].push_back(
+              std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
+        }
+      } else {
+        TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
+        for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) {
+          (*per_core_variable_indices)[core].push_back(
+              std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
+        }
+      }
+    }
+  }
+  return Status::OK();
+}
+
+Status ComputeOutputShapesForEachCore(
+    const tpu::TPUCompileMetadataProto& metadata,
+    const XlaCompiler::CompilationResult& compilation_result,
+    std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
+  for (int i = 0; i < metadata.retvals_size(); ++i) {
+    const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
+    TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
+        << "TPU compilation output " << i
+        << " has a compile-time constant value. "
+           "This should never happen.";
+
+    xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
+        compilation_result.xla_output_shape, i);
+    auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
+      (*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
+    };
+    if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
+      add_shape_to_core(retval.sharding().tile_assignment_devices(0),
+                        std::move(shape));
+    } else if (retval.sharding().type() == xla::OpSharding::OTHER) {
+      auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
+      TF_RET_CHECK(sharding_or.ok());
+      for (int64 core : retval.sharding().tile_assignment_devices()) {
+        xla::Shape per_core_shape =
+            GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
+        add_shape_to_core(core, std::move(per_core_shape));
+      }
+    } else {
+      TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
+          << "Not all of the constant tensors were consumed.";
+      for (int core = 0; core < per_core_output_shapes->size(); ++core) {
+        add_shape_to_core(core, shape);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+Status CreateHloModules(
+    const TPUCompileMetadataProto& metadata,
+    const tensorflow::XlaCompiler::CompilationResult& compilation_result,
+    const absl::optional<xla::DeviceAssignment>& device_assignment,
+    std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
+  TF_RET_CHECK(
+      compilation_result.computation->proto().has_host_program_shape());
+
+  auto debug_options = xla::DebugOptions();
+  debug_options.set_xla_step_marker_location(metadata.step_marker_location());
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<xla::HloModuleConfig> module_config,
+      CreateModuleConfig(
+          xla::ProgramShape(
+              compilation_result.computation->proto().host_program_shape()),
+          compilation_result.xla_input_shapes,
+          compilation_result.xla_output_shape, device_assignment,
+          metadata.num_replicas(), metadata.num_cores_per_replica(),
+          &debug_options));
+
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<xla::HloModule> hlo_module,
+      xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
+                                      *module_config));
+  DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
+  hlo_modules->push_back(std::move(hlo_module));
+
+  return Status::OK();
+}
+
+XlaCompilationResultProto SerializeCompilationResult(
+    const XlaCompiler::CompilationResult& compilation_result) {
+  XlaCompilationResultProto compilation_result_proto;
+  for (int input_mapping : compilation_result.input_mapping) {
+    compilation_result_proto.add_input_mappings(input_mapping);
+  }
+
+  for (const Shape& input_shape : compilation_result.xla_input_shapes) {
+    *(compilation_result_proto.add_xla_input_shapes()) = input_shape.ToProto();
+  }
+  *(compilation_result_proto.mutable_xla_output_shape()) =
+      compilation_result.xla_output_shape.ToProto();
+
+  for (const XlaCompiler::OutputDescription& output_description :
+       compilation_result.outputs) {
+    auto* new_output = compilation_result_proto.add_outputs();
+    new_output->set_type(output_description.type);
+    output_description.shape.AsProto(new_output->mutable_shape());
+    new_output->set_is_constant(output_description.is_constant);
+    output_description.constant_value.AsProtoField(
+        new_output->mutable_constant_value());
+    new_output->set_input_index(output_description.input_index);
+    new_output->set_is_tensor_list(output_description.is_tensor_list);
+  }
+
+  *compilation_result_proto.mutable_host_compute_metadata() =
+      compilation_result.host_compute_metadata;
+
+  for (const XlaCompiler::ResourceUpdate& resource_update :
+       compilation_result.resource_updates) {
+    auto* new_resource_update = compilation_result_proto.add_resource_updates();
+    new_resource_update->set_input_index(resource_update.input_index);
+    new_resource_update->set_type(resource_update.type);
+    resource_update.shape.AsProto(new_resource_update->mutable_shape());
+    new_resource_update->set_modified(resource_update.modified);
+    for (const std::string& gradient_access :
+         resource_update.tensor_array_gradients_accessed) {
+      new_resource_update->mutable_tensor_array_gradients_accessed()->insert(
+          {gradient_access, true});
+    }
+  }
+
+  if (compilation_result.computation != nullptr) {
+    *compilation_result_proto.mutable_computation() =
+        compilation_result.computation->proto();
+  }
+
+  return compilation_result_proto;
+}
+
+StatusOr<TpuAotCompilationRequestProto> CreateTpuAotCompilationRequest(
+    const xla::HloModuleGroup& module_group,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const TPUCompileMetadataProto& metadata,
+    const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+    const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+    const std::vector<std::vector<std::pair<int, bool>>>&
+        per_core_variable_indices,
+    const absl::optional<xla::DeviceAssignment>& device_assignment) {
+  VLOG(1) << "CreateTpuAotCompilationRequest.";
+  TpuAotCompilationRequestProto aot_request;
+  *(aot_request.mutable_hlo_module_group()) = module_group.ToProto();
+  *(aot_request.mutable_metadata()) = metadata;
+  if (device_assignment.has_value()) {
+    xla::DeviceAssignmentProto device_assignment_proto;
+    Status status = device_assignment->Serialize(&device_assignment_proto);
+    if (!status.ok()) {
+      return status;
+    }
+    *(aot_request.mutable_device_assignment()) = device_assignment_proto;
+  }
+
+  for (const auto& arg_shapes : per_core_arg_shapes) {
+    auto* new_shape_list = aot_request.add_per_core_arg_shapes();
+    for (const auto& arg_shape : arg_shapes) {
+      *new_shape_list->add_shapes() = arg_shape.ToProto();
+    }
+  }
+
+  for (const auto& output_shapes : per_core_output_shapes) {
+    auto* new_shape_list = aot_request.add_per_core_output_shapes();
+    for (const auto& output_shape : output_shapes) {
+      *new_shape_list->add_shapes() = output_shape.ToProto();
+    }
+  }
+
+  for (const auto& variable_indices : per_core_variable_indices) {
+    auto* new_list = aot_request.add_per_core_variable_indices();
+    for (const auto& variable_index : variable_indices) {
+      auto* core_index = new_list->add_variable_indices();
+      core_index->set_index(variable_index.first);
+      core_index->set_updated(variable_index.second);
+    }
+  }
+
+  XlaCompilationResultProto compilation_result_proto =
+      SerializeCompilationResult(compilation_result);
+  *aot_request.mutable_compilation_result() = compilation_result_proto;
+
+  VLOG(1) << "TpuAotCompilationRequest:\n" << aot_request.DebugString();
+  return aot_request;
+}
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
new file mode 100644
index 0000000..0f21e45
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
@@ -0,0 +1,122 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace se = ::stream_executor;
+
+// Describes the position of an argument or return value after the computation
+// has been partitioned into cores.
+struct ShardingAndIndex {
+  // Sharding across cores.
+  ::xla::OpSharding sharding;
+  // Argument/return value number. If sharding is single-core, `indices` has a
+  // single element; otherwise, it has num_cores elements.
+  std::vector<int> indices;
+};
+
+// TODO(b/158279168): Dedup with internal version.
+// Return the per-device shape for a `shape` with a given `sharding`.
+xla::Shape GetPerDeviceShape(const xla::Shape& shape,
+                             const xla::HloSharding& sharding,
+                             int64 device);
+
+stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
+CreateModuleConfig(
+    const xla::ProgramShape& program_shape,
+    absl::Span<const xla::Shape> argument_shapes,
+    absl::optional<const xla::Shape> result_layout,
+    absl::optional<const xla::DeviceAssignment> device_assignment,
+    int replica_count, int num_partitions,
+    const xla::DebugOptions* debug_options, const int* seed,
+    const int* launch_id, const bool* alias_passthrough_params,
+    const xla::FusionConfigCollection* fusion_config_collection,
+    const std::vector<std::vector<bool>>* fusion_config);
+
+stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
+CreateModuleConfig(
+    const xla::ProgramShape& program_shape,
+    absl::Span<const xla::Shape> argument_shapes,
+    absl::optional<const xla::Shape> result_layout,
+    absl::optional<const xla::DeviceAssignment> device_assignment,
+    int replica_count,
+    int num_partitions, const xla::DebugOptions* debug_options);
+
+xla::ShapeTree<xla::HloSharding> GetSubtree(
+    const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
+    int element_index);
+
+xla::Shape GetPerDeviceShape(const xla::Shape& shape,
+                             const xla::HloSharding& sharding,
+                             int64 device);
+
+Status AddVariableUpdatesToCores(
+    const TPUCompileMetadataProto& metadata,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const std::vector<ShardingAndIndex>& arg_core_mapping,
+    std::vector<bool>* may_modify_variables,
+    std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
+    std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices);
+
+se::port::Status ComputeOutputShapesForEachCore(
+    const tpu::TPUCompileMetadataProto& metadata,
+    const XlaCompiler::CompilationResult& compilation_result,
+    std::vector<std::vector<xla::Shape>>* per_core_output_shapes);
+
+se::port::Status CreateHloModules(
+    const TPUCompileMetadataProto& metadata,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const absl::optional<xla::DeviceAssignment>& device_assignment,
+    std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules);
+
+se::port::StatusOr<TpuAotCompilationRequestProto>
+CreateTpuAotCompilationRequest(
+    const xla::HloModuleGroup& module_group,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const TPUCompileMetadataProto& metadata,
+    const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+    const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+    const std::vector<std::vector<std::pair<int, bool>>>&
+        per_core_variable_indices,
+    const absl::optional<xla::DeviceAssignment>& device_assignment);
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
new file mode 100644
index 0000000..7fa345d
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
@@ -0,0 +1,298 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
+
+#include <cstdint>
+
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/refcount.h"
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
+#include "tensorflow/core/tpu/tpu_config_c_api.h"
+#include "tensorflow/core/tpu/tpu_configuration.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+
+namespace tensorflow {
+namespace {
+
+Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
+                                tpu::TpuMeshStateInterface** state) {
+  if (!rmgr->Lookup(rmgr->default_container(),
+                    tpu::kTpuMeshCommonStateResourceName, state)
+           .ok()) {
+    return errors::FailedPrecondition(
+        "The TPU system has not been initialized.");
+  }
+  return Status::OK();
+}
+
+// Attempt to delete resource_name from resource_manager's default_container.
+// Returns OK if the deletion succeeded, or if the resource was not found. Else
+// return the deletion error.
+template <class ResourceT>
+Status DeleteIfExists(ResourceMgr* resource_manager,
+                      const char* resource_name) {
+  VLOG(1) << "Removing resource " << resource_name << " if it exists";
+  Status status = resource_manager->Delete<ResourceT>(
+      resource_manager->default_container(), resource_name);
+  if (status.ok()) {
+    VLOG(1) << "Removed existing resource " << resource_name;
+    return Status::OK();
+  }
+  if (status.code() == error::NOT_FOUND) {
+    VLOG(1) << "No resource " << resource_name << " to remove";
+    return Status::OK();
+  }
+  VLOG(1) << "Error removing resource " << resource_name << " : " << status;
+  return status;
+}
+
+}  // namespace
+
+void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "ConfigureDistributedTpuOp";
+  XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
+
+  std::vector<int32_t> num_devices_per_host;
+  int chips_per_host = -1;
+  for (int i = 0; i < ctx->num_inputs(); ++i) {
+    const Tensor& input_tensor = ctx->input(i);
+    OP_REQUIRES(
+        ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
+        errors::InvalidArgument("Input ", i, " should be a scalar but has ",
+                                input_tensor.dims(), " dimensions"));
+    if (chips_per_host == -1) {
+      chips_per_host = input_tensor.scalar<int32_t>()();
+    } else {
+      OP_REQUIRES(
+          ctx, chips_per_host == input_tensor.scalar<int32>()(),
+          errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
+                           " TPU chips but host 0 has ", chips_per_host));
+    }
+    num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
+  }
+
+  TF_Status* status = TF_NewStatus();
+  size_t host_config_output_size;
+  char* host_config_output;
+
+  auto* rmgr = GetTPUConfigResourceMgr();
+  OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
+                          rmgr, tpu::kTpuMeshCommonStateResourceName));
+
+  ConfigureDistributedTpuOp_DoWork(
+      num_devices_per_host.size(), num_devices_per_host.data(),
+      &host_config_output_size, &host_config_output, status);
+
+  OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
+                                   tpu::kTpuMeshCommonStateResourceName,
+                                   tpu::TpuMeshStateInterface::Create()));
+
+  Tensor* ctx_output;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
+  ctx_output->scalar<tstring>()() =
+      std::string(host_config_output, host_config_output_size);
+
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+  TF_DeleteStatus(status);
+  TpuConfigurationApi_FreeCharArray(host_config_output);
+
+  VLOG(1) << "ConfigureDistributedTpuOp done";
+}
+
+void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "WaitForDistributedTpuOp";
+  XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
+
+  size_t num_devices_per_host = -1;
+  size_t num_hosts = ctx->num_inputs();
+
+  for (int i = 0; i < ctx->num_inputs(); ++i) {
+    const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
+    OP_REQUIRES(
+        ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
+        errors::InvalidArgument("Input ", i, " should be a vector but has ",
+                                host_ordinal_to_global_device_id_tensor.dims(),
+                                " dimensions"));
+  }
+
+  std::vector<std::vector<int32_t>> mapping;
+  std::vector<int32_t*> mapping_arg;
+
+  mapping.resize(ctx->num_inputs());
+
+  for (int i = 0; i < ctx->num_inputs(); ++i) {
+    const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
+    const auto host_ordinal_to_global_device_id =
+        host_ordinal_to_global_device_id_tensor.flat<int>();
+    if (num_devices_per_host == -1) {
+      num_devices_per_host =
+          host_ordinal_to_global_device_id_tensor.dim_size(0);
+    } else {
+      OP_REQUIRES(ctx,
+                  num_devices_per_host ==
+                      host_ordinal_to_global_device_id_tensor.dim_size(0),
+                  errors::Internal(
+                      "Host ", i, " has ",
+                      host_ordinal_to_global_device_id_tensor.dim_size(0),
+                      " TPU devices but host 0 has ", num_devices_per_host));
+    }
+    for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
+         ++j) {
+      int32_t global_device_id = host_ordinal_to_global_device_id(j);
+      mapping[i].push_back(global_device_id);
+    }
+    mapping_arg.push_back(mapping[i].data());
+  }
+
+  TF_Status* status = TF_NewStatus();
+  size_t tpu_topology_output_size;
+  char* tpu_topology_output;
+
+  tpu::TpuMeshStateInterface* mesh_state;
+  auto* rmgr = GetTPUConfigResourceMgr();
+  OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
+  core::ScopedUnref mesh_state_unref(mesh_state);
+
+  WaitForDistributedTpuOp_DoWork(
+      num_hosts, num_devices_per_host,
+      const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
+      &tpu_topology_output_size, &tpu_topology_output, status);
+
+  Tensor* ctx_output;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
+  ctx_output->scalar<tstring>()() =
+      std::string(tpu_topology_output, tpu_topology_output_size);
+
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+  TF_DeleteStatus(status);
+  TpuConfigurationApi_FreeCharArray(tpu_topology_output);
+
+  VLOG(1) << "WaitForDistributedTpuOp done";
+}
+
+void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "ShutdownDistributedTpuOp";
+  XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
+
+  TF_Status* status = TF_NewStatus();
+  OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
+                          GetTPUConfigResourceMgr(),
+                          tpu::kTpuMeshCommonStateResourceName));
+  ShutdownDistributedTpuOp_DoWork(status);
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+  TF_DeleteStatus(status);
+
+  VLOG(1) << "ShutdownDistributedTpuOp done";
+}
+
+void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "InitializeHostForDistributedTpuOp";
+  XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
+
+  auto tpu_host_config = ctx->input(0).scalar<tstring>()();
+
+  size_t device_id_output_size;
+  int32_t* device_id_output;
+  TF_Status* status = TF_NewStatus();
+
+  InitializeHostForDistributedTpuOp_DoWork(
+      tpu_host_config.size(), tpu_host_config.data(),
+      enable_whole_mesh_compilations_, &device_id_output_size,
+      &device_id_output, status);
+
+  Tensor* ctx_output;
+  OP_REQUIRES_OK(
+      ctx, ctx->allocate_output(
+               0, TensorShape({static_cast<long long>(device_id_output_size)}),
+               &ctx_output));
+
+  for (size_t i = 0; i < device_id_output_size; ++i) {
+    ctx_output->flat<int32>()(i) = device_id_output[i];
+  }
+
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+  TF_DeleteStatus(status);
+  TpuConfigurationApi_FreeInt32Array(device_id_output);
+
+  VLOG(1) << "InitializeHostForDistributedTpuOp done";
+}
+
+void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "SetGlobalTPUArrayOp";
+  XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
+
+  auto tpu_topology = ctx->input(0).scalar<tstring>()();
+  TF_Status* status = TF_NewStatus();
+
+  SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status);
+
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+  TF_DeleteStatus(status);
+
+  VLOG(1) << "SetGlobalTPUArrayOp done";
+}
+
+void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "DisconnectDistributedTpuChipsOp";
+  XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
+
+  TF_Status* status = TF_NewStatus();
+  int32_t number_of_chips_output = 0;
+
+  DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status);
+
+  Tensor* ctx_output;
+  OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
+  ctx_output->scalar<int32_t>()() = number_of_chips_output;
+
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+  TF_DeleteStatus(status);
+
+  VLOG(1) << "DisconnectDistributedTpuChipsOp done";
+}
+
+// These ops execute on the TPU_SYSTEM device only.
+REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
+                            .Device(DEVICE_TPU_SYSTEM)
+                            .HostMemory("output"),
+                        ConfigureDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
+                            .Device(DEVICE_TPU_SYSTEM)
+                            .HostMemory("inputs")
+                            .HostMemory("topology"),
+                        WaitForDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(
+    Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
+    ShutdownDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
+                            .Device(DEVICE_TPU_SYSTEM)
+                            .HostMemory("input")
+                            .HostMemory("tpu_ids"),
+                        InitializeHostForDistributedTpuOp);
+REGISTER_KERNEL_BUILDER(
+    Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
+    SetGlobalTPUArrayOp);
+REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
+                            .Device(DEVICE_TPU_SYSTEM)
+                            .HostMemory("number_of_tpu_chips"),
+                        DisconnectDistributedTpuChipsOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
new file mode 100644
index 0000000..f75a47e
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
@@ -0,0 +1,156 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+// The ConfigureDistributedTpu op is used to start an TPUDriver from
+// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
+// connection host:port for the CompilationCacheServer. The
+// CompilationCacheServer will remain live until the device's Resource Manager
+// is cleared or a ShutdownDistributedTpuOp is run on the same device.
+class ConfigureDistributedTpuOp : public OpKernel {
+ public:
+  explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {
+    OP_REQUIRES(
+        ctx, ctx->num_inputs() > 0,
+        errors::Internal("_ConfigureDistributedTPU needs at least one input"));
+  }
+  void Compute(OpKernelContext* ctx) override;
+  ~ConfigureDistributedTpuOp() override {}
+
+ private:
+  // ConfigureDistributedTpuOp is neither copyable nor movable.
+  ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete;
+  ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) =
+      delete;
+};
+
+// The WaitForDistributedTpuOp op is used to block execution until
+// the distributed Tpu system has started up. It must be run on
+// the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run
+// on, after all of the InitializeHostForDistributedTpuOp Ops have
+// completed.
+class WaitForDistributedTpuOp : public OpKernel {
+ public:
+  explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx,
+                   ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_));
+    OP_REQUIRES(ctx, startup_timeout_sec_ > 0,
+                errors::InvalidArgument("startup_timeout_sec ",
+                                        startup_timeout_sec_, " must be >0"));
+  }
+  void Compute(OpKernelContext* ctx) override;
+  ~WaitForDistributedTpuOp() override {}
+
+ private:
+  // The time to wait for all hosts to start up.
+  int startup_timeout_sec_;
+
+  // WaitForDistributedTpuOp is neither copyable nor movable.
+  WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete;
+  WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete;
+};
+
+// The ShutdownDistributedTpu op is used to stop a running TPUDriver from
+// TensorFlow. It should be run on the TPU_SYSTEM device where
+// ConfigureDistributedTpuOp was run.
+class ShutdownDistributedTpuOp : public OpKernel {
+ public:
+  explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override;
+
+  ~ShutdownDistributedTpuOp() override {}
+
+ private:
+  // ShutdownDistributedTpuOp is neither copyable nor movable.
+  ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete;
+  ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete;
+};
+
+// The InitializeHostForDistributedTpu op is used to initialize the
+// TPUPlatform on a host in a distributed TPU system. It should be
+// run on every host containing TPU devices before any other Ops that use
+// TPU are run.
+class InitializeHostForDistributedTpuOp : public OpKernel {
+ public:
+  explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {
+    ctx->GetAttr("enable_whole_mesh_compilations",
+                 &enable_whole_mesh_compilations_)
+        .IgnoreError();
+  }
+
+  void Compute(OpKernelContext* ctx) override;
+
+  ~InitializeHostForDistributedTpuOp() override {}
+
+ private:
+  // InitializeHostForDistributedTpuOp is neither copyable nor movable.
+  InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) =
+      delete;
+  InitializeHostForDistributedTpuOp& operator=(
+      const InitializeHostForDistributedTpuOp&) = delete;
+
+  bool enable_whole_mesh_compilations_ = false;
+};
+
+// The SetGlobalTPUArray op is used to initialize the TPUPlatform on a
+// host in a distributed TPU system. It should be run on every host
+// containing TPU devices before any other Ops that use TPU are run.
+class SetGlobalTPUArrayOp : public OpKernel {
+ public:
+  explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override;
+
+  ~SetGlobalTPUArrayOp() override {}
+
+ private:
+  // SetGlobalTPUArrayOp is neither copyable nor movable.
+  SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete;
+  SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete;
+};
+
+// The DisconnectDistributedTpuChips op is used to disconnect all the chips on a
+// host from a running TPUDriver instance. It should be run on every host
+// containing TPU devices before the ShutdownDistributedTpuOp is run on
+// the TPU_SYSTEM.
+class DisconnectDistributedTpuChipsOp : public OpKernel {
+ public:
+  explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override;
+
+  ~DisconnectDistributedTpuChipsOp() override {}
+
+ private:
+  // DisconnectDistributedTpuChipsOp is neither copyable nor movable.
+  DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) =
+      delete;
+  DisconnectDistributedTpuChipsOp& operator=(
+      const DisconnectDistributedTpuChipsOp&) = delete;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_executable_info.proto b/tensorflow/core/tpu/kernels/tpu_executable_info.proto
new file mode 100644
index 0000000..359dad0
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_executable_info.proto
@@ -0,0 +1,94 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+
+import "tensorflow/compiler/xla/service/hlo.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+
+// A serialization of TPUExecutable. Only includes fields necessary to load
+// and execute a program on a worker node.
+message TPUExecutableInfoProto {
+  reserved 1;
+
+  // The shapes of the inputs and outputs.
+  repeated xla.ShapeProto input_shapes = 2;
+  reserved 7;  // was input_shape
+  xla.ShapeProto output_shape = 3;
+
+  message UpdateIndexPair {
+    int32 index = 1;
+    bool updated = 2;
+  }
+
+  message ShapeIndex {
+    repeated int32 index = 1;
+  }
+
+  // Dynamic output indices indicate which outputs have dynamic dimensions.
+  repeated ShapeIndex dynamic_output_indices = 11;
+
+  // For each resource variable output, what was the index of the corresponding
+  // input and was it updated? The indices are sorted by input order.
+  repeated UpdateIndexPair variable_indices = 10;
+
+  // The shapes of the outputs when represented as Tensors. These may not
+  // match the output_shape values because we may flatten tensors to avoid
+  // excess padding.
+  repeated TensorShapeProto output_tensor_shapes = 8;
+
+  reserved 4;
+
+  // Optional session module for passing XLA computations between TPUCompileOp
+  // and TPUExecuteOp. This is needed to support the
+  // --xla_dump_hlo_snapshots flag.
+  xla.HloSnapshot session_module = 5;
+
+  // The physical device ids assigned to the replicated cores.
+  xla.DeviceAssignmentProto device_assignment = 6;
+}
+
+// Metadata for a data transfer between device and host.
+message TPUHostTransferProto {
+  enum TransferDirection {
+    NONE = 0;
+    DEVICE_TO_HOST = 1;
+    HOST_TO_DEVICE = 2;
+  }
+  // Channel identifier assigned by compiler and used in host commands.
+  int64 channel = 1;
+  // Direction of the transfer operation.
+  TransferDirection direction = 2;
+  // Channel identifier prodided by XLA client.
+  string key = 3;
+  // Depth of nested loops for this transfer operation.
+  int64 nested_while_level = 4;
+  // Shape of the data to be transferred (including layout).
+  xla.ShapeProto shape = 5;
+  // Address of the device buffer in HBM (byte offset).
+  int64 buffer_offset = 6;
+  // Original data type for this host transfer before X64 rewrite.
+  xla.PrimitiveType original_type = 7;
+  // If this host transfer is a splitted X64 transfer, sepcifies whether this
+  // transfer is for lower bits.
+  bool is_lower_bits = 8;
+}
+
+message TPUHostTransferInfoProto {
+  repeated TPUHostTransferProto host_transfers = 1;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h
new file mode 100644
index 0000000..cb6a82e
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
+
+typedef struct XLA_TpuMeshState XLA_TpuMeshState;
+
+// Creates a new TPU mesh state object.
+XLA_TpuMeshState* TpuMeshState_Create();
+
+// Deletes the given TPU `mesh_state` object. Once deleted the object is
+// unusable.
+void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
+
+// Returns a pointer to an opaque mesh data structure used internally.
+void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state);
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h
new file mode 100644
index 0000000..34202a7
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
+
+namespace tensorflow {
+
+class TpuMeshCommonState;
+
+namespace tpu {
+
+const char kTpuMeshCommonStateResourceName[] = "tpu_mesh_common_state";
+
+class TpuMeshStateInterface : public tensorflow::ResourceBase {
+ public:
+  explicit TpuMeshStateInterface(XLA_TpuMeshState* handle)
+      : mesh_state_(handle) {
+  }
+
+  ~TpuMeshStateInterface() override {
+    if (mesh_state_ != nullptr) {
+      TpuMeshState_Free(mesh_state_);
+    }
+  }
+
+  static TpuMeshStateInterface* Create() {
+    return new TpuMeshStateInterface(TpuMeshState_Create());
+  }
+
+  const XLA_TpuMeshState* data() const { return mesh_state_; }
+
+  tensorflow::TpuMeshCommonState* mesh_common_state() const {
+    return static_cast<tensorflow::TpuMeshCommonState*>(
+        TpuMeshState_MeshCommonState(mesh_state_));
+  }
+
+  // Returns whether we should include the device assignment as a static field
+  // to the TPU program. This also determines whether we should include the
+  // device assignment as part of the compilation cache key.
+  bool NeedsStaticDeviceAssignment(
+      const TPUCompileMetadataProto& metadata,
+      TpuCoreTypeEnum tpu_core_type) const {
+    // Static device assignment enables XLA to perform certain optimization when
+    // all cores are used in the replicated computation.
+    return metadata.num_cores_per_replica() * metadata.num_replicas() ==
+           TpuTopology_AvailableCoreCount(mesh_state_,
+                                          tpu_core_type);
+  }
+
+  string DebugString() const override { return "TpuMeshStateInterface"; }
+
+ private:
+  XLA_TpuMeshState* mesh_state_;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_program.cc b/tensorflow/core/tpu/kernels/tpu_program.cc
new file mode 100644
index 0000000..7d89ad1
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_program.cc
@@ -0,0 +1,201 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_program.h"
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+
+namespace se_tpu = ::stream_executor::tpu;
+
+using stream_executor::port::StatusOr;
+using xla::Shape;
+
+StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
+    std::unique_ptr<xla::HloModuleGroup> module_group,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const TPUCompileMetadataProto& metadata,
+    const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+    const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+    const std::vector<std::vector<std::pair<int, bool>>>&
+        per_core_variable_indices,
+    const absl::optional<xla::DeviceAssignment>& device_assignment) {
+  VLOG(1) << "Run CompileAheadOfTime.";
+  TF_ASSIGN_OR_RETURN(TpuAotCompilationRequestProto aot_request,
+                      CreateTpuAotCompilationRequest(
+                          *module_group, compilation_result, metadata,
+                          per_core_arg_shapes, per_core_output_shapes,
+                          per_core_variable_indices, device_assignment));
+  se_tpu::SerializedProto serialized_aot_request =
+      se_tpu::SerializeProto(aot_request);
+  auto cleanup = gtl::MakeCleanup([serialized_aot_request] {
+    se_tpu::SerializedProto_Free(serialized_aot_request);
+  });
+
+  XLA_TpuProgram** xla_tpu_programs = nullptr;
+  size_t count = 0;
+  StatusHelper status;
+  VLOG(1) << "Run TpuCompile_CompileAheadOfTime.";
+  TpuCompile_CompileAheadOfTime(serialized_aot_request, &xla_tpu_programs,
+                                &count, status.c_status);
+  VLOG(1) << "Run CompileAheadOfTime completed.";
+  if (!status.status().ok()) {
+    return status.status();
+  }
+  std::vector<XLA_TpuProgram*> tpu_programs(count, nullptr);
+  for (size_t i = 0; i < count; ++i) {
+    tpu_programs[i] = xla_tpu_programs[i];
+  }
+  delete[] xla_tpu_programs;
+  return tpu_programs;
+  return Status::OK();
+}
+
+StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
+    const TPUCompileMetadataProto& metadata,
+    const XlaCompiler::CompilationResult& compilation_result,
+    const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+    const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
+    const std::vector<std::vector<std::pair<int, bool>>>&
+        per_core_variable_indices,
+    const absl::optional<xla::DeviceAssignment>& device_assignment) {
+  VLOG(1) << "Compile Tpu programs.";
+  std::vector<std::unique_ptr<xla::HloModule>> hlo_modules;
+  auto status = CreateHloModules(metadata, compilation_result,
+                                 device_assignment, &hlo_modules);
+  if (!status.ok()) {
+    return status;
+  }
+
+  return CompileAheadOfTime(
+      absl::make_unique<xla::HloModuleGroup>(hlo_modules[0]->name(),
+                                             absl::MakeSpan(hlo_modules)),
+      compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
+      per_core_variable_indices, device_assignment);
+}
+
+}  // namespace
+
+int64_t TpuProgram::program_size() const {
+  int64_t total_size = 0;
+  for (XLA_TpuProgram* tpu_program : tpu_programs_) {
+    total_size += TpuProgram_GetProgramSize(tpu_program);
+  }
+  return total_size;
+}
+
+bool TpuProgram::LogProgramMemorySummary() {
+  bool success = true;
+  for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
+    success &= TpuProgram_LogProgramMemorySummary(tpu_program);
+  }
+  return success;
+}
+
+void TpuProgram::UnloadAndDestroyPrograms() {
+  for (XLA_TpuProgram* tpu_program : tpu_programs_) {
+    StatusHelper status;
+    TpuProgram_UnloadAndDestroy(tpu_program, status.c_status);
+    auto s = status.status();
+    if (!s.ok()) {
+      LOG(ERROR) << "TpuProgram::UnloadPrograms(): " << s.ToString();
+    }
+  }
+  tpu_programs_.clear();
+}
+
+/*static*/ Status TpuProgram::Build(
+    const TPUCompileMetadataProto& metadata,
+    const tensorflow::XlaCompiler::CompilationResult& compilation_result,
+    const std::vector<ShardingAndIndex>& arg_core_mapping,
+    const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+    const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
+    TpuProgram* tpu_program) {
+  std::vector<std::vector<xla::Shape>> per_core_output_shapes(
+      metadata.num_cores_per_replica());
+  TF_RETURN_IF_ERROR(ComputeOutputShapesForEachCore(
+      metadata, compilation_result, &per_core_output_shapes));
+
+  std::vector<std::vector<std::pair<int, bool>>> per_core_variable_indices(
+      metadata.num_cores_per_replica());
+  std::vector<bool> may_modify_variables;
+  TF_RETURN_IF_ERROR(AddVariableUpdatesToCores(
+      metadata, compilation_result, arg_core_mapping, &may_modify_variables,
+      &per_core_output_shapes, &per_core_variable_indices));
+  TF_RET_CHECK(per_core_arg_shapes.size() == metadata.num_cores_per_replica());
+  TF_RET_CHECK(per_core_output_shapes.size() == per_core_arg_shapes.size());
+  TF_RET_CHECK(per_core_output_shapes.size() ==
+               per_core_variable_indices.size());
+  tpu_program->set_may_modify_variables(may_modify_variables);
+
+  // With shardable input/output pairs, XLA could generate separate
+  // sharding/unsharding programs along with the main program. The
+  // sharding/unsharding programs will be in nested entries of the AOT
+  // compilation result.
+  auto status_or = CompileAheadOfTime(
+      metadata, compilation_result, per_core_arg_shapes, per_core_output_shapes,
+      per_core_variable_indices, xla_device_assignment);
+
+  TF_ASSIGN_OR_RETURN(std::vector<XLA_TpuProgram*> xla_tpu_programs,
+                      std::move(status_or));
+  // SPMD could return 1 result for all partitions.
+  TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
+               xla_tpu_programs.size() == metadata.num_cores_per_replica());
+  tpu_program->set_tpu_programs(xla_tpu_programs);
+
+  // TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
+  TpuSerializedProto serialized_executable_info;
+  TpuProgram_GetExecutableInfo(xla_tpu_programs[0],
+                               &serialized_executable_info);
+  TPUExecutableInfoProto executable_info =
+      se_tpu::DeserializeProto<TPUExecutableInfoProto>(
+          serialized_executable_info);
+  tpu_program->set_executable_info(executable_info);
+  StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
+
+  TPUHostTransferInfoProto host_transfer_info;
+  TpuSerializedProto serialized_host_transfer_info;
+  TpuProgram_GetHostTransferInfo(xla_tpu_programs[0],
+                                 &serialized_host_transfer_info);
+  if (serialized_host_transfer_info.size > 0) {
+    host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
+        serialized_host_transfer_info);
+    StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
+  }
+  tpu_program->set_host_transfer_info(host_transfer_info);
+
+  TpuSerializedProto serialized_hlo_metadata;
+  TpuProgram_GetHloMetadata(xla_tpu_programs[0], &serialized_hlo_metadata);
+  xla::HloProto hlo_metadata =
+      se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
+  tpu_program->set_hlo_metadata(hlo_metadata);
+  StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
+
+  return Status::OK();
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_program.h b/tensorflow/core/tpu/kernels/tpu_program.h
new file mode 100644
index 0000000..aee55bd
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_program.h
@@ -0,0 +1,161 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
+#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
+#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuAotCompilationOptions : public xla::AotCompilationOptions {
+ public:
+  explicit TpuAotCompilationOptions(int64 replica_count)
+      : num_cores_(0), replica_count_(replica_count) {}
+
+  // Returns the ID of the platform to which these options apply.
+  se::Platform::Id PlatformId() const override {
+    LOG(FATAL) << "Not implemented.";
+    return nullptr;
+  };
+
+  void set_num_cores(int64 tpu_cores) { num_cores_ = tpu_cores; }
+  int64 replica_count() const override { return replica_count_; }
+  int64 num_cores() const override { return num_cores_; }
+
+  void set_allow_separate_sharding_programs(bool allow) {
+    allow_separate_sharding_programs_ = allow;
+  }
+  bool allow_separate_sharding_programs() const {
+    return allow_separate_sharding_programs_;
+  }
+
+  const std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
+  shardable_value_update_pairs() const {
+    return shardable_value_update_pairs_;
+  }
+  void set_shardable_value_update_pairs(
+      std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs) {
+    shardable_value_update_pairs_ = std::move(pairs);
+  }
+
+ private:
+  int64 num_cores_;
+  int64 replica_count_;
+
+  // Whether to allow the compiler to create separte sharding and unsharding
+  // programs, and modify the original program's input/output sharded size. This
+  // is used for XLA-chosen sharding on parameters without an on-device loop:
+  // the caller can invoke sharding first, then (repeatedly) invoke the sharded
+  // main program, and finally invoke the unsharding program when it needs the
+  // full output.
+  bool allow_separate_sharding_programs_ = false;
+
+  // The list of input/output pairs in the main program that could be sharded.
+  std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
+      shardable_value_update_pairs_;
+};
+
+// An executable capable of being fed to a TPU device.
+class TpuProgram {
+ public:
+  using Status = ::stream_executor::port::Status;
+
+  virtual ~TpuProgram() = default;
+
+  static Status Build(
+      const TPUCompileMetadataProto& metadata,
+      const tensorflow::XlaCompiler::CompilationResult& compilation_result,
+      const std::vector<ShardingAndIndex>& arg_core_mapping,
+      const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
+      const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
+      TpuProgram* tpu_program);
+
+  size_t program_count() const {
+    return tpu_programs_.size();
+  }
+
+  int64_t program_size() const;
+
+  bool LogProgramMemorySummary();
+
+  void UnloadAndDestroyPrograms();
+
+  const std::vector<bool>& may_modify_variables() const {
+    return may_modify_variables_;
+  }
+  void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
+    may_modify_variables_ = may_modify_variables;
+  }
+
+  const tf2xla::HostComputeMetadata& host_compute_metadata() const {
+    return host_compute_metadata_;
+  }
+  void set_host_compute_metadata(
+      const tf2xla::HostComputeMetadata& host_compute_metadata) {
+    host_compute_metadata_ = host_compute_metadata;
+  }
+
+  const std::vector<XLA_TpuProgram*>& tpu_programs() const {
+    return tpu_programs_;
+  }
+  void set_tpu_programs(std::vector<XLA_TpuProgram*> tpu_programs) {
+    tpu_programs_ = tpu_programs;
+  }
+
+  const TPUExecutableInfoProto& executable_info() const {
+    return executable_info_;
+  }
+  void set_executable_info(const TPUExecutableInfoProto& executable_info) {
+    executable_info_ = executable_info;
+  }
+
+  const TPUHostTransferInfoProto& host_transfer_info() const {
+    return host_transfer_info_;
+  }
+  void set_host_transfer_info(
+      const TPUHostTransferInfoProto& host_transfer_info) {
+    host_transfer_info_ = host_transfer_info;
+  }
+
+  const xla::HloProto& hlo_metadata() const { return hlo_metadata_; }
+  void set_hlo_metadata(const xla::HloProto& hlo_metadata) {
+    hlo_metadata_ = hlo_metadata;
+  }
+
+ private:
+  std::vector<bool> may_modify_variables_;
+  tf2xla::HostComputeMetadata host_compute_metadata_;
+
+  std::vector<XLA_TpuProgram*> tpu_programs_;  // Not owned.
+  TPUExecutableInfoProto executable_info_;
+  TPUHostTransferInfoProto host_transfer_info_;
+  xla::HloProto hlo_metadata_;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc
new file mode 100644
index 0000000..5c286de
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_util.cc
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_util.h"
+
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/platform/random.h"
+
+namespace tensorflow {
+namespace tpu {
+
+std::string SessionNameFromMetadata(const SessionMetadata* session_metadata) {
+  return session_metadata ? session_metadata->name() : "";
+}
+
+std::string ProtoKeyForComputation(const std::string& key, int core) {
+  return absl::StrCat(key, ":", core);
+}
+
+xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
+    const std::string& key) {
+  const std::vector<std::string> splits = absl::StrSplit(key, '|');
+  if (splits.size() == 1) {
+    // No guaranteed_const.
+    return TpuCompilationCacheKey(key);
+  } else if (splits.size() != 3) {
+    return errors::InvalidArgument("Invalid TPU compilation cache key:", key);
+  }
+
+  TpuCompilationCacheKey parsed_key(splits.at(0));
+  parsed_key.has_guaranteed_const = true;
+  parsed_key.session_handle = splits.at(1);
+  const string fingerprint = splits.at(2);
+  parsed_key.guaranteed_const_fingerprint = [fingerprint] {
+    return fingerprint;
+  };
+  return parsed_key;
+}
+
+xla::CompileOnlyClient::AotXlaComputationInstance
+BuildAotXlaComputationInstance(
+    const XlaCompiler::CompilationResult& compilation_result) {
+  xla::CompileOnlyClient::AotXlaComputationInstance instance;
+  instance.computation = compilation_result.computation.get();
+  for (const xla::Shape& shape : compilation_result.xla_input_shapes) {
+    instance.argument_layouts.push_back(&shape);
+  }
+  instance.result_layout = &compilation_result.xla_output_shape;
+  return instance;
+}
+
+Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape) {
+  if (tensor.dtype() != DT_INT64 ||
+      !TensorShapeUtils::IsVector(tensor.shape())) {
+    return errors::InvalidArgument("Shape tensor must be an int64 vector.");
+  }
+  const int64 rank = tensor.NumElements();
+  auto tensor_dims = tensor.flat<int64>();
+  std::vector<int64> dims(rank);
+  for (int64 i = 0; i < rank; ++i) {
+    dims[i] = tensor_dims(i);
+  }
+  return TensorShapeUtils::MakeShape(dims, shape);
+}
+
+Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
+                                   std::vector<TensorShape>* shapes) {
+  shapes->resize(dynamic_shapes.size());
+  for (int i = 0; i < dynamic_shapes.size(); ++i) {
+    TF_RETURN_IF_ERROR(
+        ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
+  }
+  return Status::OK();
+}
+
+Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
+                                   std::vector<TensorShape>* shapes) {
+  shapes->resize(dynamic_shapes.end() - dynamic_shapes.begin());
+  size_t i = 0;
+  for (auto& dynamic_shape : dynamic_shapes) {
+    TF_RETURN_IF_ERROR(
+        ShapeTensorToTensorShape(dynamic_shape.tensor(), &(*shapes)[i]));
+    ++i;
+  }
+  return Status::OK();
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h
new file mode 100644
index 0000000..0ca94d0
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_util.h
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Utility to get session_name from `SessionMetadata`. `SessionMetadata` may
+// be null.
+std::string SessionNameFromMetadata(const SessionMetadata* session_metadata);
+
+// Generates cache proto key for a given computation on a TPU core.
+std::string ProtoKeyForComputation(const std::string& key, int core);
+
+// Returns a TpuCompilationCacheKey parsed from given key or an error.
+xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
+    const std::string& key);
+
+xla::CompileOnlyClient::AotXlaComputationInstance
+BuildAotXlaComputationInstance(
+    const XlaCompiler::CompilationResult& compilation_result);
+
+// Returns true if TPU compilation is enabled.
+bool IsTpuCompilationEnabled();
+
+// Converts an int64 host memory `tensor` to a `shape`.
+Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape);
+
+Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
+                                   std::vector<TensorShape>* shapes);
+Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
+                                   std::vector<TensorShape>* shapes);
+
+// Given a tensor of `shape` and `type`, as what shape should it be stored on
+// the TPU device? This function tranposes or flattens the excessively-padded
+// tensors to rank 1, but leaves other tensor shapes alone.
+xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
+                                                 DataType type,
+                                                 bool use_fast_memory);
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
diff --git a/tensorflow/core/tpu/kernels/trace_util.h b/tensorflow/core/tpu/kernels/trace_util.h
new file mode 100644
index 0000000..4e0b7c9
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/trace_util.h
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
+
+#ifdef PLATFORM_GOOGLE
+#include "base/tracer.h"
+#else
+#undef TRACESTRING
+#define TRACESTRING(x)
+#undef TRACELITERAL
+#define TRACELITERAL(x)
+#endif
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
diff --git a/tensorflow/core/tpu/tpu_library_loader.cc b/tensorflow/core/tpu/tpu_library_loader.cc
index 3bc835c..c89de14 100644
--- a/tensorflow/core/tpu/tpu_library_loader.cc
+++ b/tensorflow/core/tpu/tpu_library_loader.cc
@@ -64,13 +64,20 @@
 }
 
 Status InitializeTpuLibrary(void* library_handle) {
+  bool shared_object_loaded = true;
   if (library_handle == nullptr) {
     library_handle = dlopen(nullptr, RTLD_LAZY);
+    shared_object_loaded = false;
   }
 
   TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
   TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
 
+  if (shared_object_loaded) {
+    // Initialize TPU platform when the platform code is loaded from a library.
+    InitializeApiFn()->TfTpu_InitializeFn();
+  }
+
   return Status::OK();
 }
 
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
new file mode 100644
index 0000000..52ca40f
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -0,0 +1,234 @@
+# Description: StreamExecutor Interface for TPUs
+
+package(
+    default_visibility = ["//tensorflow/core/tpu:__subpackages__"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "tpu_executor_c_api_hdrs",
+    hdrs = ["tpu_executor_c_api.h"],
+    deps = [
+        "//tensorflow/c:tf_attrtype",
+        "//tensorflow/c:tf_datatype",
+        "//tensorflow/c:tf_status",
+    ],
+)
+
+cc_library(
+    name = "tpu_node_context_c_api_hdrs",
+    hdrs = ["tpu_node_context_c_api.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+    ],
+)
+
+cc_library(
+    name = "status_helper",
+    hdrs = ["status_helper.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        "//tensorflow/core/platform:status",
+    ],
+)
+
+cc_library(
+    name = "c_api_conversions",
+    hdrs = ["c_api_conversions.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/service:shaped_buffer",
+        "//tensorflow/stream_executor:stream",
+        "@com_google_absl//absl/container:inlined_vector",
+    ],
+)
+
+cc_library(
+    name = "proto_helper",
+    srcs = ["proto_helper.cc"],
+    hdrs = ["proto_helper.h"],
+    deps = ["//tensorflow/core:lib"],
+)
+
+cc_library(
+    name = "tpu_stream",
+    hdrs = ["tpu_stream.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        "//tensorflow/stream_executor:stream",
+        "//tensorflow/stream_executor/lib",
+    ],
+)
+
+cc_library(
+    name = "tpu_timer",
+    hdrs = ["tpu_timer.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        "//tensorflow/core/platform:types",
+        "//tensorflow/stream_executor:stream",
+    ],
+)
+
+cc_library(
+    name = "tpu_executor",
+    srcs = ["tpu_executor.cc"],
+    hdrs = ["tpu_executor.h"],
+    deps = [
+        ":c_api_conversions",
+        ":status_helper",
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_executor_interface",
+        ":tpu_platform",
+        ":tpu_platform_interface",
+        ":tpu_stream",
+        ":tpu_timer",
+        "//tensorflow/c:tf_status",
+        "//tensorflow/core:lib",
+        "//tensorflow/stream_executor:stream",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/container:flat_hash_map",
+    ],
+)
+
+cc_library(
+    name = "tpu_executor_hdrs",
+    hdrs = ["tpu_executor.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_executor_interface",
+        ":tpu_platform_hdrs",
+        ":tpu_platform_interface",
+        "//tensorflow/core/platform:types",
+        "//tensorflow/stream_executor:stream_header",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/container:flat_hash_map",
+    ],
+)
+
+cc_library(
+    name = "tpu_platform_hdrs",
+    hdrs = ["tpu_platform.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_platform_interface",
+        "//tensorflow/core/platform:types",
+        "//tensorflow/stream_executor:stream_header",
+        "@com_google_absl//absl/container:flat_hash_map",
+    ],
+)
+
+cc_library(
+    name = "tpu_node_context",
+    srcs = ["tpu_node_context.cc"],
+    hdrs = ["tpu_node_context.h"],
+    deps = [
+        ":status_helper",
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_node_context_c_api_hdrs",
+        ":tpu_platform_interface",
+        ":tpu_transfer_manager",
+        "//tensorflow/compiler/xla/service",
+        "//tensorflow/compiler/xla/service:backend",
+        "//tensorflow/compiler/xla/service:platform_util",
+        "//tensorflow/compiler/xla/service:stream_pool",
+        "//tensorflow/compiler/xla/service:transfer_manager",
+        "//tensorflow/core:framework",
+        "//tensorflow/stream_executor:device_memory_allocator",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/memory",
+    ],
+)
+
+cc_library(
+    name = "tpu_platform",
+    srcs = ["tpu_platform.cc"],
+    hdrs = ["tpu_platform.h"],
+    deps = [
+        ":status_helper",
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_executor_hdrs",
+        ":tpu_platform_interface",
+        "//tensorflow/c:tf_status",
+        "//tensorflow/core/platform:types",
+        "//tensorflow/stream_executor:stream",
+        "@com_google_absl//absl/container:flat_hash_map",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "tpu_transfer_manager",
+    srcs = ["tpu_transfer_manager_registration.cc"],
+    deps = [
+        ":tpu_platform",
+        ":tpu_transfer_manager_base",
+        "//tensorflow/compiler/xla/service:transfer_manager",
+    ],
+)
+
+cc_library(
+    name = "tpu_transfer_manager_base",
+    srcs = ["tpu_transfer_manager.cc"],
+    hdrs = ["tpu_transfer_manager.h"],
+    deps = [
+        ":c_api_conversions",
+        ":proto_helper",
+        ":status_helper",
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_platform",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/service:shaped_buffer",
+        "//tensorflow/compiler/xla/service:transfer_manager",
+        "//tensorflow/stream_executor:stream",
+    ],
+)
+
+cc_library(
+    name = "tpu_computation_placer",
+    srcs = ["tpu_computation_placer.cc"],
+    hdrs = ["tpu_computation_placer.h"],
+    deps = [
+        ":tpu_executor_c_api_hdrs",
+        ":tpu_platform",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/service:computation_placer",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "tpu_platform_interface",
+    srcs = ["tpu_platform_interface.cc"],
+    hdrs = ["tpu_platform_interface.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core/platform:types",
+        "//tensorflow/stream_executor:multi_platform_manager",
+        "//tensorflow/stream_executor:stream_executor_headers",
+    ],
+)
+
+cc_library(
+    name = "tpu_stream_interface",
+    hdrs = ["tpu_stream_interface.h"],
+    visibility = ["//visibility:public"],
+    deps = ["//tensorflow/stream_executor:stream_executor_internal"],
+)
+
+cc_library(
+    name = "tpu_executor_interface",
+    hdrs = ["tpu_executor_interface.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":tpu_platform_interface",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/stream_executor:stream_executor_internal",
+        "//tensorflow/stream_executor:stream_header",
+    ],
+)
diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.h b/tensorflow/stream_executor/tpu/c_api_conversions.h
new file mode 100644
index 0000000..1bb9ece
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/c_api_conversions.h
@@ -0,0 +1,115 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
+
+#include "absl/container/inlined_vector.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class TpuConversions {
+ public:
+  static stream_executor::DeviceMemoryBase
+  SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) {
+    stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size);
+    base.SetPayload(se_base.payload);
+    return base;
+  }
+
+  static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase(
+      const stream_executor::DeviceMemoryBase& base) {
+    SE_DeviceMemoryBase se_base;
+    se_base.opaque = const_cast<void*>(base.opaque());
+    se_base.payload = base.payload();
+    se_base.size = base.size();
+    return se_base;
+  }
+
+  static xla::Shape CShapeToXlaShape(XLA_Shape* shape) {
+    xla::ShapeProto p;
+    p.ParseFromArray(shape->bytes, shape->size);
+    return xla::Shape(p);
+  }
+
+  static void XlaShapeToCShape(const xla::Shape& xla_shape,
+                               XLA_Shape* c_shape) {
+    xla::ShapeProto p = xla_shape.ToProto();
+    std::string p_str = p.SerializeAsString();
+    c_shape->bytes = new char[p_str.size()];
+    c_shape->size = p_str.size();
+    memcpy(c_shape->bytes, p_str.data(), p_str.size());
+  }
+
+  static void XLAShapedBufferToCShapedBuffer(
+      const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
+    XlaShapeToCShape(buffer.on_host_shape(), &c_device_buffer->on_host_shape);
+    XlaShapeToCShape(buffer.on_device_shape(),
+                     &c_device_buffer->on_device_shape);
+    c_device_buffer->device_ordinal = buffer.device_ordinal();
+    absl::InlinedVector<SE_DeviceMemoryBase, 2> bases;
+    for (auto& pair : buffer.buffers()) {
+      bases.push_back(DeviceMemoryBaseToSE_DeviceMemoryBase(pair.second));
+    }
+    c_device_buffer->count = bases.size();
+    c_device_buffer->bases = new SE_DeviceMemoryBase[bases.size()];
+    for (int i = 0; i < bases.size(); ++i) {
+      c_device_buffer->bases[i] = bases[i];
+    }
+  }
+
+  static void XLALiteralToCLiteral(const xla::LiteralSlice& literal,
+                                   XLA_Literal* c_literal) {
+    XlaShapeToCShape(literal.shape(), &c_literal->shape);
+    auto shapes = xla::ShapeUtil::GetLeafShapes(literal.shape());
+    c_literal->buffers = new char*[shapes.size()];
+    c_literal->sizes = new size_t[shapes.size()];
+    c_literal->count = shapes.size();
+    for (int i = 0; i < shapes.size(); ++i) {
+      c_literal->buffers[i] = reinterpret_cast<char*>(
+          const_cast<void*>(literal.untyped_data(shapes[i].index)));
+      c_literal->sizes[i] = literal.size_bytes(shapes[i].index);
+    }
+  }
+
+  static xla::MutableBorrowingLiteral CLiteralToXLALiteral(
+      XLA_Literal* c_literal) {
+    xla::Shape shape = CShapeToXlaShape(&c_literal->shape);
+    LOG(INFO) << "Shape: " << shape.DebugString();
+    return xla::MutableBorrowingLiteral(
+        absl::MakeSpan(c_literal->buffers, c_literal->count), shape);
+  }
+
+  static void CShapeCleanup(XLA_Shape* c_shape) { delete[] c_shape->bytes; }
+
+  static void CLiteralCleanup(XLA_Literal* c_literal) {
+    delete[] c_literal->buffers;
+    delete[] c_literal->sizes;
+    CShapeCleanup(&c_literal->shape);
+  }
+
+  static void CShapedBufferCleanup(XLA_ShapedBuffer* c_buffer) {
+    CShapeCleanup(&c_buffer->on_device_shape);
+    CShapeCleanup(&c_buffer->on_host_shape);
+    delete[] c_buffer->bases;
+  }
+};
+
+#endif  // THIRD_PARTY_TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
diff --git a/tensorflow/stream_executor/tpu/proto_helper.cc b/tensorflow/stream_executor/tpu/proto_helper.cc
new file mode 100644
index 0000000..db663c6
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/proto_helper.cc
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+
+extern "C" {
+
+void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto) {
+  CHECK_NE(proto, nullptr);
+  CHECK_NE(proto->bytes, nullptr);
+  CHECK_GT(proto->size, 0);
+  delete[] proto->bytes;
+}
+
+}  // extern "C"
diff --git a/tensorflow/stream_executor/tpu/proto_helper.h b/tensorflow/stream_executor/tpu/proto_helper.h
new file mode 100644
index 0000000..3bd2b09
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/proto_helper.h
@@ -0,0 +1,85 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
+
+#include <cstddef>
+
+#include "tensorflow/core/platform/logging.h"
+
+extern "C" {
+
+typedef struct TpuSerializedProto {
+  const char* bytes;
+  size_t size;
+} TpuSerializedProto;
+
+void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto);
+
+}  // extern "C"
+
+namespace stream_executor {
+namespace tpu {
+
+using SerializedProto = TpuSerializedProto;
+
+// Serializes a proto and put the result in the given SerializedProto* argument.
+//
+// Users should call SerializedProto_Free on `serialized_proto` afterwards.
+template <class Proto>
+inline void SerializeProto(const Proto& proto,
+                           SerializedProto* serialized_proto) {
+  auto size = proto.ByteSizeLong();
+  auto bytes = new char[size];
+  CHECK(proto.SerializeToArray(bytes, size));
+  serialized_proto->size = size;
+  serialized_proto->bytes = bytes;
+}
+
+// Serializes a proto and return the result as a SerializedProto value.
+//
+// Users should call SerializedProto_Free on the return value afterwards.
+template <class Proto>
+inline SerializedProto SerializeProto(const Proto& proto) {
+  SerializedProto serialized_proto;
+  SerializeProto(proto, &serialized_proto);
+  return serialized_proto;
+}
+
+// Deserializes a buffer and return the corresponding proto. If the buffer is
+// empty, return an empty proto.
+template <class Proto>
+inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
+  Proto proto;
+  if (serialized_proto.bytes != nullptr) {
+    CHECK_GT(serialized_proto.size, 0);
+    CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
+        << "Invalid buffer, failed to deserialize buffer.";
+  }
+  return proto;
+}
+
+// Releases the memory allocated for serialized protos.
+inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
+  CHECK_NE(serialized_proto.bytes, nullptr);
+  CHECK_GT(serialized_proto.size, 0);
+  delete[] serialized_proto.bytes;
+}
+
+}  // namespace tpu
+}  // namespace stream_executor
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/status_helper.h b/tensorflow/stream_executor/tpu/status_helper.h
new file mode 100644
index 0000000..8fcf302
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/status_helper.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
+
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+struct StatusHelper {
+  StatusHelper() : c_status(TpuStatus_New()) {}
+  ~StatusHelper() { TpuStatus_Free(c_status); }
+  bool ok() { return TpuStatus_Code(c_status) == 0; }
+  tensorflow::Status status() {
+    if (!ok()) {
+      return tensorflow::Status(
+          tensorflow::error::Code(TpuStatus_Code(c_status)),
+          TpuStatus_Message(c_status));
+    } else {
+      return tensorflow::Status::OK();
+    }
+  }
+  SE_Status* c_status;
+};
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.cc b/tensorflow/stream_executor/tpu/tpu_computation_placer.cc
new file mode 100644
index 0000000..660b446
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_computation_placer.cc
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
+
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+template <typename T>
+using StatusOr = TpuComputationPlacer::StatusOr<T>;
+
+TpuComputationPlacer::TpuComputationPlacer() {
+  placer_ = TpuComputationPlacer_New();
+}
+
+TpuComputationPlacer::~TpuComputationPlacer() {
+  TpuComputationPlacer_Free(placer_);
+}
+
+StatusOr<int> TpuComputationPlacer::DeviceId(int replica, int computation,
+                                             int replica_count,
+                                             int computation_count) {
+  LOG(FATAL) << "Unimplemented.";
+}
+
+StatusOr<xla::DeviceAssignment> TpuComputationPlacer::AssignDevices(
+    int replica_count, int computation_count) {
+  LOG(FATAL) << "Unimplemented.";
+}
+
+static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
+  return std::make_unique<TpuComputationPlacer>();
+}
+
+static bool InitModule() {
+  xla::ComputationPlacer::RegisterComputationPlacer(
+      tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer);
+  return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.h b/tensorflow/stream_executor/tpu/tpu_computation_placer.h
new file mode 100644
index 0000000..c8f4c9e
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_computation_placer.h
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
+
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class TpuComputationPlacer : public xla::ComputationPlacer {
+ public:
+  template <typename T>
+  using StatusOr = xla::StatusOr<T>;
+
+  TpuComputationPlacer();
+  ~TpuComputationPlacer() override;
+
+  StatusOr<int> DeviceId(int replica, int computation, int replica_count,
+                         int computation_count) override;
+
+  StatusOr<xla::DeviceAssignment> AssignDevices(int replica_count,
+                                                int computation_count) override;
+
+ private:
+  XLA_ComputationPlacer* placer_;
+};
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc
new file mode 100644
index 0000000..92808936
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor.cc
@@ -0,0 +1,355 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_executor.h"
+
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_stream.h"
+#include "tensorflow/stream_executor/tpu/tpu_timer.h"
+
+using stream_executor::DeviceMemoryBase;
+
+namespace tensorflow {
+
+namespace {
+using ::stream_executor::port::Status;
+}  // namespace
+
+TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); }
+
+Status TpuExecutor::Init(int device_ordinal,
+                         ::stream_executor::DeviceOptions device_options) {
+  StatusHelper status;
+  SE_DeviceOptions* options =
+      TpuExecutor_NewDeviceOptions(device_options.flags());
+  TpuExecutor_Init(executor_, device_ordinal, options, status.c_status);
+  TpuExecutor_FreeDeviceOptions(options);
+  return status.status();
+}
+
+int TpuExecutor::PlatformDeviceCount() {
+  return TpuExecutor_PlatformDeviceCount(executor_);
+}
+
+void TpuExecutor::SyncAndForgetFailedStreams() {
+  TpuExecutor_SyncAndForgetFailedStreams(executor_);
+}
+
+bool TpuExecutor::SynchronizeAllActivity() {
+  return TpuExecutor_SynchronizeAllActivity(executor_);
+}
+
+Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
+  StatusHelper status;
+  TpuExecutor_BlockHostUntilDone(
+      executor_, stream_map().at(stream->implementation()), status.c_status);
+  return status.status();
+}
+
+Status TpuExecutor::BlockUntilDoneOrFailed() {
+  StatusHelper status;
+  TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status);
+  return status.status();
+}
+
+Status TpuExecutor::GetStatus(Stream* stream) {
+  StatusHelper status;
+  TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()),
+                        status.c_status);
+  return status.status();
+}
+
+bool TpuExecutor::AllocateStream(Stream* stream) {
+  return TpuExecutor_AllocateStream(executor_,
+                                    stream_map().at(stream->implementation()));
+}
+
+void TpuExecutor::DeallocateStream(Stream* stream) {
+  TpuExecutor_DeallocateStream(executor_,
+                               stream_map().at(stream->implementation()));
+  stream_map().erase(stream->implementation());
+}
+
+bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
+  return TpuExecutor_CreateStreamDependency(
+      executor_, stream_map().at(dependent->implementation()),
+      stream_map().at(other->implementation()));
+}
+
+Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); }
+
+Status TpuExecutor::DeallocateEvent(Event* event) { return Status::OK(); }
+
+// AllocateTimer/DeallocateTimer have no specialization.
+bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
+
+void TpuExecutor::DeallocateTimer(Timer* timer) {}
+
+bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
+  return TpuExecutor_StartTimer(executor_,
+                                stream_map().at(stream->implementation()),
+                                timer_map_.at(timer->implementation()));
+}
+
+bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
+  return TpuExecutor_StopTimer(executor_,
+                               stream_map().at(stream->implementation()),
+                               timer_map_.at(timer->implementation()));
+}
+
+stream_executor::Event::Status TpuExecutor::PollForEventStatus(
+    stream_executor::Event* event) {
+  return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
+      executor_, event_map_.at(event->implementation())));
+}
+
+Status TpuExecutor::RecordEvent(Stream* stream,
+                                ::stream_executor::Event* event) {
+  StatusHelper status;
+  TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
+                          event_map_.at(event->implementation()),
+                          status.c_status);
+  return status.status();
+}
+
+Status TpuExecutor::WaitForEvent(Stream* stream,
+                                 ::stream_executor::Event* event) {
+  StatusHelper status;
+  TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
+                           event_map_.at(event->implementation()),
+                           status.c_status);
+  return status.status();
+}
+
+// Implementations for Timer, Stream, Event
+// We need to map these implementations to internal equivalents -- thus we
+// allocate the internal Timer, Stream and Event operations here, and map
+// the implementations to the internal values. The "wrapper" interfaces are
+// responsible for deallocating the internal value when they are destroyed.
+
+// Called by Timer::Timer
+std::unique_ptr<::stream_executor::internal::TimerInterface>
+TpuExecutor::GetTimerImplementation() {
+  SE_Timer* tpu_timer = TpuTimer_New(executor_);
+  auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
+  timer_map_[ptr.get()] = tpu_timer;
+  return ptr;
+}
+
+// Called by Stream::Stream
+std::unique_ptr<::stream_executor::internal::StreamInterface>
+TpuExecutor::GetStreamImplementation() {
+  SE_Stream* tpu_stream = TpuStream_New(executor_);
+  auto ptr = absl::make_unique<TpuStream>(tpu_stream);
+  stream_map()[ptr.get()] = tpu_stream;
+  return ptr;
+}
+
+// Called by Event::Event
+std::unique_ptr<::stream_executor::internal::EventInterface>
+TpuExecutor::CreateEventImplementation() {
+  SE_Event* tpu_event = TpuEvent_New(executor_);
+  auto ptr = absl::make_unique<TpuEvent>(tpu_event);
+  event_map_[ptr.get()] = tpu_event;
+  return ptr;
+}
+
+DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
+  SE_DeviceMemoryBase se_base =
+      TpuExecutor_Allocate(executor_, size, memory_space);
+  return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
+}
+
+void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
+  SE_DeviceMemoryBase se_base =
+      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
+  TpuExecutor_Deallocate(executor_, &se_base);
+}
+
+void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
+  SE_DeviceMemoryBase se_base =
+      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
+  TpuExecutor_Deallocate(executor_, &se_base);
+}
+
+bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
+  int64_t _free;
+  int64_t _total;
+  if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) {
+    *free = _free;
+    *total = _total;
+    return true;
+  }
+  return false;
+}
+
+absl::optional<stream_executor::AllocatorStats>
+TpuExecutor::GetAllocatorStats() {
+  SE_AllocatorStats c_stats;
+  if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) {
+    ::stream_executor::AllocatorStats stats;
+    stats.num_allocs = c_stats.num_allocs;
+    stats.bytes_in_use = c_stats.bytes_in_use;
+    stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
+    stats.largest_alloc_size = c_stats.largest_alloc_size;
+    if (c_stats.has_bytes_limit) {
+      stats.bytes_limit = c_stats.bytes_limit;
+    }
+    stats.bytes_reserved = c_stats.bytes_reserved;
+    stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
+    if (c_stats.has_bytes_reservable_limit) {
+      stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
+    }
+    stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
+    return stats;
+  }
+  return {};
+}
+
+Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
+  StatusHelper status;
+  TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index,
+                                 status.c_status);
+  return status.status();
+}
+
+Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
+  StatusHelper status;
+  TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index,
+                                  status.c_status);
+  return status.status();
+}
+
+void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
+                                 absl::Span<uint8> bytes, StatusCallback done) {
+  StatusHelper status;
+  TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(),
+                             bytes.size(), status.c_status);
+  done(status.status());
+}
+
+Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
+                                  absl::Span<const uint8> bytes) {
+  StatusHelper status;
+  TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(),
+                            bytes.size(), status.c_status);
+  return status.status();
+}
+
+bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
+                         const ::stream_executor::DeviceMemoryBase& device_src,
+                         uint64 size) {
+  SE_DeviceMemoryBase se_base =
+      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+  return TpuExecutor_MemcpyToHost(executor_,
+                                  stream_map().at(stream->implementation()),
+                                  host_dst, &se_base, size);
+}
+
+bool TpuExecutor::Memcpy(Stream* stream,
+                         ::stream_executor::DeviceMemoryBase* device_dst,
+                         const void* host_src, uint64 size) {
+  SE_DeviceMemoryBase se_base =
+      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+  return TpuExecutor_MemcpyFromHost(executor_,
+                                    stream_map().at(stream->implementation()),
+                                    &se_base, host_src, size);
+}
+
+Status TpuExecutor::SynchronousMemcpy(
+    ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
+    uint64 size) {
+  StatusHelper status;
+  SE_DeviceMemoryBase se_base =
+      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
+  TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size,
+                                        status.c_status);
+  return status.status();
+}
+
+Status TpuExecutor::SynchronousMemcpy(
+    void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
+    uint64 size) {
+  StatusHelper status;
+  SE_DeviceMemoryBase se_base =
+      TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
+  TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size,
+                                      status.c_status);
+  return status.status();
+}
+
+Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
+    ::stream_executor::DeviceMemoryBase* device_dst,
+    const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) {
+  return ::stream_executor::port::UnimplementedError(
+      "This operation not supported on TPU");
+}
+
+bool TpuExecutor::MemcpyDeviceToDevice(
+    Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
+    const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) {
+  LOG(FATAL) << __func__ << " not supported on TpuExecutor";
+}
+
+struct HostCallbackContext {
+  std::function<Status()> callback;
+};
+
+SE_Status* HostCallbackTrampoline(void* ctx) {
+  HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
+  Status status = host_ctx->callback();
+  SE_Status* c_status =
+      TpuStatus_Create(status.code(), status.error_message().c_str());
+  delete host_ctx;
+  return c_status;
+}
+
+bool TpuExecutor::HostCallback(Stream* stream,
+                               std::function<Status()> callback) {
+  HostCallbackContext* ctx = new HostCallbackContext{callback};
+  return TpuExecutor_HostCallback(executor_,
+                                  stream_map().at(stream->implementation()),
+                                  &HostCallbackTrampoline, ctx);
+}
+
+TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
+TpuExecutor::CreateDeviceDescription() const {
+  StatusHelper status;
+  SE_DeviceDescription* description = TpuDeviceDescription_New();
+  auto cleanup = tensorflow::gtl::MakeCleanup(
+      [description]() { TpuDeviceDescription_Free(description); });
+  TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status);
+  if (status.status().ok()) {
+    stream_executor::internal::DeviceDescriptionBuilder builder;
+    CHECK_NE(description->device_vendor, nullptr);
+    builder.set_device_vendor(description->device_vendor);
+    builder.set_name(description->name);
+    builder.set_clock_rate_ghz(description->clock_rate_ghz);
+    builder.set_core_count(description->core_count);
+    builder.set_ecc_enabled(description->ecc_enabled);
+    builder.set_device_memory_size(description->device_memory_size);
+    builder.set_platform_version(description->platform_version);
+    return builder.Build();
+  }
+  return status.status();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_executor.h b/tensorflow/stream_executor/tpu/tpu_executor.h
new file mode 100644
index 0000000..5f36642
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor.h
@@ -0,0 +1,241 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/device_options.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/temporary_device_memory.h"
+#include "tensorflow/stream_executor/timer.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+
+class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
+ public:
+  using Status = ::stream_executor::port::Status;
+  template <typename T>
+  using StatusOr = ::stream_executor::port::StatusOr<T>;
+  using StatusCallback = std::function<void(const Status&)>;
+  using Stream = ::stream_executor::Stream;
+  using Event = ::stream_executor::Event;
+  using Timer = ::stream_executor::Timer;
+  using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
+  using StreamInterface = ::stream_executor::internal::StreamInterface;
+  using StreamExecutorInterface =
+      ::stream_executor::internal::StreamExecutorInterface;
+
+  using EventMap =
+      absl::flat_hash_map<stream_executor::internal::EventInterface*,
+                          SE_Event*>;
+  using TimerMap =
+      absl::flat_hash_map<stream_executor::internal::TimerInterface*,
+                          SE_Timer*>;
+
+  explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
+                       SE_StreamExecutor* executor)
+      : platform_(platform), executor_(executor) {}
+
+  ~TpuExecutor() override;
+
+  Status Init(int device_ordinal,
+              ::stream_executor::DeviceOptions device_options) override;
+
+  DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override;
+
+  StatusOr<DeviceMemoryBase> AllocateDeviceMemoryBase(uint64 size,
+                                                      int64 memory_space);
+
+  Status AllocateEvent(Event* event) override;
+
+  bool AllocateStream(Stream* stream) override;
+
+  bool AllocateTimer(Timer* timer) override;
+
+  Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
+
+  Status BlockUntilDoneOrFailed();
+
+  StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
+  CreateDeviceDescription() const override;
+
+  bool CreateStreamDependency(Stream* dependent, Stream* other) override;
+
+  void DeallocateStream(Stream* stream) override;
+
+  void Deallocate(const DeviceMemoryBase& memory);
+
+  void Deallocate(DeviceMemoryBase* memory) override;
+
+  Status DeallocateEvent(Event* event) override;
+
+  void DeallocateTimer(Timer* timer) override;
+
+  bool DeviceMemoryUsage(int64* free, int64* total) const override;
+
+  void DequeueOutfeed(int32 outfeed_queue_index, absl::Span<uint8> bytes,
+                      StatusCallback done);
+
+  Status EnqueueInfeed(int32 infeed_queue_index,
+                       absl::Span<const uint8> bytes);
+
+  absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
+
+  Status GetStatus(Stream* stream) override;
+
+  std::unique_ptr<::stream_executor::internal::StreamInterface>
+  GetStreamImplementation() override;
+
+  std::unique_ptr<::stream_executor::internal::TimerInterface>
+  GetTimerImplementation() override;
+
+  std::unique_ptr<::stream_executor::internal::EventInterface>
+  CreateEventImplementation() override;
+
+  bool HostCallback(Stream* stream, std::function<Status()> callback) override;
+
+  bool Memcpy(Stream* stream, void* host_dst,
+              const ::stream_executor::DeviceMemoryBase& device_src,
+              uint64 size) override;
+
+  bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
+              const void* host_src, uint64 size) override;
+
+  bool MemcpyDeviceToDevice(Stream* stream,
+                            ::stream_executor::DeviceMemoryBase* gpu_dst,
+                            const ::stream_executor::DeviceMemoryBase& host_src,
+                            uint64 size) override;
+
+  void SyncAndForgetFailedStreams();
+  bool SynchronizeAllActivity() override;
+
+  Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
+                           const void* host_src, uint64 size) override;
+  Status SynchronousMemcpy(
+      void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
+      uint64 size) override;
+  Status SynchronousMemcpyDeviceToDevice(
+      ::stream_executor::DeviceMemoryBase* device_dst,
+      const ::stream_executor::DeviceMemoryBase& device_src,
+      uint64 size) override;
+
+  int PlatformDeviceCount() override;
+
+  Event::Status PollForEventStatus(Event* event) override;
+  Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
+  Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
+
+  bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
+  bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
+
+  Status WaitForInfeedReady(int32 infeed_queue_index);
+
+  Status WaitForOutfeedReady(int32 outfeed_queue_index);
+
+  const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
+    return *platform_;
+  }
+
+  ::tensorflow::tpu::TpuPlatformInterface& platform() override {
+    return *platform_;
+  }
+
+  // TODO(henrytan): convert this to override once the base interface is changed
+  // to TpuExecutorInterface.
+  StatusOr<std::unique_ptr<
+      tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
+  CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
+                              int64 size) override {
+    LOG(FATAL) << "Unimplemented.";
+  }
+
+  // -- Unimplemented (stubbed out) methods.
+  std::unique_ptr<stream_executor::internal::KernelInterface>
+  CreateKernelImplementation() override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  stream_executor::SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
+    LOG(FATAL) << "not yet implemented";
+  }
+
+  void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
+                     uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  Status MemZero(Stream* stream, DeviceMemoryBase* location,
+                 uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
+                  uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  Status SetDeviceSharedMemoryConfig(
+      stream_executor::SharedMemoryConfig config) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  void* HostMemoryAllocate(uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  void HostMemoryDeallocate(void* mem) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  bool HostMemoryRegister(void* mem, uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  bool HostMemoryUnregister(void* mem) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+  Status SynchronousMemSet(DeviceMemoryBase* location, int value,
+                           uint64 size) override {
+    LOG(FATAL) << "not yet implemented";
+  }
+
+ private:
+  EventMap event_map_;
+  TimerMap timer_map_;
+
+  TpuPlatform::StreamMap& stream_map() {
+    return *(static_cast<TpuPlatform*>(platform_)->stream_map());
+  }
+
+  ::tensorflow::tpu::TpuPlatformInterface* platform_;
+  SE_StreamExecutor* executor_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
new file mode 100644
index 0000000..8bf2ecb
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
@@ -0,0 +1,293 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include "tensorflow/c/tf_attrtype.h"
+#include "tensorflow/c/tf_datatype.h"
+#include "tensorflow/c/tf_status.h"
+
+typedef struct SE_Platform SE_Platform;
+typedef struct SE_StreamExecutor SE_StreamExecutor;
+typedef struct SE_Stream SE_Stream;
+typedef struct SE_Event SE_Event;
+typedef struct SE_Timer SE_Timer;
+typedef struct SE_Status SE_Status;
+
+typedef struct SE_PlatformId {
+  void* id;  // aka stream_executor::Platform::Id
+} SE_PlatformId;
+typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig;
+typedef struct SE_DeviceOptions SE_DeviceOptions;
+typedef SE_Status* (*SE_StatusCallbackFn)(void*);
+
+typedef struct SE_DeviceMemoryBase {
+  void* opaque;
+  uint64_t size;
+  uint64_t payload;
+} SE_DeviceMemoryBase;
+
+typedef struct SE_AllocatorStats {
+  int64_t num_allocs;
+  int64_t bytes_in_use;
+  int64_t peak_bytes_in_use;
+  int64_t largest_alloc_size;
+
+  bool has_bytes_limit;
+  int64_t bytes_limit;
+
+  int64_t bytes_reserved;
+  int64_t peak_bytes_reserved;
+
+  bool has_bytes_reservable_limit;
+  int64_t bytes_reservable_limit;
+
+  int64_t largest_free_block_bytes;
+} SE_AllocatorStats;
+
+typedef struct SE_DeviceDescription {
+  char* device_vendor;
+  char* platform_version;
+  char* driver_version;
+  char* runtime_version;
+  char* pci_bus_id;
+  char* name;
+
+  int64_t thread_dim_limit_x;
+  int64_t thread_dim_limit_y;
+  int64_t thread_dim_limit_z;
+  int64_t block_dim_limit_x;
+  int64_t block_dim_limit_y;
+  int64_t block_dim_limit_z;
+
+  int64_t threads_per_core_limit;
+  int64_t threads_per_block_limit;
+  int64_t threads_per_warp;
+
+  int64_t registers_per_core_limit;
+  int64_t registers_per_block_limit;
+
+  int64_t device_address_bits;
+  int64_t device_memory_size;
+  int64_t memory_bandwidth;
+
+  int64_t shared_memory_per_core;
+  int64_t shared_memory_per_block;
+
+  float clock_rate_ghz;
+
+  int cuda_compute_capability_major;
+  int cuda_compute_capability_minor;
+
+  int rocm_amdgpu_isa_version;
+
+  int numa_node;
+  int core_count;
+  bool ecc_enabled;
+} SE_DeviceDescription;
+
+typedef struct XLA_TransferManager XLA_TransferManager;
+
+typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
+
+// Represents an XLA shape tree.
+// Shapes are flattened in default traversal order.
+typedef struct XLA_Shape {
+  char* bytes;
+  size_t size;
+} XLA_Shape;
+
+// Represents a leaf node for a XLA shaped buffer.
+typedef struct XLA_ShapedBuffer {
+  XLA_Shape on_host_shape;
+  XLA_Shape on_device_shape;
+  int device_ordinal;
+
+  SE_DeviceMemoryBase* bases;
+  size_t count;
+} XLA_ShapedBuffer;
+
+// Represents a leaf XLA literal.
+typedef struct XLA_Literal {
+  char** buffers;
+  size_t* sizes;
+  size_t count;
+  XLA_Shape shape;
+} XLA_Literal;
+
+typedef void (*XLA_CallbackFn)(void*);
+typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*);
+
+extern "C" {
+
+SE_Platform* TpuPlatform_New();
+void TpuPlatform_Free(SE_Platform* platform);
+void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size,
+                            const char** options_key,
+                            const char** options_value, SE_Status* status);
+bool TpuPlatform_Initialized(SE_Platform* platform);
+SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
+                                           SE_StreamExecutorConfig* config,
+                                           SE_Status* status);
+SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
+int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
+int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
+
+void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
+                      SE_DeviceOptions* device_options, SE_Status* status);
+void TpuExecutor_Free(SE_StreamExecutor* executor);
+
+int TpuExecutor_PlatformDeviceCount(SE_StreamExecutor* executor);
+
+SE_DeviceMemoryBase TpuExecutor_Allocate(SE_StreamExecutor* executor,
+                                         uint64_t size, int64_t memory_space);
+void TpuExecutor_Deallocate(SE_StreamExecutor* executor,
+                            SE_DeviceMemoryBase* memory);
+bool TpuExecutor_GetAllocatorStats(SE_StreamExecutor* executor,
+                                   SE_AllocatorStats* stats);
+bool TpuExecutor_DeviceMemoryUsage(SE_StreamExecutor* executor, int64_t* free,
+                                   int64_t* total);
+
+bool TpuExecutor_AllocateStream(SE_StreamExecutor* executor, SE_Stream* stream);
+void TpuExecutor_DeallocateStream(SE_StreamExecutor* executor,
+                                  SE_Stream* stream);
+bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor,
+                                        SE_Stream* dependent, SE_Stream* other);
+void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream,
+                           SE_Status* status);
+
+void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event,
+                               SE_Status* status);
+void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event,
+                                 SE_Status* status);
+int TpuExecutor_PollForEventStatus(SE_StreamExecutor* executor,
+                                   SE_Event* event);
+void TpuExecutor_RecordEvent(SE_StreamExecutor* executor, SE_Stream* stream,
+                             SE_Event* event, SE_Status* status);
+void TpuExecutor_WaitForEvent(SE_StreamExecutor* executor, SE_Stream* stream,
+                              SE_Event* event, SE_Status* status);
+
+bool TpuExecutor_AllocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
+void TpuExecutor_DeallocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
+bool TpuExecutor_StartTimer(SE_StreamExecutor* executor, SE_Stream* stream,
+                            SE_Timer* timer);
+bool TpuExecutor_StopTimer(SE_StreamExecutor* executor, SE_Stream* stream,
+                           SE_Timer* timer);
+
+void TpuExecutor_SynchronousMemcpyToHost(SE_StreamExecutor* executor,
+                                         void* host_dst,
+                                         const SE_DeviceMemoryBase* device_src,
+                                         uint64_t size, SE_Status* status);
+void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor,
+                                           SE_DeviceMemoryBase* device_dst,
+                                           const void* host_src, uint64_t size,
+                                           SE_Status* status);
+bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream,
+                              void* host_dst,
+                              const SE_DeviceMemoryBase* device_src,
+                              uint64_t size);
+
+bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream,
+                                SE_DeviceMemoryBase* device_dst,
+                                const void* host_src, uint64_t size);
+
+void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor,
+                               int32_t infeed_queue_index, const uint8_t* data,
+                               int64_t size, SE_Status* status);
+void TpuExecutor_DequeueOutfeed(SE_StreamExecutor* executor,
+                                int32_t outfeed_queue_index, uint8_t* data,
+                                int64_t size, SE_Status* status);
+void TpuExecutor_WaitForInfeedReady(SE_StreamExecutor* executor,
+                                    int32_t infeed_queue_index,
+                                    SE_Status* status);
+void TpuExecutor_WaitForOutfeedReady(SE_StreamExecutor* executor,
+                                     int32_t outfeed_queue_index,
+                                     SE_Status* status);
+
+void TpuExecutor_BlockHostUntilDone(SE_StreamExecutor* executor,
+                                    SE_Stream* stream, SE_Status* status);
+void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
+                                        SE_Status* status);
+void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
+bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
+
+SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
+void TpuStream_Free(SE_Stream*);
+void* TpuStream_Stream(SE_Stream*);
+bool TpuStream_Status(SE_Stream*);
+
+SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
+void TpuEvent_Free(SE_Event*);
+
+SE_Timer* TpuTimer_New(SE_StreamExecutor* parent);
+void TpuTimer_Free(SE_Timer*);
+int64_t TpuTimer_Nanoseconds(SE_Timer*);
+int64_t TpuTimer_Microseconds(SE_Timer*);
+
+SE_Status* TpuStatus_New();
+SE_Status* TpuStatus_Create(int32_t code, const char* msg);
+void TpuStatus_Free(SE_Status* status);
+const char* TpuStatus_Message(SE_Status* status);
+int TpuStatus_Code(SE_Status* status);
+bool TpuStatus_Ok(SE_Status* status);
+
+SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default();
+void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal);
+void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*);
+
+SE_DeviceDescription* TpuDeviceDescription_New();
+void TpuDeviceDescription_Free(SE_DeviceDescription* description);
+void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor,
+                                         SE_DeviceDescription* description,
+                                         SE_Status* status);
+
+SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags);
+void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options);
+
+bool TpuExecutor_HostCallback(SE_StreamExecutor* executor, SE_Stream* stream,
+                              SE_StatusCallbackFn callback_fn, void* ctx);
+
+XLA_TransferManager* TpuTransferManager_New();
+void TpuTransferManager_Free(XLA_TransferManager* manager);
+SE_PlatformId TpuTransferManager_PlatformId(XLA_TransferManager* manager);
+void TpuTransferManager_HostShapeToDeviceShape(XLA_TransferManager* manager,
+                                               XLA_Shape* host_shape,
+                                               XLA_Shape* device_shape);
+void TpuTransferManager_TransferLiteralToDeviceAsync(
+    XLA_TransferManager* manager, SE_Stream* stream, XLA_Literal* literal,
+    XLA_ShapedBuffer* device_buffer, SE_Status* status);
+void TpuTransferManager_TransferLiteralFromDevice(
+    XLA_TransferManager* manager, SE_Stream* stream,
+    XLA_ShapedBuffer* device_buffer, XLA_Literal* literal,
+    XLA_StatusCallbackFn callback, void* ctx);
+
+int64_t TpuTransferManager_GetByteSizeRequirement(XLA_TransferManager* manager,
+                                                  XLA_Shape* shape);
+void TpuTransferManager_WriteSingleTupleIndexTable(
+    XLA_TransferManager* manager, SE_Stream* stream,
+    SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
+    SE_DeviceMemoryBase* region, SE_Status* status);
+
+XLA_ComputationPlacer* TpuComputationPlacer_New();
+void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
+}
+
+// extern "C"
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_executor_interface.h b/tensorflow/stream_executor/tpu/tpu_executor_interface.h
new file mode 100644
index 0000000..5b00f61
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_executor_interface.h
@@ -0,0 +1,64 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
+
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/timer.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tpu {
+class TpuCore;
+}  // namespace tpu
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuExecutorInterface
+    : public ::stream_executor::internal::StreamExecutorInterface {
+ public:
+  using Status = ::stream_executor::port::Status;
+  template <typename T>
+  using StatusOr = ::stream_executor::port::StatusOr<T>;
+
+  class TemporaryDeviceMemory {
+   public:
+    virtual ~TemporaryDeviceMemory() {}
+    virtual stream_executor::DeviceMemoryBase AsDeviceMemoryBase() const = 0;
+  };
+
+  virtual StatusOr<std::unique_ptr<TemporaryDeviceMemory>>
+  CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
+                              int64 size) {
+    LOG(FATAL) << "Unimplemented.";
+  }
+
+  virtual const TpuPlatformInterface& platform() const {
+    LOG(FATAL) << "Unimplemented.";
+  }
+
+  virtual TpuPlatformInterface& platform() { LOG(FATAL) << "Unimplemented."; }
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.cc b/tensorflow/stream_executor/tpu/tpu_node_context.cc
new file mode 100644
index 0000000..2a4954d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_node_context.cc
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/stream_executor/device_memory_allocator.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
+
+namespace tensorflow {
+namespace tpu {
+
+using stream_executor::port::Status;
+using stream_executor::port::StatusOr;
+
+/*static*/ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Initialize(
+    int device_ordinal) {
+  StatusHelper status;
+  XLA_TpuNodeContext* node_context =
+      TpuNodeContext_Create(device_ordinal, status.c_status);
+  if (!status.status().ok()) {
+    TpuNodeContext_Free(node_context);
+    return status.status();
+  }
+  return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
+}
+
+TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
+
+/* static */
+Status TpuNodeContext::StopChipHeartbeats() {
+  StatusHelper status;
+  TpuNodeContext_StopChipHeartbeats(status.c_status);
+  return status.status();
+}
+
+/* static */
+Status TpuNodeContext::CloseTpuHost() {
+  StatusHelper status;
+  TpuNodeContext_CloseTpuHost(status.c_status);
+  return status.status();
+}
+
+/* static */
+tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() {
+  return TpuPlatformInterface::GetRegisteredPlatform();
+}
+
+/* static */
+stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() {
+  static stream_executor::StreamExecutorMemoryAllocator* memory_allocator =
+      new stream_executor::StreamExecutorMemoryAllocator(
+          platform(),
+          xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie());
+  return memory_allocator;
+}
+
+/* static */
+xla::Backend* TpuNodeContext::backend() {
+  static xla::Backend* backend =
+      xla::Backend::CreateBackend(
+          xla::BackendOptions().set_platform(platform()))
+          .ValueOrDie()
+          .release();
+  return backend;
+}
+
+/* static */
+StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
+    int device_ordinal) {
+  return backend()->BorrowStream(device_ordinal);
+}
+
+/* static */
+StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
+    stream_executor::StreamExecutor* executor) {
+  return backend()->BorrowStream(executor);
+}
+
+/* static */
+xla::TransferManager* TpuNodeContext::transfer_manager() {
+  return xla::TransferManager::GetForPlatform(platform()).ValueOrDie();
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.h b/tensorflow/stream_executor/tpu/tpu_node_context.h
new file mode 100644
index 0000000..e1e1ffc
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_node_context.h
@@ -0,0 +1,89 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
+
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/stream_executor/device_memory_allocator.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuNodeContext final {
+ public:
+  using Status = stream_executor::port::Status;
+  template <typename T>
+  using StatusOr = stream_executor::port::StatusOr<T>;
+
+  static StatusOr<std::unique_ptr<TpuNodeContext>> Initialize(
+      int device_ordinal);
+
+  explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context)
+      : device_ordinal_(device_ordinal), node_context_(node_context) {
+    CHECK_NE(node_context, nullptr);
+  }
+  ~TpuNodeContext();
+
+  TpuNodeContext(const TpuNodeContext&) = delete;
+  TpuNodeContext& operator=(const TpuNodeContext&) = delete;
+
+  static Status StopChipHeartbeats();
+
+  static Status CloseTpuHost();
+
+  static tensorflow::tpu::TpuPlatformInterface* platform();
+
+  static stream_executor::DeviceMemoryAllocator* memory_allocator();
+
+  static xla::TransferManager* transfer_manager();
+
+  static xla::Backend* backend();
+
+  static StatusOr<xla::StreamPool::Ptr> BorrowStream(int device_ordinal);
+
+  static StatusOr<xla::StreamPool::Ptr> BorrowStream(
+      stream_executor::StreamExecutor* executor);
+
+  stream_executor::StreamExecutor* stream_executor() {
+    LOG(FATAL) << "Not implemented yet.";
+  }
+
+  std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; }
+
+  int index_on_host() { LOG(FATAL) << "Not implemented yet."; }
+
+  int device_ordinal() const { return device_ordinal_; }
+
+ private:
+  const int device_ordinal_;
+  XLA_TpuNodeContext* const node_context_;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h b/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h
new file mode 100644
index 0000000..d2684e4
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
+
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
+
+XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
+                                          SE_Status* status);
+void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
+
+void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
+void TpuNodeContext_CloseTpuHost(SE_Status* status);
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc
new file mode 100644
index 0000000..c44926d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform.cc
@@ -0,0 +1,125 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+namespace tensorflow {
+
+PLATFORM_DEFINE_ID(TpuPlatform::kId);
+TpuPlatform* tpu_registered_platform = nullptr;
+
+using Status = ::stream_executor::port::Status;
+template <typename T>
+using StatusOr = ::stream_executor::port::StatusOr<T>;
+
+TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); }
+
+TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
+  return tpu_registered_platform;
+}
+
+Status TpuPlatform::Initialize(
+    const std::map<std::string, std::string>& platform_options) {
+  StatusHelper status;
+
+  size_t options_size = platform_options.size();
+  const char** options_key =
+      static_cast<const char**>(malloc(sizeof(const char*) * options_size));
+  const char** options_value =
+      static_cast<const char**>(malloc(sizeof(const char*) * options_size));
+
+  size_t i = 0;
+  for (const auto& option : platform_options) {
+    options_key[i] = option.first.c_str();
+    options_value[i] = option.second.c_str();
+    i++;
+  }
+
+  TpuPlatform_Initialize(platform_, options_size, options_key, options_value,
+                         status.c_status);
+
+  free(options_key);
+  free(options_value);
+
+  return status.status();
+}
+
+TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); }
+
+int TpuPlatform::VisibleDeviceCount() const {
+  return TpuPlatform_VisibleDeviceCount(platform_);
+}
+
+StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
+    const ::stream_executor::StreamExecutorConfig& config) {
+  return executor_cache_.GetOrCreate(
+      config, [&]() { return GetUncachedExecutor(config); });
+}
+
+StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
+TpuPlatform::GetUncachedExecutor(
+    const ::stream_executor::StreamExecutorConfig& config) {
+  SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default();
+
+  TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal);
+
+  StatusHelper status;
+  SE_StreamExecutor* executor =
+      TpuPlatform_GetExecutor(platform_, c_config, status.c_status);
+  TpuStreamExecutorConfig_Free(c_config);
+  if (!status.ok()) {
+    return status.status();
+  }
+  return std::make_unique<stream_executor::StreamExecutor>(
+      this, absl::make_unique<tensorflow::TpuExecutor>(this, executor),
+      config.ordinal);
+}
+
+::stream_executor::Platform::Id TpuPlatform::id() const {
+  return TpuPlatform::kId;
+}
+
+const std::string& TpuPlatform::Name() const {
+  static std::string* name = new std::string(kName);
+  return *name;
+}
+
+int64 TpuPlatform::TpuMemoryLimit() {
+  return TpuPlatform_TpuMemoryLimit(platform_);
+}
+
+}  // namespace tensorflow
+
+void RegisterTpuPlatform() {
+  tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
+  std::unique_ptr<stream_executor::Platform> platform(
+      tensorflow::tpu_registered_platform);
+  SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
+      std::move(platform)));
+}
+
+REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
+
+// Note that module initialization sequencing is not supported in the
+// open-source project, so this will be a no-op there.
+REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
+REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
+                                     tpu_platform);
diff --git a/tensorflow/stream_executor/tpu/tpu_platform.h b/tensorflow/stream_executor/tpu/tpu_platform.h
new file mode 100644
index 0000000..9a67045
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform.h
@@ -0,0 +1,121 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
+
+#include <memory>
+
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/executor_cache.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+
+class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
+ public:
+  using StreamMap =
+      absl::flat_hash_map<stream_executor::internal::StreamInterface*,
+                          SE_Stream*>;
+
+  static const ::stream_executor::Platform::Id kId;
+  static constexpr char kName[] = "TPU";
+
+  using Status = ::stream_executor::port::Status;
+  template <typename T>
+  using StatusOr = ::stream_executor::port::StatusOr<T>;
+
+  TpuPlatform();
+
+  ~TpuPlatform() override;
+
+  static TpuPlatform* GetRegisteredPlatform();
+
+  Id id() const override;
+
+  const std::string& Name() const override;
+
+  int VisibleDeviceCount() const override;
+
+  int64 TpuMemoryLimit() override;
+
+  bool Initialized() const override {
+    return TpuPlatform_Initialized(platform_);
+  }
+
+  Status Initialize(
+      const std::map<std::string, std::string>& platform_options) override;
+
+  Status Reset() override { return Reset(false); }
+
+  Status Reset(bool only_tear_down) override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
+  DescriptionForDevice(int ordinal) const override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice(
+      int ordinal) override {
+    stream_executor::StreamExecutorConfig config;
+    config.ordinal = ordinal;
+    return GetExecutor(config);
+  }
+
+  StatusOr<::stream_executor::StreamExecutor*>
+  ExecutorForDeviceWithPluginConfig(
+      int ordinal,
+      const ::stream_executor::PluginConfig& plugin_config) override {
+    stream_executor::StreamExecutorConfig config;
+    config.ordinal = ordinal;
+    config.plugin_config = plugin_config;
+    return GetExecutor(config);
+  }
+
+  StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
+      const ::stream_executor::StreamExecutorConfig& config) override;
+
+  StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
+  GetUncachedExecutor(
+      const ::stream_executor::StreamExecutorConfig& config) override;
+
+  void RegisterTraceListener(
+      std::unique_ptr<stream_executor::TraceListener> listener) override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  void UnregisterTraceListener(
+      stream_executor::TraceListener* listener) override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  StreamMap* stream_map() { return &stream_map_; }
+
+ private:
+  SE_Platform* platform_;
+
+  stream_executor::ExecutorCache executor_cache_;
+  StreamMap stream_map_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
new file mode 100644
index 0000000..c5b8ece
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+namespace tensorflow {
+namespace tpu {
+
+/* static */
+TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
+  // Prefer TpuPlatform if it's registered.
+  auto status_or_tpu_platform =
+      stream_executor::MultiPlatformManager::PlatformWithName("TPU");
+  if (status_or_tpu_platform.ok()) {
+    return static_cast<TpuPlatformInterface*>(
+        status_or_tpu_platform.ValueOrDie());
+  }
+  if (status_or_tpu_platform.status().code() != error::NOT_FOUND) {
+    LOG(WARNING) << "Error when getting the TPU platform: "
+                 << status_or_tpu_platform.status();
+    return nullptr;
+  }
+
+  // Use any other registered TPU platform.
+  auto status_or_other_tpu_platforms =
+      stream_executor::MultiPlatformManager::PlatformsWithFilter(
+          [](const stream_executor::Platform* platform) {
+            return dynamic_cast<const TpuPlatformInterface*>(platform) !=
+                   nullptr;
+          });
+  if (!status_or_other_tpu_platforms.ok()) {
+    LOG(WARNING) << "Error when getting other TPU platforms: "
+                 << status_or_tpu_platform.status();
+    return nullptr;
+  }
+  auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie();
+  if (!other_tpu_platforms.empty()) {
+    LOG(WARNING) << other_tpu_platforms.size()
+                 << " TPU platforms registered, selecting "
+                 << other_tpu_platforms[0]->Name();
+    return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
+  }
+
+  LOG(WARNING) << "No TPU platform registered";
+  return nullptr;
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
new file mode 100644
index 0000000..5c7aa8e
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
@@ -0,0 +1,44 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/platform.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuPlatformInterface : public stream_executor::Platform {
+ public:
+  using Status = stream_executor::port::Status;
+
+  // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
+  // registered, finds the most suitable one. Returns nullptr if no TPU platform
+  // is registered or an error occurred.
+  static TpuPlatformInterface* GetRegisteredPlatform();
+
+  virtual Status Reset() { return Reset(false); }
+
+  virtual Status Reset(bool only_tear_down) = 0;
+
+  virtual int64 TpuMemoryLimit() = 0;
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h
new file mode 100644
index 0000000..b8fd10d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_stream.h
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
+
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+class TpuStream : public stream_executor::internal::StreamInterface {
+ public:
+  explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
+  ~TpuStream() override { TpuStream_Free(stream_); }
+
+ private:
+  SE_Stream* stream_;
+};
+
+class TpuEvent : public ::stream_executor::internal::EventInterface {
+ public:
+  explicit TpuEvent(SE_Event* event) : event_(event) {}
+  ~TpuEvent() override { TpuEvent_Free(event_); }
+
+ private:
+  SE_Event* event_;
+};
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_stream_interface.h b/tensorflow/stream_executor/tpu/tpu_stream_interface.h
new file mode 100644
index 0000000..2e5e02d
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_stream_interface.h
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
+
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace tensorflow {
+namespace tpu {
+
+class TpuStreamInterface : public ::stream_executor::internal::StreamInterface {
+};
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_timer.h b/tensorflow/stream_executor/tpu/tpu_timer.h
new file mode 100644
index 0000000..246a0b7
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_timer.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+namespace tensorflow {
+
+class TpuTimer : public ::stream_executor::internal::TimerInterface {
+ public:
+  explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
+  ~TpuTimer() override { TpuTimer_Free(timer_); }
+  uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); }
+  uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); }
+
+ private:
+  SE_Timer* timer_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc
new file mode 100644
index 0000000..473585a
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/status_helper.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+
+namespace tensorflow {
+
+using Status = stream_executor::port::Status;
+template <typename T>
+using StatusOr = stream_executor::port::StatusOr<T>;
+
+TpuTransferManager::TpuTransferManager() {
+  manager_ = TpuTransferManager_New();
+}
+
+TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); }
+
+stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
+  return TpuPlatform::kId;
+}
+
+xla::Shape TpuTransferManager::HostShapeToDeviceShape(
+    const xla::Shape& host_shape) const {
+  XLA_Shape c_host_shape;
+  XLA_Shape c_device_shape;
+
+  TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
+
+  TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape,
+                                            &c_device_shape);
+  xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
+  TpuConversions::CShapeCleanup(&c_host_shape);
+  TpuConversions::CShapeCleanup(&c_device_shape);
+  return device_shape;
+}
+
+Status TpuTransferManager::TransferLiteralToDeviceAsync(
+    stream_executor::Stream* stream, const xla::LiteralSlice& literal,
+    const xla::ShapedBuffer& device_buffer,
+    const TransferMetadata* transfer_metadata) {
+  StatusHelper status;
+
+  XLA_Literal c_literal;
+  TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
+
+  XLA_ShapedBuffer c_device_buffer;
+  TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
+                                                 &c_device_buffer);
+
+  TpuTransferManager_TransferLiteralToDeviceAsync(
+      manager_,
+      TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
+          stream->implementation()),
+      &c_literal, &c_device_buffer, status.c_status);
+  TpuConversions::CShapedBufferCleanup(&c_device_buffer);
+  TpuConversions::CLiteralCleanup(&c_literal);
+  return status.status();
+}
+
+struct TransferFromDeviceState {
+  std::atomic<int64_t> remaining_transfers;
+  StatusHelper status_helper;
+  std::function<void(Status)> done;
+
+  void TransferFinished(SE_Status* status) {
+    if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) {
+      status_helper.c_status = status;
+    } else {
+      TpuStatus_Free(status);
+    }
+
+    if (--remaining_transfers == 0) {
+      done(status_helper.status());
+      delete this;
+    }
+  }
+};
+
+void TransferLiteralFromDeviceTrampoline(void* ctx, SE_Status* status) {
+  reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
+}
+
+void TpuTransferManager::TransferLiteralFromDevice(
+    stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
+    xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
+    const TransferMetadata* transfer_metadata) {
+  TransferFromDeviceState* state = new TransferFromDeviceState;
+  state->remaining_transfers = 1;
+  state->done = done;
+  XLA_ShapedBuffer c_device_buffer;
+  TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
+                                                 &c_device_buffer);
+  XLA_Literal c_literal;
+  TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
+
+  TpuTransferManager_TransferLiteralFromDevice(
+      manager_,
+      TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
+          stream->implementation()),
+      &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
+  TpuConversions::CShapedBufferCleanup(&c_device_buffer);
+  TpuConversions::CLiteralCleanup(&c_literal);
+}
+
+int64 TpuTransferManager::GetByteSizeRequirement(
+    const xla::Shape& shape) const {
+  XLA_Shape c_shape;
+  TpuConversions::XlaShapeToCShape(shape, &c_shape);
+
+  int64 size_in_bytes =
+      TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape);
+
+  TpuConversions::CShapeCleanup(&c_shape);
+  return size_in_bytes;
+}
+
+Status TpuTransferManager::WriteSingleTupleIndexTable(
+    stream_executor::Stream* stream,
+    absl::Span<const stream_executor::DeviceMemoryBase> elements,
+    const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
+  CHECK_GT(elements.size(), 0);
+  SE_DeviceMemoryBase* elements_bases =
+      new SE_DeviceMemoryBase[elements.size()];
+  for (int i = 0; i < elements.size(); i++) {
+    elements_bases[i] =
+        SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
+                            elements[i].size(), elements[i].payload()};
+  }
+  XLA_Shape c_shape;
+  TpuConversions::XlaShapeToCShape(shape, &c_shape);
+  SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
+                                  region->payload()};
+  StatusHelper status;
+
+  TpuTransferManager_WriteSingleTupleIndexTable(
+      manager_,
+      TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
+          stream->implementation()),
+      elements_bases, elements.size(), &c_shape, &region_base, status.c_status);
+
+  delete[] elements_bases;
+  TpuConversions::CShapeCleanup(&c_shape);
+  return status.status();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
new file mode 100644
index 0000000..163ac81
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
@@ -0,0 +1,83 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+
+namespace tensorflow {
+
+class TpuTransferManager : public xla::TransferManager {
+ public:
+  TpuTransferManager();
+  ~TpuTransferManager() override;
+
+  using Status = stream_executor::port::Status;
+  template <typename T>
+  using StatusOr = stream_executor::port::StatusOr<T>;
+
+  stream_executor::Platform::Id PlatformId() const override;
+
+  xla::Shape HostShapeToDeviceShape(
+      const xla::Shape& host_shape) const override;
+
+  Status TransferLiteralToDeviceAsync(
+      stream_executor::Stream* stream, const xla::LiteralSlice& literal,
+      const xla::ShapedBuffer& device_buffer,
+      const TransferMetadata* transfer_metadata) override;
+
+  void TransferLiteralFromDevice(
+      stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
+      xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
+      const TransferMetadata* transfer_metadata) override;
+
+  Status TransferLiteralToInfeed(stream_executor::StreamExecutor* executor,
+                                 const xla::LiteralSlice& literal) override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  Status TransferLiteralFromOutfeed(
+      stream_executor::StreamExecutor* executor,
+      const xla::Shape& literal_shape,
+      xla::MutableBorrowingLiteral literal) override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  Status ResetDevices(
+      absl::Span<stream_executor::StreamExecutor* const> executor) override {
+    LOG(FATAL) << "Not yet implemented";
+  }
+
+  int64 GetByteSizeRequirement(const xla::Shape& shape) const override;
+
+  Status WriteSingleTupleIndexTable(
+      stream_executor::Stream* stream,
+      absl::Span<const stream_executor::DeviceMemoryBase> elements,
+      const xla::Shape& shape,
+      stream_executor::DeviceMemoryBase* region) override;
+
+ private:
+  XLA_TransferManager* manager_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc
new file mode 100644
index 0000000..f7f0c6f
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
+  return std::make_unique<TpuTransferManager>();
+}
+
+static bool InitModule() {
+  xla::TransferManager::RegisterTransferManager(TpuPlatform::kId,
+                                                CreateTpuTransferManager);
+  return true;
+}
+static bool module_initialized = InitModule();
+
+}  // namespace tensorflow