Introduce additional XLA TPU Ops to open source

PiperOrigin-RevId: 326343558
Change-Id: I47da1dc0c96cdf8223ccebef012e2a5088a857a4
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 6d33690..0c7ac66 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -38,6 +38,7 @@
         ":tpu_execute_op",
         ":tpu_handle_to_key_op",
         ":transfer_ops",
+        "//tensorflow/core/tpu/kernels/xla:xla_ops",
     ],
 )
 
diff --git a/tensorflow/core/tpu/kernels/xla/BUILD b/tensorflow/core/tpu/kernels/xla/BUILD
new file mode 100644
index 0000000..f55583a
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/BUILD
@@ -0,0 +1,52 @@
+# XLA Ops for TPUs
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "xla_ops",
+    srcs = [
+        "get_item_op.cc",
+        "host_compute_ops.cc",
+        "index_ops.cc",
+        "infeed_op.cc",
+        "inplace_ops.cc",
+        "outfeed_ops.cc",
+        "segment_reduction_ops.cc",
+        "where_op.cc",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:sharding_util",
+        "//tensorflow/compiler/tf2xla:side_effect_util",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_context",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
+        "//tensorflow/compiler/tf2xla/kernels:if_op",
+        "//tensorflow/compiler/tf2xla/kernels:while_op",
+        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+        "//tensorflow/compiler/tf2xla/lib:scatter",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client/lib:arithmetic",
+        "//tensorflow/compiler/xla/client/lib:comparators",
+        "//tensorflow/compiler/xla/client/lib:constants",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/tpu:tpu_api",
+        "//tensorflow/core/tpu:tpu_defs",
+        "//tensorflow/core/tpu/kernels:cross_replica_ops",
+        "//tensorflow/stream_executor/tpu:c_api_conversions",
+        "//tensorflow/stream_executor/tpu:c_api_decl",
+        "@com_google_absl//absl/strings",
+    ],
+    alwayslink = 1,
+)
diff --git a/tensorflow/core/tpu/kernels/xla/get_item_op.cc b/tensorflow/core/tpu/kernels/xla/get_item_op.cc
new file mode 100644
index 0000000..094c6b8
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/get_item_op.cc
@@ -0,0 +1,75 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_util.h"
+
+namespace tensorflow {
+namespace {
+
+// The Xla kernel to build up the computation for get_item(data, index).
+class GetItemXlaOp : public XlaOpKernel {
+ public:
+  explicit GetItemXlaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    const TensorShape& data_shape = ctx->InputShape(0);
+    const TensorShape& index_shape = ctx->InputShape(1);
+    OP_REQUIRES(
+        ctx, TensorShapeUtils::IsVectorOrHigher(data_shape),
+        errors::InvalidArgument("data must be at least 1 dimensional."));
+    OP_REQUIRES(ctx, index_shape.dims() == 1 && index_shape.dim_size(0) == 1,
+                errors::InvalidArgument("index must be a vector of size 1."));
+
+    // NOTE(pbar) Use Concat to extend the indices to match cl/142279605.
+    // This isn't the simplest way to emit the indices, but the code for
+    // dynamic slice needs to be able to see that minor dims are const zero.
+    auto const_zero = xla::ConstantR0(ctx->builder(), 0);
+    std::vector<xla::XlaOp> operands;
+    operands.push_back(xla::Reshape(ctx->Input(1), {}));
+    for (int i = 1; i < data_shape.dims(); i++) {
+      operands.push_back(const_zero);
+    }
+
+    std::vector<int64> dims = {0};
+    std::vector<int64> slice_sizes = {1};
+    std::vector<int64> out_sizes = {};
+    for (int i = 1; i < data_shape.dims(); i++) {
+      dims.push_back(i);
+      auto size = data_shape.dim_size(i);
+      slice_sizes.push_back(size);
+      out_sizes.push_back(size);
+    }
+    // NOTE: DynamicSlice here doesn't raise an error or wraps the index
+    // if its out-of-range.
+    auto slice = xla::DynamicSlice(ctx->Input(0), operands, slice_sizes);
+    // In-order collapse to remove the 1st dim.
+    auto reshape = xla::Reshape(slice, dims, out_sizes);
+    ctx->SetOutput(0, reshape);
+  }
+};
+
+REGISTER_XLA_OP(Name("GetItem"), GetItemXlaOp);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc
new file mode 100644
index 0000000..be3ee1c
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc
@@ -0,0 +1,498 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/common_runtime/lower_function_call_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+
+namespace {
+
+// TODO(phawkins) add a canonical copy of these operator names and refactor
+// everything to use it.
+static const char* const kSendFromHostOp = "_XlaSendFromHost";
+static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
+
+Status MakeXlaShapes(gtl::ArraySlice<TensorShape> shapes,
+                     gtl::ArraySlice<DataType> dtypes,
+                     std::vector<xla::Shape>* xla_shapes,
+                     xla::Shape* xla_shape) {
+  for (int i = 0; i < shapes.size(); i++) {
+    xla::Shape single_xla_shape;
+    TF_RETURN_IF_ERROR(
+        TensorShapeToXLAShape(dtypes[i], shapes[i], &single_xla_shape));
+    VLOG(2) << "Shape " << single_xla_shape.DebugString();
+    xla_shapes->push_back(single_xla_shape);
+  }
+  // Temporarily add a dummy output to the shape array before making the tuple:
+  // this output is used for control dependencies between host compute ops.
+  xla_shapes->push_back(xla::ShapeUtil::MakeShape(xla::PRED, {}));
+  *xla_shape = xla::ShapeUtil::MakeTupleShape(*xla_shapes);
+  // Remove the dummy output from the vector that will be used to copy real
+  // outputs from host to device.
+  xla_shapes->pop_back();
+  return Status::OK();
+}
+
+// This TensorFlow pseudo-op is used to record host-side computation.
+class HostComputeOp : public XlaOpKernel {
+ public:
+  explicit HostComputeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("cost_estimate_ns", &cost_estimate_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("tpu_core", &tpu_core_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_dtypes_));
+    OP_REQUIRES(ctx, ctx->num_inputs() == input_dtypes_.size(),
+                errors::InvalidArgument("Tinputs size=", input_dtypes_.size(),
+                                        " but expected ", ctx->num_inputs(),
+                                        " inputs."));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_dtypes_));
+    OP_REQUIRES(ctx, ctx->num_outputs() == output_dtypes_.size(),
+                errors::InvalidArgument("Toutputs size=", output_dtypes_.size(),
+                                        " but expected ", ctx->num_outputs(),
+                                        " outputs."));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("ancestors", &ancestors_));
+    NameAttrList shape_inference_graph;
+    OP_REQUIRES_OK(
+        ctx, ctx->GetAttr("shape_inference_graph", &shape_inference_graph));
+    if (shape_inference_graph.name().empty()) {
+      OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &static_output_shapes_));
+      OP_REQUIRES(ctx, static_output_shapes_.size() == output_dtypes_.size(),
+                  errors::InvalidArgument(
+                      "shapes attr list size ", static_output_shapes_.size(),
+                      " differs from dtypes size ", output_dtypes_.size()));
+      OP_REQUIRES_OK(ctx, MakeXlaShapes(static_output_shapes_, output_dtypes_,
+                                        &static_xla_output_shapes_,
+                                        &static_xla_output_shape_));
+      VLOG(2) << "Output Shape: " << static_xla_output_shape_.DebugString();
+    } else {
+      FunctionLibraryRuntime* flib_runtime = ctx->function_library();
+      OP_REQUIRES(ctx, flib_runtime != nullptr,
+                  errors::Internal(
+                      "No function library runtime at kernel construction"));
+      const FunctionLibraryDefinition* library =
+          flib_runtime->GetFunctionLibraryDefinition();
+      const FunctionDef* fdef = library->Find(shape_inference_graph.name());
+      OP_REQUIRES(ctx, fdef != nullptr,
+                  errors::Internal("Failed to find function ",
+                                   shape_inference_graph.name(),
+                                   " in function library."));
+      OP_REQUIRES_OK(ctx, FunctionDefToBodyHelper(
+                              *fdef, AttrSlice(&shape_inference_graph.attr()),
+                              library, &shape_inference_graph_function_));
+      VLOG(2) << "Output Shape to be inferred at compile time";
+    }
+    OP_REQUIRES_OK(
+        ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
+    OP_REQUIRES(ctx, !token_input_nodes_.empty(),
+                errors::InvalidArgument("XlaHostCompute node does not have ",
+                                        kXlaTokenInputNodesAttrName, " attr"));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
+                                     &original_node_name_));
+  }
+
+  ~HostComputeOp() override {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaBuilder* b = ctx->builder();
+    XlaCompiler* compiler = ctx->compiler();
+
+    std::vector<xla::XlaOp> input_handles;
+    std::vector<TensorShape> input_shapes;
+    auto inputs = ctx->InputList("inputs", &input_handles, &input_shapes);
+    const auto device_sharding = xla::sharding_builder::AssignDevice(tpu_core_);
+    xla::XlaScopedShardingAssignment assign_sharding(b, device_sharding);
+
+    std::vector<xla::XlaOp> input_tokens;
+    for (auto& token_input_node : token_input_nodes_) {
+      auto token_or = compiler->GetNodeToken(token_input_node);
+      OP_REQUIRES_OK(ctx, token_or.status());
+      input_tokens.push_back(token_or.ValueOrDie());
+    }
+    xla::XlaOp token = xla::AfterAll(b, input_tokens);
+
+    // Send values to the host.
+    std::vector<xla::XlaOp> send_to_host_tokens;
+    for (int i = 0; i < input_handles.size(); ++i) {
+      const string channel_name = absl::StrCat(key_, "_dtoh_", i);
+      xla::Shape xla_shape;
+      OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtypes_[i],
+                                                input_shapes[i], &xla_shape));
+      // Specify frontend attributes.
+      xla::FrontendAttributes attrs;
+      (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
+      (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+          xla::primitive_util::LowercasePrimitiveTypeName(
+              xla_shape.element_type());
+      b->SetFrontendAttributes(attrs);
+      xla::ChannelHandle channel;
+      OP_REQUIRES_OK(
+          ctx, compiler->GetDeviceToHostChannelHandle(channel_name, &channel));
+      send_to_host_tokens.push_back(
+          xla::SendToHost(input_handles[i], token, xla_shape, channel));
+      b->ClearOpMetadata();
+    }
+    xla::XlaOp recv_from_host_token_input =
+        send_to_host_tokens.empty() ? token
+                                    : xla::AfterAll(b, send_to_host_tokens);
+    if (!input_handles.empty()) {
+      // Register the shapes used in this transfer.
+      OP_REQUIRES_OK(ctx, ctx->compiler()->SetDeviceToHostMetadata(
+                              key_, input_dtypes_, input_shapes));
+    }
+    // Compute the shapes of the values to copy to the device, if necessary.
+    std::vector<TensorShape>* output_shapes;
+    std::vector<xla::Shape>* xla_output_shapes;
+    xla::Shape* xla_output_shape;
+    std::vector<TensorShape> inferred_output_shapes;
+    std::vector<xla::Shape> inferred_xla_output_shapes;
+    xla::Shape inferred_xla_output_shape;
+    if (shape_inference_graph_function_) {
+      OP_REQUIRES_OK(
+          ctx, InferOutputShapes(
+                   ctx, ctx->function_library()->GetFunctionLibraryDefinition(),
+                   &inferred_output_shapes));
+      OP_REQUIRES_OK(ctx, MakeXlaShapes(inferred_output_shapes, output_dtypes_,
+                                        &inferred_xla_output_shapes,
+                                        &inferred_xla_output_shape));
+      output_shapes = &inferred_output_shapes;
+      xla_output_shapes = &inferred_xla_output_shapes;
+      xla_output_shape = &inferred_xla_output_shape;
+    } else {
+      output_shapes = &static_output_shapes_;
+      xla_output_shapes = &static_xla_output_shapes_;
+      xla_output_shape = &static_xla_output_shape_;
+    }
+    OP_REQUIRES(
+        ctx, output_shapes->size() == ctx->num_outputs(),
+        errors::InvalidArgument("Op has ", ctx->num_outputs(), " outputs ",
+                                " but output shape vector of size ",
+                                output_shapes->size()));
+    if (ctx->num_outputs() > 0) {
+      // Register the shapes used in this transfer.
+      OP_REQUIRES_OK(ctx, ctx->compiler()->SetHostToDeviceMetadata(
+                              key_, output_dtypes_, *output_shapes));
+    }
+    // Copy results to the device.
+    std::vector<xla::XlaOp> recv_from_host_tokens;
+    for (int i = 0; i < output_shapes->size(); ++i) {
+      const string channel_name = absl::StrCat(key_, "_htod_", i);
+      // Specify frontend attributes.
+      xla::FrontendAttributes attrs;
+      (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
+      (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+          xla::primitive_util::LowercasePrimitiveTypeName(
+              xla_output_shapes->at(i).element_type());
+      b->SetFrontendAttributes(attrs);
+      xla::ChannelHandle channel;
+      OP_REQUIRES_OK(
+          ctx, compiler->GetHostToDeviceChannelHandle(channel_name, &channel));
+
+      const auto result_token_tuple = xla::RecvFromHost(
+          recv_from_host_token_input, xla_output_shapes->at(i), channel);
+      b->ClearOpMetadata();
+      recv_from_host_tokens.push_back(
+          xla::GetTupleElement(result_token_tuple, /*index=*/1));
+      ctx->SetOutput(i, xla::GetTupleElement(result_token_tuple, 0));
+    }
+
+    // Set token output.
+    xla::XlaOp token_output = recv_from_host_tokens.empty()
+                                  ? recv_from_host_token_input
+                                  : xla::AfterAll(b, recv_from_host_tokens);
+    OP_REQUIRES_OK(
+        ctx, ctx->compiler()->SetNodeToken(original_node_name_, token_output));
+  }
+
+ private:
+  Status LowerFunctionalOps(Graph* g,
+                            const FunctionLibraryDefinition& flib_def) {
+    bool modified;
+    do {
+      modified = false;
+
+      // Lower "If" nodes first. Their body functions will be expanded as
+      // function call nodes, which we will lower later.
+      // We do not need to lower "While" nodes because shape inference can
+      // handle them correctly (output shapes are input shapes).
+      std::vector<Node*> if_nodes;
+      for (Node* n : g->op_nodes()) {
+        if (n->type_string() == "If") {
+          if_nodes.push_back(n);
+        }
+      }
+      for (Node* if_node : if_nodes) {
+        TF_RETURN_IF_ERROR(
+            RewriteIfNode(if_node, g, /*keep_node_fetchable=*/false));
+      }
+      if (!if_nodes.empty()) {
+        modified = true;
+      }
+
+      // Lower function call nodes.
+      std::vector<Node*> call_nodes;
+      for (Node* n : g->op_nodes()) {
+        if (IsFunctionCall(flib_def, *n)) {
+          call_nodes.push_back(n);
+        }
+      }
+      for (Node* call_node : call_nodes) {
+        TF_RETURN_IF_ERROR(RewriteFunctionCallNode(
+            call_node, g, flib_def, /*keep_caller_fetchable=*/false));
+      }
+      if (!call_nodes.empty()) {
+        modified = true;
+      }
+    } while (modified);
+
+    return Status::OK();
+  }
+
+  Status InferOutputShapes(XlaOpKernelContext* ctx,
+                           const FunctionLibraryDefinition* flib_def,
+                           std::vector<TensorShape>* output_shapes) {
+    // First unpack the inference graphdef from the attr into graph. Don't do
+    // any shape inference at this point.
+    Graph* graph = shape_inference_graph_function_->graph;
+
+    // Lower functional ops, because they are not friendly to shape inference.
+    TF_RETURN_IF_ERROR(LowerFunctionalOps(graph, *flib_def));
+
+    // Now run shape inference, filling in the shapes of recvathost nodes.
+    bool got_output_shapes = false;
+    ShapeRefiner shape_refiner{graph->versions().producer(),
+                               graph->op_registry()};
+    std::vector<Node*> nodes;
+    GetReversePostOrder(*graph, &nodes);
+    for (auto node : nodes) {
+      TF_RETURN_IF_ERROR(shape_refiner.AddNode(node));
+      if (node->type_string() == kRecvAtHostOp) {
+        const AttrValue* key_attr = node->attrs().Find("key");
+        if (key_attr == nullptr) {
+          return errors::InvalidArgument("Node ", node->name(),
+                                         " has no key attribute");
+        }
+        std::vector<TensorShape> dtoh_shapes;
+        if (!ctx->compiler()
+                 ->GetDeviceToHostShapes(key_attr->s(), &dtoh_shapes)
+                 .ok()) {
+          return errors::InvalidArgument(
+              "Shape inference for HostCompute ", ctx->op_kernel().name(),
+              " failed: host recv node ", node->name(), " with key '",
+              key_attr->s(), "' has unknown shapes.");
+        }
+        if (dtoh_shapes.size() != node->num_outputs()) {
+          return errors::InvalidArgument(
+              "Shape inference for HostCompute ", ctx->op_kernel().name(),
+              " failed: host recv node ", node->name(), " with key '",
+              key_attr->s(), "' has ", node->num_outputs(),
+              " outputs but inferred shapes expect ", dtoh_shapes.size());
+        }
+        for (int i = 0; i < node->num_outputs(); ++i) {
+          shape_inference::InferenceContext* shape_ctx =
+              shape_refiner.GetContext(node);
+          shape_inference::ShapeHandle handle;
+          TF_RETURN_IF_ERROR(
+              shape_ctx->MakeShapeFromTensorShape(dtoh_shapes.at(i), &handle));
+          shape_ctx->set_output(i, handle);
+        }
+      } else if (node->type_string() == kSendFromHostOp) {
+        if (got_output_shapes) {
+          return errors::InvalidArgument(
+              "Shape inference for HostCompute ", ctx->op_kernel().name(),
+              " failed: inference graph has multiple send from host nodes");
+        } else {
+          got_output_shapes = true;
+          // The last input is the dynamic key so don't record its shape.
+          output_shapes->resize(node->num_inputs() - 1);
+          shape_inference::InferenceContext* shape_ctx =
+              shape_refiner.GetContext(node);
+          for (int i = 0; i < node->num_inputs() - 1; ++i) {
+            shape_inference::ShapeHandle handle = shape_ctx->input(i);
+            if (!shape_ctx->FullyDefined(handle)) {
+              return errors::InvalidArgument(
+                  "Shape inference for HostCompute ", ctx->op_kernel().name(),
+                  " failed: send from host node ", node->name(),
+                  " has non-fully defined shape of input index ", i);
+            }
+            TensorShapeProto shape_proto;
+            shape_ctx->ShapeHandleToProto(handle, &shape_proto);
+            (*output_shapes)[i] = TensorShape(shape_proto);
+            VLOG(2) << "Inferred shape " << shape_proto.DebugString();
+          }
+        }
+      }
+    }
+    if (!got_output_shapes) {
+      return errors::InvalidArgument(
+          "Shape inference for HostCompute ", ctx->op_kernel().name(),
+          " failed: inference graph has no send from host node");
+    }
+    return Status::OK();
+  }
+
+  DataTypeVector input_dtypes_;
+  DataTypeVector output_dtypes_;
+  std::vector<string> ancestors_;
+  std::vector<TensorShape> static_output_shapes_;
+  std::vector<xla::Shape> static_xla_output_shapes_;
+  string original_node_name_;
+  // If static_xla_output_shapes_.size() == 1 then xla_output_shape_ is the
+  // unique output shape, otherwise it is a tuple of all the xla_output_shapes_.
+  xla::Shape static_xla_output_shape_;
+  string key_;
+  // If shape inference is performed at runtime, the graph needed to perform
+  // shape inference is stored in this function.
+  std::unique_ptr<FunctionBody> shape_inference_graph_function_;
+  int64 cost_estimate_;
+  int64 tpu_core_;
+  std::vector<string> token_input_nodes_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(HostComputeOp);
+};
+
+class SendToHostOp : public XlaOpKernel {
+ public:
+  explicit SendToHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinput", &input_dtype_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
+    OP_REQUIRES_OK(
+        ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
+    OP_REQUIRES(ctx, !token_input_nodes_.empty(),
+                errors::InvalidArgument("XlaSendToHost node does not have ",
+                                        kXlaTokenInputNodesAttrName, " attr"));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
+                                     &original_node_name_));
+  }
+
+  ~SendToHostOp() override {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaBuilder* b = ctx->builder();
+
+    XlaCompiler* compiler = ctx->compiler();
+    xla::XlaOp operand = ctx->Input(0);
+    std::vector<xla::XlaOp> input_tokens;
+    for (auto& token_input_node : token_input_nodes_) {
+      auto token_or = compiler->GetNodeToken(token_input_node);
+      OP_REQUIRES_OK(ctx, token_or.status());
+      input_tokens.push_back(token_or.ValueOrDie());
+    }
+    xla::XlaOp token = xla::AfterAll(b, input_tokens);
+    xla::Shape xla_shape;
+    OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtype_, ctx->InputShape(0),
+                                              &xla_shape));
+    // Specify frontend attributes.
+    xla::FrontendAttributes attrs;
+    (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
+    (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+        xla::primitive_util::LowercasePrimitiveTypeName(
+            xla_shape.element_type());
+    b->SetFrontendAttributes(attrs);
+    xla::ChannelHandle channel;
+    OP_REQUIRES_OK(ctx, compiler->GetDeviceToHostChannelHandle(key_, &channel));
+    xla::XlaOp output_token =
+        xla::SendToHost(operand, token, xla_shape, channel);
+    OP_REQUIRES_OK(ctx,
+                   compiler->SetNodeToken(original_node_name_, output_token));
+  }
+
+ private:
+  DataType input_dtype_;
+  string key_;
+  std::vector<string> token_input_nodes_;
+  string original_node_name_;
+  TF_DISALLOW_COPY_AND_ASSIGN(SendToHostOp);
+};
+
+class RecvFromHostOp : public XlaOpKernel {
+ public:
+  explicit RecvFromHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput", &output_dtype_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &output_shape_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
+    OP_REQUIRES_OK(
+        ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
+    OP_REQUIRES(ctx, !token_input_nodes_.empty(),
+                errors::InvalidArgument("XlaRecvFromHost node does not have ",
+                                        kXlaTokenInputNodesAttrName, " attr"));
+  }
+
+  ~RecvFromHostOp() override {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaBuilder* b = ctx->builder();
+
+    XlaCompiler* compiler = ctx->compiler();
+    std::vector<xla::XlaOp> input_tokens;
+    for (auto& token_input_node : token_input_nodes_) {
+      auto token_or = compiler->GetNodeToken(token_input_node);
+      OP_REQUIRES_OK(ctx, token_or.status());
+      input_tokens.push_back(token_or.ValueOrDie());
+    }
+    xla::XlaOp token = xla::AfterAll(b, input_tokens);
+    xla::Shape xla_shape;
+    OP_REQUIRES_OK(
+        ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape));
+    // Specify frontend attributes.
+    xla::FrontendAttributes attrs;
+    (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
+    (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+        xla::primitive_util::LowercasePrimitiveTypeName(
+            xla_shape.element_type());
+    b->SetFrontendAttributes(attrs);
+    xla::ChannelHandle channel;
+    OP_REQUIRES_OK(ctx, compiler->GetHostToDeviceChannelHandle(key_, &channel));
+    xla::XlaOp result = xla::RecvFromHost(token, xla_shape, channel);
+    // xla::RecvFromHost returns a tuple of (received data, token).
+    ctx->SetOutput(0, xla::GetTupleElement(result, 0));
+    OP_REQUIRES_OK(
+        ctx, compiler->SetNodeToken(name(), xla::GetTupleElement(result, 1)));
+  }
+
+ private:
+  DataType output_dtype_;
+  TensorShape output_shape_;
+  string key_;
+  std::vector<string> token_input_nodes_;
+  TF_DISALLOW_COPY_AND_ASSIGN(RecvFromHostOp);
+};
+
+REGISTER_XLA_OP(Name("XlaHostCompute"), HostComputeOp);
+REGISTER_XLA_OP(Name("XlaSendToHost"), SendToHostOp);
+REGISTER_XLA_OP(Name("XlaRecvFromHost"), RecvFromHostOp);
+
+}  // anonymous namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/index_ops.cc b/tensorflow/core/tpu/kernels/xla/index_ops.cc
new file mode 100644
index 0000000..40148f3
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/index_ops.cc
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/index_ops.h"
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+
+// This registration is needed here because the ArgMax Op is defined in
+// third_party where DEVICE_TPU_XLA_JIT is not visible. Most Ops don't need a
+// specific TPU whitelist, but ArgMax does because it has a separate CustomCall
+// implementation on CPU.
+REGISTER_XLA_OP(Name("ArgMax")
+                    .Device(DEVICE_TPU_XLA_JIT)
+                    .CompileTimeConstantInput("dimension"),
+                XlaArgMaxOp);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/infeed_op.cc b/tensorflow/core/tpu/kernels/xla/infeed_op.cc
new file mode 100644
index 0000000..941a543
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/infeed_op.cc
@@ -0,0 +1,162 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/c_api_decl.h"
+
+namespace tensorflow {
+
+namespace {
+
+xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
+  XLA_Shape c_shape;
+  XLA_Shape c_infeed_shape;
+
+  ApiConverter::ToC(shape, &c_shape);
+
+  tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
+                                                             &c_infeed_shape);
+  xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
+  ApiConverter::Free(&c_shape);
+  ApiConverter::Free(&c_infeed_shape);
+  return infeed_shape;
+}
+
+// Updates the layout of the given infeed shape, optionally considering the
+// sharding of the op. If the op has tile sharding, assign the layout based on
+// the shard shape.
+Status UpdateInfeedLayout(xla::Shape* shape,
+                          absl::optional<xla::OpSharding> sharding) {
+  if (sharding && sharding->type() == xla::OpSharding::OTHER) {
+    TF_ASSIGN_OR_RETURN(auto hlo_sharding,
+                        xla::HloSharding::FromProto(*sharding));
+    for (int64 i = 0; i < sharding->tile_assignment_devices_size(); ++i) {
+      auto device = sharding->tile_assignment_devices(i);
+      auto shard_shape =
+          GetTPUInfeedLayout(hlo_sharding.TileShape(*shape, device));
+      if (i == 0) {
+        *shape->mutable_layout() = shard_shape.layout();
+      }
+      if (xla::ShapeUtil::ElementsIn(shard_shape) == 0) {
+        // Shapes with 0 dimensions may be assigned with a different layout, but
+        // it doesn't matter since we're not sending any data.
+        continue;
+      }
+      if (!xla::LayoutUtil::Equal(shard_shape.layout(), shape->layout())) {
+        return xla::Unimplemented(
+            "Sharded infeed with non-uniform layouts is not supported. Try "
+            "turning off the infeed layout optimization "
+            "(--transpose_tpu_infeed=false) and report to XLA team.");
+      }
+    }
+    return Status::OK();
+  }
+  *shape = GetTPUInfeedLayout(*shape);
+  return Status::OK();
+}
+
+// TODO(pbar) Work out if we need to Infeed Tuples - if so then
+// this op will need a way to provide a list of shapes
+// since they can't be provided by the runtime JIT mechanism.
+// (InfeedDequeue has no inputs!)
+// Compare this op to tf.Queue operations which operate on N tensors.
+
+// This TensorFlow op supports the XLA Infeed primitve.
+class InfeedDequeueOp : public XlaOpKernel {
+ public:
+  explicit InfeedDequeueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+    OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
+  }
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaBuilder* b = ctx->builder();
+    OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shape_, b->sharding()));
+    ctx->SetOutput(0, xla::Infeed(b, xla_shape_));
+  }
+
+ private:
+  TensorShape shape_;
+  DataType dtype_;
+  xla::Shape xla_shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueOp);
+};
+
+REGISTER_XLA_OP(Name("InfeedDequeue"), InfeedDequeueOp);
+
+// This TensorFlow op supports the XLA Infeed primitive for tuple types.
+class InfeedDequeueTupleOp : public XlaOpKernel {
+ public:
+  explicit InfeedDequeueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+    for (int i = 0; i < shapes_.size(); i++) {
+      xla::Shape xla_shape;
+      OP_REQUIRES_OK(ctx,
+                     TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
+      xla_shapes_.push_back(xla_shape);
+    }
+  }
+
+  ~InfeedDequeueTupleOp() override {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaBuilder* b = ctx->builder();
+    for (int64 i = 0; i < xla_shapes_.size(); ++i) {
+      absl::optional<xla::OpSharding> sharding;
+      if (b->sharding()) {
+        sharding = b->sharding()->type() == xla::OpSharding::TUPLE
+                       ? b->sharding()->tuple_shardings(i)
+                       : b->sharding();
+      }
+      OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shapes_[i], sharding));
+    }
+    tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
+    auto tuple = xla::Infeed(b, tuple_shape_);
+
+    // Don't apply the infeed tuple sharding to the get-tuple-elements. They
+    // need non-tuple shardings.
+    xla::XlaScopedShardingAssignment clear_sharding(b, absl::nullopt);
+    for (int i = 0; i < shapes_.size(); ++i) {
+      ctx->SetOutput(i, xla::GetTupleElement(tuple, i));
+    }
+  }
+
+ private:
+  std::vector<TensorShape> shapes_;
+  DataTypeVector dtypes_;
+  std::vector<xla::Shape> xla_shapes_;
+  xla::Shape tuple_shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueTupleOp);
+};
+
+REGISTER_XLA_OP(Name("InfeedDequeueTuple"), InfeedDequeueTupleOp);
+
+}  // anonymous namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/inplace_ops.cc b/tensorflow/core/tpu/kernels/xla/inplace_ops.cc
new file mode 100644
index 0000000..9baffd6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/inplace_ops.cc
@@ -0,0 +1,142 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+class InplaceUpdateOp : public XlaOpKernel {
+ public:
+  explicit InplaceUpdateOp(OpKernelConstruction* context)
+      : XlaOpKernel(context) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    VLOG(3) << "InplaceUpdateOp::Compile";
+
+    DataType index_type = input_type(1);
+    OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
+                errors::InvalidArgument("index must be int32 or int64"));
+
+    // TF Args are X, I, V
+    const TensorShape x_shape = ctx->InputShape(0);
+    const TensorShape i_shape = ctx->InputShape(1);
+    const TensorShape v_shape = ctx->InputShape(2);
+
+    OP_REQUIRES(ctx,
+                TensorShapeUtils::IsScalar(i_shape) ||
+                    TensorShapeUtils::IsVector(i_shape),
+                errors::InvalidArgument("index must be Rank 0 or 1"));
+    OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
+                errors::InvalidArgument("X and V must have the same Rank,"
+                                        " X.shape=",
+                                        x_shape.DebugString(),
+                                        " V.shape=", v_shape.DebugString()));
+
+    auto* builder = ctx->builder();
+    auto const_zero = xla::ConstantR0(builder, 0);
+    auto current = ctx->Input(0);
+
+    for (int64 i = 0; i < i_shape.num_elements(); i++) {
+      std::vector<xla::XlaOp> update_indices;
+      update_indices.push_back(
+          xla::Reshape(xla::SliceInDim(ctx->Input(1), i, i + 1, 1, 0), {}));
+      for (int xi = 1; xi < x_shape.dims(); xi++) {
+        update_indices.push_back(const_zero);
+      }
+      current = xla::DynamicUpdateSlice(
+          current, xla::SliceInDim(ctx->Input(2), i, i + 1, 1, 0),
+          update_indices);
+    }
+    ctx->SetOutput(0, current);
+
+    // TODO(b/118122460): Uncomment+format this code to use XLA Scatter.
+    //     auto* builder = ctx->builder();
+    //     const auto initial = ctx->Input(0);
+    //     const auto indices = ctx->Input(1);
+    //     const auto updates = ctx->Input(2);
+    //
+    //     auto result = XlaScatter(
+    //         initial, updates, indices, /*indices_are_vectors=*/false,
+    //         [](xla::XlaOp, xla::XlaOp second, xla::XlaBuilder*) { return
+    //         second; }, builder);
+    //     OP_REQUIRES_OK(ctx, result.status());
+    //     ctx->SetOutput(0, result.ValueOrDie());
+  }
+};
+
+REGISTER_XLA_OP(Name("InplaceUpdate"), InplaceUpdateOp);
+
+class InplaceAddOp : public XlaOpKernel {
+ public:
+  explicit InplaceAddOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    VLOG(3) << "InplaceAddOp::Compile";
+
+    DataType index_type = input_type(1);
+    OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
+                errors::InvalidArgument("index must be int32 or int64"));
+
+    // TF Args are X, I, V
+    const TensorShape x_shape = ctx->InputShape(0);
+    const TensorShape i_shape = ctx->InputShape(1);
+    const TensorShape v_shape = ctx->InputShape(2);
+    OP_REQUIRES(ctx,
+                (TensorShapeUtils::IsScalar(i_shape) ||
+                 ((i_shape.dims() == 1) && (i_shape.num_elements() == 1))),
+                errors::InvalidArgument("index must be Rank 1 and size 1"));
+    OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
+                errors::InvalidArgument("X and V must have the same Rank,"
+                                        " X.shape=",
+                                        x_shape.DebugString(),
+                                        " V.shape=", v_shape.DebugString()));
+    // Pad the indices out to the match the rank of params.
+    auto* builder = ctx->builder();
+    std::vector<xla::XlaOp> padded_indices;
+    padded_indices.push_back(xla::Reshape(ctx->Input(1), {}));
+    for (int i = 0; i < x_shape.dims() - 1; ++i) {
+      padded_indices.push_back(XlaHelpers::Zero(builder, index_type));
+    }
+
+    std::vector<int64> sizes;
+    sizes.push_back(1);
+    for (int i = 1; i < x_shape.dims(); i++) {
+      sizes.push_back(x_shape.dim_size(i));
+    }
+
+    auto prev = xla::DynamicSlice(ctx->Input(0), padded_indices, sizes);
+    auto updated = xla::Add(prev, ctx->Input(2));
+    auto result =
+        xla::DynamicUpdateSlice(ctx->Input(0), updated, padded_indices);
+    ctx->SetOutput(0, result);
+  }
+};
+
+REGISTER_XLA_OP(Name("InplaceAdd"), InplaceAddOp);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc
new file mode 100644
index 0000000..8abdd3d
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+namespace {
+
+// This TensorFlow op implements the XLA Outfeed primitive.
+class OutfeedEnqueueOp : public XlaOpKernel {
+ public:
+  explicit OutfeedEnqueueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+  }
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::Shape xla_shape;
+    OP_REQUIRES_OK(
+        ctx, TensorShapeToXLAShape(dtype_, ctx->InputShape(0), &xla_shape));
+    // Outfeed configuration is only needed for embedding outfeed.
+    const string outfeed_config;
+    xla::Outfeed(ctx->Input(0), xla_shape, outfeed_config);
+  }
+
+ private:
+  DataType dtype_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueOp);
+};
+
+REGISTER_XLA_OP(Name("OutfeedEnqueue"), OutfeedEnqueueOp);
+
+// This TensorFlow op implements the XLA Outfeed primitive for tuple types.
+class OutfeedEnqueueTupleOp : public XlaOpKernel {
+ public:
+  explicit OutfeedEnqueueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+  }
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    std::vector<xla::XlaOp> handles;
+    std::vector<TensorShape> shapes;
+    auto inputs = ctx->InputList("inputs", &handles, &shapes);
+
+    std::vector<xla::Shape> xla_shapes;
+    for (int i = 0; i < shapes.size(); ++i) {
+      xla::Shape xla_shape;
+      OP_REQUIRES_OK(ctx,
+                     TensorShapeToXLAShape(dtypes_[i], shapes[i], &xla_shape));
+      xla_shapes.push_back(xla_shape);
+    }
+    xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(xla_shapes);
+    VLOG(1) << "OutfeedEnqueueTuple: "
+            << xla::ShapeUtil::HumanStringWithLayout(tuple_shape);
+    auto b = ctx->builder();
+    auto tuple = xla::Tuple(b, handles);
+    // Outfeed configuration is only needed for embedding outfeed.
+    const string outfeed_config;
+    xla::Outfeed(tuple, tuple_shape, outfeed_config);
+  }
+
+ private:
+  DataTypeVector dtypes_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueTupleOp);
+};
+
+REGISTER_XLA_OP(Name("OutfeedEnqueueTuple"), OutfeedEnqueueTupleOp);
+
+}  // anonymous namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
new file mode 100644
index 0000000..f7c33e5
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
@@ -0,0 +1,145 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+// TODO(b/32945756): Add a scatter op in XLA and move this to a HLO optimization
+// pass. Optimization for UnsortedSegmentSum on TPU: use k-hot matmul. This
+// optimization requires:
+//     1. data has dtype supported by TPU matmul and has rank of 1 or 2.
+//     2. indices has rank of 1.
+//     3. matmul op count is less than 800 billion.
+//
+// Example of calculating UnsortedSegmentSum by k-hot matmul:
+//     data shape        [A, B]
+//     indices shape     [A]
+//     num_segment        N
+//     output shape      [N, B]
+//     matmul op count    N * A * B
+// Step 1: create k-hot matrix
+//     k-hot matrix has shape of [A, N], where row i is responsible for
+//     collecting the sum of the i-th segment, concretely
+//            k-hot[i][j] = 1 if indices[i] = j
+// Step 2: perform matmul
+//     the final result is obtained by multiplying k-hot matrix with data
+//     matrix, namely
+//             k-hot  *  data   => result
+// shape:      [N, A] *  [A, B] => [N, B]
+xla::XlaOp KHotMatmul(XlaOpKernelContext* ctx, xla::XlaBuilder* builder,
+                      const xla::XlaOp data, const xla::XlaOp indices,
+                      int64 num_segments) {
+  DataType data_dtype = ctx->input_type(0);
+  xla::PrimitiveType indices_type = ctx->input_xla_type(1);
+  TensorShape data_shape = ctx->InputShape(0);
+  TensorShape indices_shape = ctx->InputShape(1);
+  xla::XlaOp linspace = xla::Iota(builder, indices_type, num_segments);
+  xla::XlaOp linspace_col = xla::Reshape(linspace, {num_segments, 1});
+  TensorShape indices_row_shape = indices_shape;
+  indices_row_shape.InsertDim(0, 1);
+  xla::XlaOp indices_row = xla::Reshape(indices, indices_row_shape.dim_sizes());
+  xla::XlaOp k_hot = xla::Eq(indices_row, linspace_col);
+  xla::XlaOp k_hot_with_data_dtype =
+      XlaHelpers::ConvertElementType(k_hot, data_dtype);
+  // F32 version of the KHotMatmul. It splits the F32 data into three
+  // BF16 partial data and run KHotMatmul for each of them. The final result
+  // is the summation of three BF16 results.
+  // Note that this still doesn't fully retain f32 precision.
+  // In particular, values smaller than 2^-111 may see loss of precision.
+  xla::PrecisionConfig precision_config;
+  if (data_dtype == DT_FLOAT) {
+    precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST);
+  } else {
+    CHECK_EQ(data_dtype, DT_BFLOAT16);
+    precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
+  }
+  precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
+  return xla::Dot(k_hot_with_data_dtype, data, &precision_config);
+}
+
+class UnsortedSegmentSum : public XlaOpKernel {
+ public:
+  explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+  }
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    // output = unsorted_segment_sum(data, indices, num_segments)
+    // Compute a tensor such that:
+    //    output[i] = sum over {j where indices[j] == i} of data[j]
+    //    output[i] == 0 if i does not appear in indices
+    //
+    // Contrast with segment_sum(), which assumes indices are sorted and that
+    // max(indices)+1 is the desired size of the output.
+    //
+    // The returned output tensor has the same type as data, and the same shape
+    // as data with the first indices.rank dimensions are replaced
+    // by a single dimension with size num_segments.
+    xla::XlaOp data = ctx->Input(0);
+    TensorShape data_shape = ctx->InputShape(0);
+
+    xla::XlaOp indices = ctx->Input(1);
+    TensorShape indices_shape = ctx->InputShape(1);
+
+    int64 num_segments;
+    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
+
+    OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
+                errors::InvalidArgument(
+                    "UnsortedSegmentSum requires that indices' rank be"
+                    " less than or equal to data's rank."));
+    // Validate that indices.shape is a prefix of data.shape.
+    for (int d = 0; d < indices_shape.dims(); ++d) {
+      OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
+                  errors::InvalidArgument(
+                      "UnsortedSegmentSum requires indices shape to be prefix"
+                      " of data_shape, but dimension ",
+                      d, " differs ", data_shape.dim_size(d), " vs. ",
+                      indices_shape.dim_size(d)));
+    }
+    xla::XlaBuilder* builder = ctx->builder();
+    TensorShape buffer_shape = data_shape;
+    buffer_shape.RemoveDimRange(0, indices_shape.dims());
+    buffer_shape.InsertDim(0, num_segments);
+    auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
+                                 buffer_shape.dim_sizes());
+
+    auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
+      return a + b;
+    };
+
+    auto result = XlaScatter(buffer, /*updates=*/data, indices,
+                             /*indices_are_vectors=*/false, combiner, builder);
+    OP_REQUIRES_OK(ctx, result.status());
+    ctx->SetOutput(0, result.ValueOrDie());
+  }
+
+ private:
+  DataType dtype_;
+};
+
+REGISTER_XLA_OP(Name("UnsortedSegmentSum")
+                    .Device(DEVICE_TPU_XLA_JIT)
+                    .CompileTimeConstantInput("num_segments"),
+                UnsortedSegmentSum);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/where_op.cc b/tensorflow/core/tpu/kernels/xla/where_op.cc
new file mode 100644
index 0000000..420d5bc
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/where_op.cc
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/comparators.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/ops_util.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+
+class WhereOp : public XlaOpKernel {
+ public:
+  explicit WhereOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    xla::XlaOp condition = ctx->Input(0);
+    xla::StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(condition);
+    OP_REQUIRES_OK(ctx, input_shape.status());
+    // Use S32 as indices first, then convert to S64 in the end if needed.
+    auto iota_shape = input_shape.ValueOrDie();
+    iota_shape.set_element_type(xla::S32);
+
+    int64 flattened_size = xla::Product(iota_shape.dimensions());
+    xla::XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
+    xla::XlaOp zeros = xla::ZerosLike(reshaped_condition);
+    xla::XlaOp zeros_int = xla::ConvertElementType(zeros, xla::S32);
+    xla::XlaOp reshaped_condition_int =
+        xla::ConvertElementType(reshaped_condition, xla::S32);
+    xla::XlaOp compared = xla::ConvertElementType(
+        xla::Gt(reshaped_condition_int, zeros_int), xla::S32);
+    xla::XlaOp length = xla::ReduceAll(
+        compared, xla::Zero(ctx->builder(), xla::S32),
+        xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
+
+    std::vector<xla::XlaOp> to_sort = {reshaped_condition_int};
+    std::vector<xla::PrimitiveType> types_to_sort = {xla::S32};
+    // Generate iota for each dimension, which after combining becomes
+    // indices of each element.
+    for (int64 axis = 0; axis < iota_shape.rank(); ++axis) {
+      xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis);
+      xla::XlaOp reshaped = xla::Reshape(iota, {flattened_size});
+      to_sort.push_back(reshaped);
+      types_to_sort.push_back(xla::S32);
+    }
+
+    xla::XlaOp sorted = xla::Sort(
+        to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()),
+        /*dimension=*/0,
+        /*is_stable=*/true);
+    std::vector<xla::XlaOp> to_concat;
+    for (int64 i = 0; i < iota_shape.rank(); ++i) {
+      xla::XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1);
+      to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}));
+    }
+
+    xla::XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1);
+    result = xla::ConvertElementType(result, ctx->output_xla_type(0));
+    // Dynamic padder will handle the dynamic dimension.
+    xla::XlaOp result_padded = xla::SetDimensionSize(result, length, 0);
+    ctx->SetOutput(0, result_padded);
+  }
+};
+
+REGISTER_XLA_OP(Name("Where").Device(DEVICE_TPU_XLA_JIT), WhereOp);
+
+}  // namespace
+}  // namespace tensorflow