Add a TPU execution op.

PiperOrigin-RevId: 321844765
Change-Id: I3bfb52fe00f7d378a26e3247beee7daa5ba6d38b
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index af7c9ea..7a6160a 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -553,44 +553,3 @@
     ],
     alwayslink = 1,
 )
-
-cc_library(
-    name = "tpu_execute_op",
-    srcs = ["tpu_execute_op.cc"],
-    hdrs = ["tpu_execute_op.h"],
-    deps = [
-        ":tpu_compilation_cache_entry",
-        ":tpu_compilation_cache_external",
-        ":tpu_compilation_cache_local_lookup",
-        ":tpu_compilation_cache_lookup",
-        ":tpu_executable_info_proto_cc",
-        ":tpu_op_consts",
-        "//tensorflow/compiler/jit:xla_device",
-        "//tensorflow/compiler/jit:xla_launch_util",
-        "//tensorflow/compiler/jit:xla_tensor",
-        "//tensorflow/compiler/tf2xla:common",
-        "//tensorflow/compiler/tf2xla:tf2xla_util",
-        "//tensorflow/compiler/xla:debug_options_flags",
-        "//tensorflow/compiler/xla:shape_util",
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla:xla_data_proto_cc",
-        "//tensorflow/compiler/xla/service:dump",
-        "//tensorflow/compiler/xla/service:executable",
-        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:framework_internal",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:stream_executor_no_cuda",
-        "//tensorflow/core/profiler/lib:traceme",
-        "//tensorflow/core/tpu:tpu_configuration",
-        "//tensorflow/core/tpu:tpu_defs",
-        "//tensorflow/core/tpu:tpu_execute",
-        "//tensorflow/stream_executor:device_memory_allocator",
-        "//tensorflow/stream_executor/tpu:tpu_node_context",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/types:span",
-    ],
-    alwayslink = True,
-)
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
deleted file mode 100644
index 817649e..0000000
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ /dev/null
@@ -1,805 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/memory/memory.h"
-#include "absl/types/span.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/jit/xla_tensor.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/xla/debug_options_flags.h"
-#include "tensorflow/compiler/xla/service/dump.h"
-#include "tensorflow/compiler/xla/service/executable.h"
-#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/resource_var.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/platform/tracing.h"
-#include "tensorflow/core/profiler/lib/traceme.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
-#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
-#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
-#include "tensorflow/core/tpu/tpu_configuration.h"
-#include "tensorflow/core/tpu/tpu_defs.h"
-#include "tensorflow/core/tpu/tpu_execute.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-#include "tensorflow/stream_executor/device_memory_allocator.h"
-#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
-
-namespace tensorflow {
-
-namespace {
-
-using ::tensorflow::tpu::TpuNodeContext;
-using CompilationCacheEntryRef = ::tensorflow::tpu::CompilationCacheEntryRef<
-    ::tensorflow::tpu::TpuCompilationCacheEntry>;
-using TpuCompilationCacheLookup =
-    ::tensorflow::tpu::TpuCompilationCacheLookup<CompilationCacheEntryRef>;
-
-// Looks up the input `key` in the compilation cache, populating
-// `*rendezvous_key_base` and `*entry`.
-Status GetComputationCacheEntry(
-    OpKernelContext* context, string* rendezvous_key_base,
-    std::unique_ptr<CompilationCacheEntryRef>* entry) {
-  const Tensor* key;
-  TF_RETURN_IF_ERROR(context->input("key", &key));
-  profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
-  if (!TensorShapeUtils::IsVector(key->shape()) ||
-      key->shape().dim_size(0) != 2) {
-    return errors::InvalidArgument(
-        "Key argument to TPUExecute must be a 2-element vector");
-  }
-
-  ResourceMgr* rmgr = GetTPUConfigResourceMgr();
-  TpuCompilationCacheLookup* proto_lookup;
-  TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
-                                  tpu::kCompiledProtoCacheResourceName,
-                                  &proto_lookup));
-  core::ScopedUnref lookup_unref(proto_lookup);
-  TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
-  *rendezvous_key_base = key->vec<tstring>()(1);
-  return Status::OK();
-}
-
-struct VariableUpdateMap {
-  // Maps input index to the updated output index. If the variable doesn't have
-  // an updated output, the corresponding output is set to -1.
-  absl::flat_hash_map<int, int> input_to_output;
-  // Maps output index to (the input index, whether the update is generated from
-  // compilation).
-  absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
-  // Part of the input indices that are from the compilation, in the compiled
-  // order.
-  std::vector<int> input_in_compiled_update_order;
-};
-
-// Creates a VariableUpdateMap from both the compilation and the fused variable
-// reads/updates.
-xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
-    absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
-        compiled_variable_updates,
-    absl::Span<int const> fused_device_var_reads_in_computation_inputs,
-    const std::vector<int>& fused_device_var_updates_in_computation_outputs,
-    int64 computation_output_count) {
-  VariableUpdateMap map;
-  auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
-    TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
-        << "Duplicate variable input index: " << input;
-    if (output >= 0) {
-      TF_RET_CHECK(map.output_to_input
-                       .emplace(output, std::pair{input, from_compilation})
-                       .second)
-          << "Duplicate variable output index: " << output;
-    }
-    return Status::OK();
-  };
-
-  // First add the updates produced by the compilation. Not all variables are
-  // updated, and if not, they do not have an output in the XLA computation. The
-  // update output indices in the XLA computation start after the non-variable
-  // outputs.
-  int num_updated_variables = 0;
-  for (int i = 0; i < compiled_variable_updates.size(); ++i) {
-    const bool updated = compiled_variable_updates[i]->updated();
-    if (updated) ++num_updated_variables;
-  }
-  TF_RET_CHECK(num_updated_variables <= computation_output_count)
-      << num_updated_variables << " <= " << computation_output_count;
-  int64 compiled_variable_output_index =
-      computation_output_count - num_updated_variables;
-  for (auto update : compiled_variable_updates) {
-    map.input_in_compiled_update_order.push_back(update->index());
-    if (!update->updated()) {
-      TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
-      continue;
-    }
-    TF_RETURN_IF_ERROR(
-        add_pair(update->index(), compiled_variable_output_index, true));
-    ++compiled_variable_output_index;
-  }
-
-  // Now add the updates from the attributes.
-  TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
-               fused_device_var_updates_in_computation_outputs.size());
-  for (int64 i = 0; i < fused_device_var_reads_in_computation_inputs.size();
-       ++i) {
-    TF_RETURN_IF_ERROR(
-        add_pair(fused_device_var_reads_in_computation_inputs[i],
-                 fused_device_var_updates_in_computation_outputs[i], false));
-  }
-  return map;
-}
-
-// Buffers representing the inputs to a computation.
-struct InputBuffers {
-  explicit InputBuffers(xla::Shape device_shape)
-      : buffers(std::move(device_shape)) {}
-
-  InputBuffers(const InputBuffers&) = delete;
-  InputBuffers& operator=(const InputBuffers&) = delete;
-
-  ~InputBuffers() = default;
-
-  xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
-                                   se::DeviceMemoryAllocator* allocator,
-                                   int device_ordinal) {
-    CHECK_NE(allocator, nullptr);
-    xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
-                                    allocator->platform(), device_ordinal);
-    shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
-        [](xla::MaybeOwningDeviceMemory* buffer) {
-          CHECK(buffer);
-          return buffer->AsDeviceMemoryBase();
-        }));
-    return shaped_buffer;
-  }
-
-  // Describes the buffer tree.
-  xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
-
-  // Information about resource variables passed directly to TPUExecute.
-  std::vector<VariableInfo> variables;
-
-  // Mapping from input index to offsets in 'variables'. < 0 if the input does
-  // not correspond to a variable in 'variables'.
-  std::vector<int> variable_index;
-};
-
-// Builds an InputBuffers object that describes the inputs to the computation.
-xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
-    OpKernelContext* context, const xla::Shape& input_host_shape,
-    const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
-    se::Stream* stream) {
-  profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
-  OpInputList arg_list;
-  TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
-
-  if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
-    return errors::InvalidArgument(
-        "Number of parameters (", arg_list.size(),
-        ") does not match input shape: ",
-        xla::ShapeUtil::TupleElementCount(input_host_shape));
-  }
-
-  auto validate_shape = [&](int i, const Tensor& tensor) {
-    const xla::Shape& expected =
-        xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
-    VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
-    XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
-
-    if (xla_tensor == nullptr) {
-      // FromTensor failed; tensor must be empty.
-      if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
-        return errors::InvalidArgument(
-            "Run-time shape mismatch for TPUExecute argument[", i, "] (",
-            context->op_kernel().requested_input(i), "). Expected ",
-            expected.DebugString(), "; got empty tensor");
-      }
-    } else {
-      // Compare host shapes, easier than getting the expected device shape.
-      const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
-      if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
-        return errors::InvalidArgument(
-            "Run-time shape mismatch for TPUExecute argument[", i, "] (",
-            context->op_kernel().requested_input(i), "). Expected ",
-            expected.DebugString(), "; got ", xla_shape.DebugString());
-      }
-    }
-
-    return Status::OK();
-  };
-
-  // Iterate over the inputs, validating the shapes of non-variable inputs,
-  // and creating a VariableInfo object for each variable. We consider variable
-  // inputs in a separate phase because we must acquire variable locks in order.
-  std::vector<VariableInfo> variables;
-  std::vector<int> variable_index(arg_list.size(), -1);
-  variables.reserve(arg_list.size());
-  for (int i = 0; i < arg_list.size(); ++i) {
-    // Arguments are assumed to be variables if they have a resource type.
-    // (Non-variable resources are not supported.)
-    if (context->input_dtype(i) == DT_RESOURCE) {
-      variable_index[i] = variables.size();
-      // TODO(phawkins): we may be looking up many variables here; it would be
-      // better if we did not repeatedly acquire the resource manager's lock.
-      const ResourceHandle& handle = HandleFromInput(context, i);
-      Var* variable;
-      TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
-      variables.push_back(VariableInfo(i, handle.name(), variable));
-    } else {
-      TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
-    }
-  }
-
-  // Lock the variables, and validate their shapes. We hold the variable locks
-  // for the duration of the TPU execution so we can donate the variable buffers
-  // to the computation. If we copied the variable's Tensor instead, its
-  // reference count would be greater than one due to the reference the Var
-  // object holds, and we would never be able to reuse variable buffers.
-  // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
-  // the user to elect to copy the buffers and permit concurrent access instead.
-  TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
-  for (int i = 0; i < variables.size(); ++i) {
-    TF_RETURN_IF_ERROR(
-        validate_shape(variables[i].index(), *variables[i].var()->tensor()));
-  }
-
-  se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
-  xla::TransferManager* const transfer_manager =
-      node_context->transfer_manager();
-  const int device_ordinal = node_context->device_ordinal();
-
-  auto input_buffers = absl::make_unique<InputBuffers>(
-      transfer_manager->HostShapeToDeviceShape(input_host_shape));
-
-  // Allocates a buffer for the root tuple.
-  const int64 root_size =
-      transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
-  TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
-                      allocator->Allocate(device_ordinal, root_size));
-
-  // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
-  // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
-  // to the computation and overwrites the entries in 'buffers' with nulls.
-  auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
-                                      xla::ShapedBuffer* buffers) {
-    buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
-                                                 se::DeviceMemoryBase* buffer) {
-      xla::ShapeIndex in_index = {arg_index};
-      for (int64 j : index) {
-        in_index.push_back(j);
-      }
-      auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
-      if (donate_buffers) {
-        *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
-        *buffer = se::DeviceMemoryBase();
-      } else {
-        *in_buffer = *buffer;
-      }
-    });
-  };
-
-  // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
-  // buffers for zero-element tensors where required.
-  auto assign_input = [&](int i, const Tensor& tensor,
-                          bool may_reuse) -> xla::Status {
-    XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
-
-    // Size 0 tensors have no backing XlaTensor, but may still need to have
-    // tuple buffers allocated.
-    if (xla_tensor == nullptr) {
-      CHECK_EQ(tensor.NumElements(), 0);
-      const xla::Shape& host_shape =
-          xla::ShapeUtil::GetSubshape(input_host_shape, {i});
-      TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
-                          transfer_manager->AllocateScopedShapedBuffer(
-                              host_shape, allocator, device_ordinal));
-      set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
-                               &buffers);
-    } else {
-      bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
-      set_input_buffers_helper(/*arg_index=*/i,
-                               /*donate_buffers=*/can_reuse_buffers,
-                               &xla_tensor->shaped_buffer());
-      xla_tensor->WaitForDefinitionEventOnStream(stream);
-    }
-    return Status::OK();
-  };
-
-  for (int i = 0; i < arg_list.size(); ++i) {
-    auto it = variable_updates.input_to_output.find(i);
-    if (it == variable_updates.input_to_output.end()) {
-      TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
-      continue;
-    }
-    // input i is a variable
-    bool updated = it->second >= 0;
-    if (arg_list[i].dtype() != DT_RESOURCE) {
-      TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
-    } else {
-      int vi = variable_index[i];
-      TF_RETURN_IF_ERROR(
-          assign_input(i, *variables[vi].var()->tensor(), updated));
-    }
-  }
-
-  input_buffers->variables = std::move(variables);
-  input_buffers->variable_index = std::move(variable_index);
-
-  return std::move(input_buffers);
-}
-
-struct OutputBuffers {
-  OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
-      : owned_buffers(b.on_device_shape(), true),
-        buffers(b.release()),
-        memory_allocator(allocator) {}
-
-  ~OutputBuffers() {
-    buffers.buffers().ForEachElement([&](const xla::ShapeIndex& index,
-                                         const se::DeviceMemoryBase& buffer) {
-      if (owned_buffers.element(index) && !buffer.is_null()) {
-        Status status =
-            memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
-        LOG_IF(ERROR, !status.ok()) << "Error deallocating buffer " << status;
-      }
-    });
-  }
-
-  // Which of the buffers do we own?
-  xla::ShapeTree<bool> owned_buffers;
-
-  xla::ShapedBuffer buffers;
-
-  se::DeviceMemoryAllocator* const memory_allocator;
-};
-
-// Allocates Tensors for the outputs of the computation. Ownership of most
-// output buffers is passed to the output Tensors. Returns an OutputBuffer that
-// owns the root buffer that should be passed to the XLA computation, as well as
-// any output buffers that do not have corresponding output tensors. The latter
-// may happen for zero-element tensors of type int64 or complex64 which still
-// require a tuple buffer but do not have a corresponding XlaTensor.
-xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
-    OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
-    absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
-    const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
-    se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
-    const std::shared_ptr<se::Event>& definition_event) {
-  VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
-
-  profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
-  // Shapes of the outputs, in TensorShape form.
-  const int64 sub_elements =
-      xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
-  if (sub_elements != output_tensor_shape_protos.size()) {
-    return errors::InvalidArgument(
-        "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
-        output_tensor_shape_protos.size());
-  }
-
-  xla::TransferManager* const transfer_manager =
-      node_context->transfer_manager();
-
-  std::vector<TensorShape> output_tensor_shapes;
-  output_tensor_shapes.reserve(sub_elements);
-  for (int64 i = 0; i < sub_elements; ++i) {
-    TF_RETURN_IF_ERROR(
-        TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
-    TensorShape shape(*output_tensor_shape_protos[i]);
-    const xla::Shape& xla_shape =
-        xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
-    if (!xla_shape.IsArray() ||
-        xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
-      return errors::InvalidArgument(
-          "Mismatched number of elements in output shape: ",
-          xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
-    }
-    output_tensor_shapes.push_back(shape);
-  }
-
-  // Builds a shaped buffer for the outputs.
-  TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
-  TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
-
-  se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
-
-  auto output_buffers =
-      absl::MakeUnique<OutputBuffers>(std::move(scoped_buffers), allocator);
-
-  xla::Shape output_host_shape = output_buffers->buffers.on_host_shape();
-  xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
-
-  if (!output_host_shape.is_static()) {
-    TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
-        stream, &output_buffers->buffers, &output_host_shape,
-        &output_device_shape));
-    for (int64 i = 0; i < sub_elements; ++i) {
-      const xla::Shape& subshape =
-          xla::ShapeUtil::GetSubshape(output_host_shape, {i});
-      TensorShape shape;
-      TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
-      output_tensor_shapes[i] = shape;
-    }
-  }
-
-  // Transfers ownership of the buffers that back XLA computation output 'i'
-  // to 'output_tensor'.
-  auto transfer_buffers = [&](int i, Tensor* output_tensor) {
-    const xla::Shape& host_shape =
-        xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
-    const xla::Shape& device_shape =
-        xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
-
-    // Transfers ownership of the output buffers to the output Tensor, if
-    // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
-    // backing XlaTensor, so we let retain 'output_buffers' ownership of any
-    // buffers in that case.
-    if (output_tensor->NumElements() > 0) {
-      xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
-                                            device_ordinal);
-      shaped_buffer.buffers().ForEachMutableElement(
-          [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
-            xla::ShapeIndex out_index = {i};
-            for (int64 j : index) {
-              out_index.push_back(j);
-            }
-            *buffer = output_buffers->buffers.buffers().element(out_index);
-            *output_buffers->owned_buffers.mutable_element(out_index) = false;
-          });
-
-      XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
-      xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
-      xla_tensor->ResetDefinitionEvent(definition_event, stream);
-    }
-  };
-
-  const int num_updated_variables = variable_updates.output_to_input.size();
-  TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
-      << num_updated_variables << " <= " << output_tensor_shapes.size();
-
-  OpInputList arg_list;
-  TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
-
-  // The TPU program outputs the updated variables including DT_RESOURCE and
-  // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
-  // variables (updated or not).
-  //
-  //                       updated          not_updated
-  //                 |------------------|------------------|
-  // DT_RESOURCE     | allocate persist |    do nothing    |
-  //                 |------------------|------------------|
-  //                 |     allocate     | forward Op input |
-  // not DT_RESOURCE |      output      |   to Op output   | Op output
-  //                 |------------------|------------------|
-  //                    program output
-
-  // Allocates a fresh tensor for each updated variable. While the variable
-  // inputs need come in no particular order, the variable values are
-  // always added last by XlaCompiler class, in the same order as the
-  // corresponding input variables.
-  int op_output_index = 0;
-  int compiled_update_index = 0;
-  auto process_non_updated_variable = [&](int input_index) {
-    const int variable_index = input_buffers->variable_index.at(input_index);
-    // If a DT_RESOURCE input is not updated, nothing needs to be done
-    // because there is no corresponding output. If a non-resource input
-    // is not updated, forward the input to the output.
-    if (variable_index < 0) {
-      context->set_output(op_output_index, arg_list[input_index]);
-      ++op_output_index;
-    }
-  };
-  for (int i = 0; i < output_tensor_shapes.size(); ++i) {
-    auto it = variable_updates.output_to_input.find(i);
-    if (it == variable_updates.output_to_input.end()) {
-      // Not a variable update.
-      // Allocates a fresh tensor for each output of the operator. We always
-      // allocate a new host-side tensor, but the on-device buffers that back
-      // that tensor may be aliases of input buffers.
-      Tensor* output_tensor;
-      TF_RETURN_IF_ERROR(context->allocate_output(
-          op_output_index, output_tensor_shapes[i], &output_tensor));
-      transfer_buffers(i, output_tensor);
-      ++op_output_index;
-      continue;
-    }
-    const int input_index = it->second.first;
-    // We must process the compiled updates in order, which includes the
-    // non-updated variables, i.e., those without an XLA output.
-    const bool from_compilation = it->second.second;
-    while (from_compilation &&
-           variable_updates
-                   .input_in_compiled_update_order[compiled_update_index] !=
-               input_index) {
-      process_non_updated_variable(
-          variable_updates
-              .input_in_compiled_update_order[compiled_update_index]);
-      ++compiled_update_index;
-    }
-    ++compiled_update_index;
-    const int variable_index = input_buffers->variable_index.at(input_index);
-    PersistentTensor unused;
-    Tensor* output_tensor;
-    if (variable_index >= 0) {
-      // This output corresponds to a DT_RESOURCE input to the TPUExecute
-      // operator. Update the corresponding variable.
-      VariableInfo& var = input_buffers->variables[variable_index];
-      // TODO(b/35625933): the correct thing to do would be to transfer
-      // ownership of the PersistentTensor into the Var object. However, Var
-      // contains a Tensor so we can't.
-      TF_RETURN_IF_ERROR(context->allocate_persistent(
-          var.var()->tensor()->dtype(), output_tensor_shapes[i], &unused,
-          &output_tensor));
-      *var.var()->tensor() = *output_tensor;
-    } else {
-      // This output corresponds to a non-resource input to the TPUExecute
-      // operator. This case occurs for the distributed TPU rewrite which
-      // adds variable values as inputs and outputs rather than passing the
-      // variables themselves; reading and writing the variable is handled
-      // outside the op.
-      // TODO(phawkins): remove this case when placement of variables on TPU
-      // devices is well supported and we no longer need to place "remote"
-      // variables on CPU devices.
-      TF_RETURN_IF_ERROR(context->allocate_output(
-          op_output_index, output_tensor_shapes[i], &output_tensor));
-      ++op_output_index;
-    }
-    transfer_buffers(i, output_tensor);
-  }
-
-  // Process any remaining non-updated variables.
-  for (; compiled_update_index <
-         variable_updates.input_in_compiled_update_order.size();
-       ++compiled_update_index) {
-    process_non_updated_variable(
-        variable_updates.input_in_compiled_update_order[compiled_update_index]);
-  }
-  return std::move(output_buffers);
-}
-
-}  // namespace
-
-// TPUExecuteOp
-
-TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
-    : AsyncOpKernel(context, /* is_deferred = */ true) {}
-
-AsyncOpKernel* TPUExecuteOp::AsAsync() {
-  // If TPU launches are asynchronous, we can perform the launch without
-  // blocking the calling thread, and so the executor may treat this kernel as
-  // a regular (synchronous) OpKernel.
-  return nullptr;
-}
-
-void TPUExecuteOp::Compute(OpKernelContext* context) {
-  Status s = DoWork(context);
-  // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
-  // a dynamic check that we are not in an AsyncOpKernel.
-  if (TF_PREDICT_FALSE(!s.ok())) {
-    context->SetStatus(s);
-  }
-}
-
-void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
-  // If TPU launches are asynchronous, then perform the launch on this
-  // thread to avoid a thread hop, which has an observable latency cost.
-  OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
-  done();
-}
-
-Status TPUExecuteOp::DoWork(OpKernelContext* context) {
-  VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
-
-  const XlaDevice::Metadata* metadata;
-  TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
-  const int device_ordinal = metadata->device_ordinal();
-
-  // We are guaranteed that the object underlying TpuNodeContext won't be
-  // deleted out from under us, while node_context is alive.
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
-                      TpuNodeContext::Create(device_ordinal));
-
-  profiler::TraceMe trace_me(
-      [&, device_ordinal] {
-        return absl::StrCat("TpuExecuteOp#device_ordinal=", device_ordinal,
-                            ",id=", context->step_id(),
-                            ",iter_num=", context->frame_iter().iter_id, "#");
-      },
-      /*level=*/2);
-  profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
-
-  string rendezvous_key_base;
-  std::unique_ptr<CompilationCacheEntryRef> entry;
-  TF_RETURN_IF_ERROR(
-      GetComputationCacheEntry(context, &rendezvous_key_base, &entry));
-
-  // Shapes of the inputs and outputs, in xla::Shape form.
-  const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
-
-  xla::TransferManager* const transfer_manager =
-      node_context->transfer_manager();
-  CHECK(context->op_device_context());
-  se::Stream* stream = context->op_device_context()->stream();
-
-  TF_RET_CHECK(proto->input_shapes_size() == 1);
-
-  xla::Shape host_shape(proto->input_shapes(0));
-
-  TF_ASSIGN_OR_RETURN(
-      auto variable_update_map,
-      BuildVariableUpdateMap(proto->variable_indices(),
-                             fused_device_var_reads_in_computation_inputs_,
-                             fused_device_var_updates_in_computation_outputs_,
-                             proto->output_tensor_shapes().size()));
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<InputBuffers> input_buffers,
-      BuildComputationInputs(context, host_shape, variable_update_map,
-                             node_context.get(), stream));
-
-  // Ideally this should be the host-to-device stream from XlaDeviceContext.
-  // The particular anti-dependency this is avoiding (why we need a separate
-  // transfer stream) is between the executable writing tuple tables and
-  // TPUExecute()'s deregister_stream; if they come from the same stream pool
-  // antidependencies will occur. XlaBackend has a different pool of streams
-  // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
-  // will never refer to the same stream.
-  //
-  // TODO(jmolloy): Add the necessary plumbing to obtain the proper
-  // host-to-device stream here.
-  TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
-                      node_context->BorrowStream(device_ordinal));
-
-  se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
-  auto shaped_buffer =
-      input_buffers->ToShapedBuffer(host_shape, allocator, device_ordinal);
-  if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
-                                                     shaped_buffer)) {
-    TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
-        transfer_stream_ptr.get(), shaped_buffer));
-    stream->ThenWaitFor(transfer_stream_ptr.get());
-  } else {
-    TF_RETURN_IF_ERROR(
-        transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
-  }
-  VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
-
-  // Snapshot the inputs, if a snapshot was requested.
-  std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
-  if (proto->has_session_module()) {
-    hlo_snapshot = std::make_shared<xla::HloSnapshot>(proto->session_module());
-    auto literal =
-        std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
-    transfer_manager->TransferLiteralFromDevice(
-        stream, shaped_buffer, literal.get(),
-        [hlo_snapshot, literal](Status status) {
-          if (!status.ok()) {
-            LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
-                          "failed: "
-                       << status;
-            return;
-          }
-          *hlo_snapshot->add_arguments() = literal->ToProto();
-        });
-  }
-
-  auto definition_event = std::make_shared<se::Event>(stream->parent());
-  TF_RET_CHECK(definition_event->Init())
-      << "TPU definition event initialization failed";
-
-  trace_me_init.Stop();
-
-  const uint32 rng_seed = GetXLARandomSeed();
-
-  std::unique_ptr<xla::DeviceAssignment> device_assignment;
-  if (proto->has_device_assignment()) {
-    TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
-                                               proto->device_assignment()));
-  }
-
-  VLOG(4) << "Input buffers after alias resolution: "
-          << shaped_buffer.ToString();
-
-  std::vector<xla::ExecutionInput> input;
-  input.emplace_back(
-      xla::ExecutionInput(std::move(input_buffers->buffers), host_shape));
-
-  // The buffers to be freed are in the `output` and will be automatically
-  // freed when it goes out of the scope. In async mode, this means the buffers
-  // will be freed before anyone calls "BlockHostUntilDone", which indicates
-  // that some of the (input) buffers will be freed while the program is running
-  // and looks scary. However, this turns out to be not a problem since although
-  // we free a memory and reassign it to other users while a program is running,
-  // all subsequent writes to the program that could possibly clobber the memory
-  // will depend on the program to finish.
-  const TPUHostTransferInfoProto* host_transfer_info =
-      entry->get().get_host_transfer_info();
-  const xla::HloProto* hlo_metadata = entry->get().get_hlo_metadata();
-  TF_ASSIGN_OR_RETURN(
-      xla::ExecutionOutput output,
-      TPUExecute(*proto, *host_transfer_info, *hlo_metadata, std::move(input),
-                 rendezvous_key_base, rng_seed, node_context.get(),
-                 device_assignment.get(), context->cancellation_manager(),
-                 context, stream, transfer_stream_ptr.get(),
-                 entry->get().get_tpu_program()));
-  stream->ThenRecordEvent(definition_event.get());
-
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<OutputBuffers> output_buffers,
-      AllocateOutputTensors(context, output.ConsumeResult(),
-                            proto->output_tensor_shapes(), variable_update_map,
-                            node_context.get(), stream, device_ordinal,
-                            input_buffers.get(), definition_event));
-
-  // Transfer the outputs and save the snapshot to disk.
-  if (hlo_snapshot) {
-    auto literal =
-        std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
-    transfer_manager->TransferLiteralFromDevice(
-        stream, output_buffers->buffers, literal.get(),
-        [hlo_snapshot, literal](Status status) {
-          if (status.ok()) {
-            *hlo_snapshot->mutable_result() = literal->ToProto();
-          } else {
-            LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
-                          "outputs failed: "
-                       << status;
-          }
-          DumpHloSnapshotIfEnabled(*hlo_snapshot,
-                                   xla::GetDebugOptionsFromFlags());
-        });
-  }
-  return Status::OK();
-}
-
-TPUExecuteOp::~TPUExecuteOp() = default;
-
-TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
-    OpKernelConstruction* context)
-    : TPUExecuteOp(context) {
-  OP_REQUIRES_OK(context, context->GetAttr(
-                              "device_var_reads_indices",
-                              &fused_device_var_reads_in_computation_inputs_));
-  OP_REQUIRES_OK(
-      context,
-      context->GetAttr("device_var_updates_indices",
-                       &fused_device_var_updates_in_computation_outputs_));
-}
-
-REGISTER_KERNEL_BUILDER(
-    Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
-
-REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
-                            .Device(DEVICE_TPU_NODE)
-                            .HostMemory("key"),
-                        TPUExecuteAndUpdateVariablesOp);
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.h b/tensorflow/core/tpu/kernels/tpu_execute_op.h
deleted file mode 100644
index c66118a..0000000
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_
-#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_
-
-#include <memory>
-#include <vector>
-
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-
-namespace tensorflow {
-
-// Op that executes a precompiled TPU computation.
-class TPUExecuteOp : public AsyncOpKernel {
- public:
-  explicit TPUExecuteOp(OpKernelConstruction* context);
-  ~TPUExecuteOp() override;
-
-  AsyncOpKernel* AsAsync() override;
-
-  void Compute(OpKernelContext* context) override;
-  void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
-
- protected:
-  // Used by TPUExecuteAndUpdateVariablesOp to set the fused variable reads and
-  // updates indices in the XLA computation. The two vectors must have the same
-  // size, and a pair of read index and write index represents a variable's
-  // input to the program and its updated value from the program. If the
-  // variable is not updated, use -1 as the output index.
-  std::vector<int> fused_device_var_reads_in_computation_inputs_;
-  std::vector<int> fused_device_var_updates_in_computation_outputs_;
-
- private:
-  Status DoWork(OpKernelContext* context);
-
-  TF_DISALLOW_COPY_AND_ASSIGN(TPUExecuteOp);
-};
-
-// A variant of TPUExecuteOp that contains fused device variable reads and
-// updates.
-class TPUExecuteAndUpdateVariablesOp : public TPUExecuteOp {
- public:
-  explicit TPUExecuteAndUpdateVariablesOp(OpKernelConstruction* context);
-  ~TPUExecuteAndUpdateVariablesOp() override = default;
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(TPUExecuteAndUpdateVariablesOp);
-};
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_