Introduce additional XLA TPU Ops to open source
PiperOrigin-RevId: 326343558
Change-Id: I47da1dc0c96cdf8223ccebef012e2a5088a857a4
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 6d33690..0c7ac66 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -38,6 +38,7 @@
":tpu_execute_op",
":tpu_handle_to_key_op",
":transfer_ops",
+ "//tensorflow/core/tpu/kernels/xla:xla_ops",
],
)
diff --git a/tensorflow/core/tpu/kernels/xla/BUILD b/tensorflow/core/tpu/kernels/xla/BUILD
new file mode 100644
index 0000000..f55583a
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/BUILD
@@ -0,0 +1,52 @@
+# XLA Ops for TPUs
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "xla_ops",
+ srcs = [
+ "get_item_op.cc",
+ "host_compute_ops.cc",
+ "index_ops.cc",
+ "infeed_op.cc",
+ "inplace_ops.cc",
+ "outfeed_ops.cc",
+ "segment_reduction_ops.cc",
+ "where_op.cc",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:sharding_util",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla:xla_context",
+ "//tensorflow/compiler/tf2xla:xla_helpers",
+ "//tensorflow/compiler/tf2xla:xla_op_registry",
+ "//tensorflow/compiler/tf2xla/kernels:if_op",
+ "//tensorflow/compiler/tf2xla/kernels:while_op",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/tf2xla/lib:scatter",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:comparators",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/tpu:tpu_api",
+ "//tensorflow/core/tpu:tpu_defs",
+ "//tensorflow/core/tpu/kernels:cross_replica_ops",
+ "//tensorflow/stream_executor/tpu:c_api_conversions",
+ "//tensorflow/stream_executor/tpu:c_api_decl",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/core/tpu/kernels/xla/get_item_op.cc b/tensorflow/core/tpu/kernels/xla/get_item_op.cc
new file mode 100644
index 0000000..094c6b8
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/get_item_op.cc
@@ -0,0 +1,75 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_util.h"
+
+namespace tensorflow {
+namespace {
+
+// The Xla kernel to build up the computation for get_item(data, index).
+class GetItemXlaOp : public XlaOpKernel {
+ public:
+ explicit GetItemXlaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape& data_shape = ctx->InputShape(0);
+ const TensorShape& index_shape = ctx->InputShape(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVectorOrHigher(data_shape),
+ errors::InvalidArgument("data must be at least 1 dimensional."));
+ OP_REQUIRES(ctx, index_shape.dims() == 1 && index_shape.dim_size(0) == 1,
+ errors::InvalidArgument("index must be a vector of size 1."));
+
+ // NOTE(pbar) Use Concat to extend the indices to match cl/142279605.
+ // This isn't the simplest way to emit the indices, but the code for
+ // dynamic slice needs to be able to see that minor dims are const zero.
+ auto const_zero = xla::ConstantR0(ctx->builder(), 0);
+ std::vector<xla::XlaOp> operands;
+ operands.push_back(xla::Reshape(ctx->Input(1), {}));
+ for (int i = 1; i < data_shape.dims(); i++) {
+ operands.push_back(const_zero);
+ }
+
+ std::vector<int64> dims = {0};
+ std::vector<int64> slice_sizes = {1};
+ std::vector<int64> out_sizes = {};
+ for (int i = 1; i < data_shape.dims(); i++) {
+ dims.push_back(i);
+ auto size = data_shape.dim_size(i);
+ slice_sizes.push_back(size);
+ out_sizes.push_back(size);
+ }
+ // NOTE: DynamicSlice here doesn't raise an error or wraps the index
+ // if its out-of-range.
+ auto slice = xla::DynamicSlice(ctx->Input(0), operands, slice_sizes);
+ // In-order collapse to remove the 1st dim.
+ auto reshape = xla::Reshape(slice, dims, out_sizes);
+ ctx->SetOutput(0, reshape);
+ }
+};
+
+REGISTER_XLA_OP(Name("GetItem"), GetItemXlaOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc
new file mode 100644
index 0000000..be3ee1c
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc
@@ -0,0 +1,498 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/common_runtime/lower_function_call_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+
+namespace {
+
+// TODO(phawkins) add a canonical copy of these operator names and refactor
+// everything to use it.
+static const char* const kSendFromHostOp = "_XlaSendFromHost";
+static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
+
+Status MakeXlaShapes(gtl::ArraySlice<TensorShape> shapes,
+ gtl::ArraySlice<DataType> dtypes,
+ std::vector<xla::Shape>* xla_shapes,
+ xla::Shape* xla_shape) {
+ for (int i = 0; i < shapes.size(); i++) {
+ xla::Shape single_xla_shape;
+ TF_RETURN_IF_ERROR(
+ TensorShapeToXLAShape(dtypes[i], shapes[i], &single_xla_shape));
+ VLOG(2) << "Shape " << single_xla_shape.DebugString();
+ xla_shapes->push_back(single_xla_shape);
+ }
+ // Temporarily add a dummy output to the shape array before making the tuple:
+ // this output is used for control dependencies between host compute ops.
+ xla_shapes->push_back(xla::ShapeUtil::MakeShape(xla::PRED, {}));
+ *xla_shape = xla::ShapeUtil::MakeTupleShape(*xla_shapes);
+ // Remove the dummy output from the vector that will be used to copy real
+ // outputs from host to device.
+ xla_shapes->pop_back();
+ return Status::OK();
+}
+
+// This TensorFlow pseudo-op is used to record host-side computation.
+class HostComputeOp : public XlaOpKernel {
+ public:
+ explicit HostComputeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("cost_estimate_ns", &cost_estimate_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tpu_core", &tpu_core_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_dtypes_));
+ OP_REQUIRES(ctx, ctx->num_inputs() == input_dtypes_.size(),
+ errors::InvalidArgument("Tinputs size=", input_dtypes_.size(),
+ " but expected ", ctx->num_inputs(),
+ " inputs."));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_dtypes_));
+ OP_REQUIRES(ctx, ctx->num_outputs() == output_dtypes_.size(),
+ errors::InvalidArgument("Toutputs size=", output_dtypes_.size(),
+ " but expected ", ctx->num_outputs(),
+ " outputs."));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ancestors", &ancestors_));
+ NameAttrList shape_inference_graph;
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("shape_inference_graph", &shape_inference_graph));
+ if (shape_inference_graph.name().empty()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &static_output_shapes_));
+ OP_REQUIRES(ctx, static_output_shapes_.size() == output_dtypes_.size(),
+ errors::InvalidArgument(
+ "shapes attr list size ", static_output_shapes_.size(),
+ " differs from dtypes size ", output_dtypes_.size()));
+ OP_REQUIRES_OK(ctx, MakeXlaShapes(static_output_shapes_, output_dtypes_,
+ &static_xla_output_shapes_,
+ &static_xla_output_shape_));
+ VLOG(2) << "Output Shape: " << static_xla_output_shape_.DebugString();
+ } else {
+ FunctionLibraryRuntime* flib_runtime = ctx->function_library();
+ OP_REQUIRES(ctx, flib_runtime != nullptr,
+ errors::Internal(
+ "No function library runtime at kernel construction"));
+ const FunctionLibraryDefinition* library =
+ flib_runtime->GetFunctionLibraryDefinition();
+ const FunctionDef* fdef = library->Find(shape_inference_graph.name());
+ OP_REQUIRES(ctx, fdef != nullptr,
+ errors::Internal("Failed to find function ",
+ shape_inference_graph.name(),
+ " in function library."));
+ OP_REQUIRES_OK(ctx, FunctionDefToBodyHelper(
+ *fdef, AttrSlice(&shape_inference_graph.attr()),
+ library, &shape_inference_graph_function_));
+ VLOG(2) << "Output Shape to be inferred at compile time";
+ }
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
+ OP_REQUIRES(ctx, !token_input_nodes_.empty(),
+ errors::InvalidArgument("XlaHostCompute node does not have ",
+ kXlaTokenInputNodesAttrName, " attr"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
+ &original_node_name_));
+ }
+
+ ~HostComputeOp() override {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ XlaCompiler* compiler = ctx->compiler();
+
+ std::vector<xla::XlaOp> input_handles;
+ std::vector<TensorShape> input_shapes;
+ auto inputs = ctx->InputList("inputs", &input_handles, &input_shapes);
+ const auto device_sharding = xla::sharding_builder::AssignDevice(tpu_core_);
+ xla::XlaScopedShardingAssignment assign_sharding(b, device_sharding);
+
+ std::vector<xla::XlaOp> input_tokens;
+ for (auto& token_input_node : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(token_input_node);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ input_tokens.push_back(token_or.ValueOrDie());
+ }
+ xla::XlaOp token = xla::AfterAll(b, input_tokens);
+
+ // Send values to the host.
+ std::vector<xla::XlaOp> send_to_host_tokens;
+ for (int i = 0; i < input_handles.size(); ++i) {
+ const string channel_name = absl::StrCat(key_, "_dtoh_", i);
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtypes_[i],
+ input_shapes[i], &xla_shape));
+ // Specify frontend attributes.
+ xla::FrontendAttributes attrs;
+ (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
+ (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+ xla::primitive_util::LowercasePrimitiveTypeName(
+ xla_shape.element_type());
+ b->SetFrontendAttributes(attrs);
+ xla::ChannelHandle channel;
+ OP_REQUIRES_OK(
+ ctx, compiler->GetDeviceToHostChannelHandle(channel_name, &channel));
+ send_to_host_tokens.push_back(
+ xla::SendToHost(input_handles[i], token, xla_shape, channel));
+ b->ClearOpMetadata();
+ }
+ xla::XlaOp recv_from_host_token_input =
+ send_to_host_tokens.empty() ? token
+ : xla::AfterAll(b, send_to_host_tokens);
+ if (!input_handles.empty()) {
+ // Register the shapes used in this transfer.
+ OP_REQUIRES_OK(ctx, ctx->compiler()->SetDeviceToHostMetadata(
+ key_, input_dtypes_, input_shapes));
+ }
+ // Compute the shapes of the values to copy to the device, if necessary.
+ std::vector<TensorShape>* output_shapes;
+ std::vector<xla::Shape>* xla_output_shapes;
+ xla::Shape* xla_output_shape;
+ std::vector<TensorShape> inferred_output_shapes;
+ std::vector<xla::Shape> inferred_xla_output_shapes;
+ xla::Shape inferred_xla_output_shape;
+ if (shape_inference_graph_function_) {
+ OP_REQUIRES_OK(
+ ctx, InferOutputShapes(
+ ctx, ctx->function_library()->GetFunctionLibraryDefinition(),
+ &inferred_output_shapes));
+ OP_REQUIRES_OK(ctx, MakeXlaShapes(inferred_output_shapes, output_dtypes_,
+ &inferred_xla_output_shapes,
+ &inferred_xla_output_shape));
+ output_shapes = &inferred_output_shapes;
+ xla_output_shapes = &inferred_xla_output_shapes;
+ xla_output_shape = &inferred_xla_output_shape;
+ } else {
+ output_shapes = &static_output_shapes_;
+ xla_output_shapes = &static_xla_output_shapes_;
+ xla_output_shape = &static_xla_output_shape_;
+ }
+ OP_REQUIRES(
+ ctx, output_shapes->size() == ctx->num_outputs(),
+ errors::InvalidArgument("Op has ", ctx->num_outputs(), " outputs ",
+ " but output shape vector of size ",
+ output_shapes->size()));
+ if (ctx->num_outputs() > 0) {
+ // Register the shapes used in this transfer.
+ OP_REQUIRES_OK(ctx, ctx->compiler()->SetHostToDeviceMetadata(
+ key_, output_dtypes_, *output_shapes));
+ }
+ // Copy results to the device.
+ std::vector<xla::XlaOp> recv_from_host_tokens;
+ for (int i = 0; i < output_shapes->size(); ++i) {
+ const string channel_name = absl::StrCat(key_, "_htod_", i);
+ // Specify frontend attributes.
+ xla::FrontendAttributes attrs;
+ (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = channel_name;
+ (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+ xla::primitive_util::LowercasePrimitiveTypeName(
+ xla_output_shapes->at(i).element_type());
+ b->SetFrontendAttributes(attrs);
+ xla::ChannelHandle channel;
+ OP_REQUIRES_OK(
+ ctx, compiler->GetHostToDeviceChannelHandle(channel_name, &channel));
+
+ const auto result_token_tuple = xla::RecvFromHost(
+ recv_from_host_token_input, xla_output_shapes->at(i), channel);
+ b->ClearOpMetadata();
+ recv_from_host_tokens.push_back(
+ xla::GetTupleElement(result_token_tuple, /*index=*/1));
+ ctx->SetOutput(i, xla::GetTupleElement(result_token_tuple, 0));
+ }
+
+ // Set token output.
+ xla::XlaOp token_output = recv_from_host_tokens.empty()
+ ? recv_from_host_token_input
+ : xla::AfterAll(b, recv_from_host_tokens);
+ OP_REQUIRES_OK(
+ ctx, ctx->compiler()->SetNodeToken(original_node_name_, token_output));
+ }
+
+ private:
+ Status LowerFunctionalOps(Graph* g,
+ const FunctionLibraryDefinition& flib_def) {
+ bool modified;
+ do {
+ modified = false;
+
+ // Lower "If" nodes first. Their body functions will be expanded as
+ // function call nodes, which we will lower later.
+ // We do not need to lower "While" nodes because shape inference can
+ // handle them correctly (output shapes are input shapes).
+ std::vector<Node*> if_nodes;
+ for (Node* n : g->op_nodes()) {
+ if (n->type_string() == "If") {
+ if_nodes.push_back(n);
+ }
+ }
+ for (Node* if_node : if_nodes) {
+ TF_RETURN_IF_ERROR(
+ RewriteIfNode(if_node, g, /*keep_node_fetchable=*/false));
+ }
+ if (!if_nodes.empty()) {
+ modified = true;
+ }
+
+ // Lower function call nodes.
+ std::vector<Node*> call_nodes;
+ for (Node* n : g->op_nodes()) {
+ if (IsFunctionCall(flib_def, *n)) {
+ call_nodes.push_back(n);
+ }
+ }
+ for (Node* call_node : call_nodes) {
+ TF_RETURN_IF_ERROR(RewriteFunctionCallNode(
+ call_node, g, flib_def, /*keep_caller_fetchable=*/false));
+ }
+ if (!call_nodes.empty()) {
+ modified = true;
+ }
+ } while (modified);
+
+ return Status::OK();
+ }
+
+ Status InferOutputShapes(XlaOpKernelContext* ctx,
+ const FunctionLibraryDefinition* flib_def,
+ std::vector<TensorShape>* output_shapes) {
+ // First unpack the inference graphdef from the attr into graph. Don't do
+ // any shape inference at this point.
+ Graph* graph = shape_inference_graph_function_->graph;
+
+ // Lower functional ops, because they are not friendly to shape inference.
+ TF_RETURN_IF_ERROR(LowerFunctionalOps(graph, *flib_def));
+
+ // Now run shape inference, filling in the shapes of recvathost nodes.
+ bool got_output_shapes = false;
+ ShapeRefiner shape_refiner{graph->versions().producer(),
+ graph->op_registry()};
+ std::vector<Node*> nodes;
+ GetReversePostOrder(*graph, &nodes);
+ for (auto node : nodes) {
+ TF_RETURN_IF_ERROR(shape_refiner.AddNode(node));
+ if (node->type_string() == kRecvAtHostOp) {
+ const AttrValue* key_attr = node->attrs().Find("key");
+ if (key_attr == nullptr) {
+ return errors::InvalidArgument("Node ", node->name(),
+ " has no key attribute");
+ }
+ std::vector<TensorShape> dtoh_shapes;
+ if (!ctx->compiler()
+ ->GetDeviceToHostShapes(key_attr->s(), &dtoh_shapes)
+ .ok()) {
+ return errors::InvalidArgument(
+ "Shape inference for HostCompute ", ctx->op_kernel().name(),
+ " failed: host recv node ", node->name(), " with key '",
+ key_attr->s(), "' has unknown shapes.");
+ }
+ if (dtoh_shapes.size() != node->num_outputs()) {
+ return errors::InvalidArgument(
+ "Shape inference for HostCompute ", ctx->op_kernel().name(),
+ " failed: host recv node ", node->name(), " with key '",
+ key_attr->s(), "' has ", node->num_outputs(),
+ " outputs but inferred shapes expect ", dtoh_shapes.size());
+ }
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ shape_inference::InferenceContext* shape_ctx =
+ shape_refiner.GetContext(node);
+ shape_inference::ShapeHandle handle;
+ TF_RETURN_IF_ERROR(
+ shape_ctx->MakeShapeFromTensorShape(dtoh_shapes.at(i), &handle));
+ shape_ctx->set_output(i, handle);
+ }
+ } else if (node->type_string() == kSendFromHostOp) {
+ if (got_output_shapes) {
+ return errors::InvalidArgument(
+ "Shape inference for HostCompute ", ctx->op_kernel().name(),
+ " failed: inference graph has multiple send from host nodes");
+ } else {
+ got_output_shapes = true;
+ // The last input is the dynamic key so don't record its shape.
+ output_shapes->resize(node->num_inputs() - 1);
+ shape_inference::InferenceContext* shape_ctx =
+ shape_refiner.GetContext(node);
+ for (int i = 0; i < node->num_inputs() - 1; ++i) {
+ shape_inference::ShapeHandle handle = shape_ctx->input(i);
+ if (!shape_ctx->FullyDefined(handle)) {
+ return errors::InvalidArgument(
+ "Shape inference for HostCompute ", ctx->op_kernel().name(),
+ " failed: send from host node ", node->name(),
+ " has non-fully defined shape of input index ", i);
+ }
+ TensorShapeProto shape_proto;
+ shape_ctx->ShapeHandleToProto(handle, &shape_proto);
+ (*output_shapes)[i] = TensorShape(shape_proto);
+ VLOG(2) << "Inferred shape " << shape_proto.DebugString();
+ }
+ }
+ }
+ }
+ if (!got_output_shapes) {
+ return errors::InvalidArgument(
+ "Shape inference for HostCompute ", ctx->op_kernel().name(),
+ " failed: inference graph has no send from host node");
+ }
+ return Status::OK();
+ }
+
+ DataTypeVector input_dtypes_;
+ DataTypeVector output_dtypes_;
+ std::vector<string> ancestors_;
+ std::vector<TensorShape> static_output_shapes_;
+ std::vector<xla::Shape> static_xla_output_shapes_;
+ string original_node_name_;
+ // If static_xla_output_shapes_.size() == 1 then xla_output_shape_ is the
+ // unique output shape, otherwise it is a tuple of all the xla_output_shapes_.
+ xla::Shape static_xla_output_shape_;
+ string key_;
+ // If shape inference is performed at runtime, the graph needed to perform
+ // shape inference is stored in this function.
+ std::unique_ptr<FunctionBody> shape_inference_graph_function_;
+ int64 cost_estimate_;
+ int64 tpu_core_;
+ std::vector<string> token_input_nodes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HostComputeOp);
+};
+
+class SendToHostOp : public XlaOpKernel {
+ public:
+ explicit SendToHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinput", &input_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
+ OP_REQUIRES(ctx, !token_input_nodes_.empty(),
+ errors::InvalidArgument("XlaSendToHost node does not have ",
+ kXlaTokenInputNodesAttrName, " attr"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
+ &original_node_name_));
+ }
+
+ ~SendToHostOp() override {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+
+ XlaCompiler* compiler = ctx->compiler();
+ xla::XlaOp operand = ctx->Input(0);
+ std::vector<xla::XlaOp> input_tokens;
+ for (auto& token_input_node : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(token_input_node);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ input_tokens.push_back(token_or.ValueOrDie());
+ }
+ xla::XlaOp token = xla::AfterAll(b, input_tokens);
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(input_dtype_, ctx->InputShape(0),
+ &xla_shape));
+ // Specify frontend attributes.
+ xla::FrontendAttributes attrs;
+ (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
+ (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+ xla::primitive_util::LowercasePrimitiveTypeName(
+ xla_shape.element_type());
+ b->SetFrontendAttributes(attrs);
+ xla::ChannelHandle channel;
+ OP_REQUIRES_OK(ctx, compiler->GetDeviceToHostChannelHandle(key_, &channel));
+ xla::XlaOp output_token =
+ xla::SendToHost(operand, token, xla_shape, channel);
+ OP_REQUIRES_OK(ctx,
+ compiler->SetNodeToken(original_node_name_, output_token));
+ }
+
+ private:
+ DataType input_dtype_;
+ string key_;
+ std::vector<string> token_input_nodes_;
+ string original_node_name_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SendToHostOp);
+};
+
+class RecvFromHostOp : public XlaOpKernel {
+ public:
+ explicit RecvFromHostOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput", &output_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &output_shape_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_));
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_));
+ OP_REQUIRES(ctx, !token_input_nodes_.empty(),
+ errors::InvalidArgument("XlaRecvFromHost node does not have ",
+ kXlaTokenInputNodesAttrName, " attr"));
+ }
+
+ ~RecvFromHostOp() override {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+
+ XlaCompiler* compiler = ctx->compiler();
+ std::vector<xla::XlaOp> input_tokens;
+ for (auto& token_input_node : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(token_input_node);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ input_tokens.push_back(token_or.ValueOrDie());
+ }
+ xla::XlaOp token = xla::AfterAll(b, input_tokens);
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(
+ ctx, TensorShapeToXLAShape(output_dtype_, output_shape_, &xla_shape));
+ // Specify frontend attributes.
+ xla::FrontendAttributes attrs;
+ (*attrs.mutable_map())[kXlaHostTransferRendezvousNameAttr] = key_;
+ (*attrs.mutable_map())[kXlaHostTransferOriginalTypeAttr] =
+ xla::primitive_util::LowercasePrimitiveTypeName(
+ xla_shape.element_type());
+ b->SetFrontendAttributes(attrs);
+ xla::ChannelHandle channel;
+ OP_REQUIRES_OK(ctx, compiler->GetHostToDeviceChannelHandle(key_, &channel));
+ xla::XlaOp result = xla::RecvFromHost(token, xla_shape, channel);
+ // xla::RecvFromHost returns a tuple of (received data, token).
+ ctx->SetOutput(0, xla::GetTupleElement(result, 0));
+ OP_REQUIRES_OK(
+ ctx, compiler->SetNodeToken(name(), xla::GetTupleElement(result, 1)));
+ }
+
+ private:
+ DataType output_dtype_;
+ TensorShape output_shape_;
+ string key_;
+ std::vector<string> token_input_nodes_;
+ TF_DISALLOW_COPY_AND_ASSIGN(RecvFromHostOp);
+};
+
+REGISTER_XLA_OP(Name("XlaHostCompute"), HostComputeOp);
+REGISTER_XLA_OP(Name("XlaSendToHost"), SendToHostOp);
+REGISTER_XLA_OP(Name("XlaRecvFromHost"), RecvFromHostOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/index_ops.cc b/tensorflow/core/tpu/kernels/xla/index_ops.cc
new file mode 100644
index 0000000..40148f3
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/index_ops.cc
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/index_ops.h"
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+
+// This registration is needed here because the ArgMax Op is defined in
+// third_party where DEVICE_TPU_XLA_JIT is not visible. Most Ops don't need a
+// specific TPU whitelist, but ArgMax does because it has a separate CustomCall
+// implementation on CPU.
+REGISTER_XLA_OP(Name("ArgMax")
+ .Device(DEVICE_TPU_XLA_JIT)
+ .CompileTimeConstantInput("dimension"),
+ XlaArgMaxOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/infeed_op.cc b/tensorflow/core/tpu/kernels/xla/infeed_op.cc
new file mode 100644
index 0000000..941a543
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/infeed_op.cc
@@ -0,0 +1,162 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/c_api_decl.h"
+
+namespace tensorflow {
+
+namespace {
+
+xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
+ XLA_Shape c_shape;
+ XLA_Shape c_infeed_shape;
+
+ ApiConverter::ToC(shape, &c_shape);
+
+ tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
+ &c_infeed_shape);
+ xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
+ ApiConverter::Free(&c_shape);
+ ApiConverter::Free(&c_infeed_shape);
+ return infeed_shape;
+}
+
+// Updates the layout of the given infeed shape, optionally considering the
+// sharding of the op. If the op has tile sharding, assign the layout based on
+// the shard shape.
+Status UpdateInfeedLayout(xla::Shape* shape,
+ absl::optional<xla::OpSharding> sharding) {
+ if (sharding && sharding->type() == xla::OpSharding::OTHER) {
+ TF_ASSIGN_OR_RETURN(auto hlo_sharding,
+ xla::HloSharding::FromProto(*sharding));
+ for (int64 i = 0; i < sharding->tile_assignment_devices_size(); ++i) {
+ auto device = sharding->tile_assignment_devices(i);
+ auto shard_shape =
+ GetTPUInfeedLayout(hlo_sharding.TileShape(*shape, device));
+ if (i == 0) {
+ *shape->mutable_layout() = shard_shape.layout();
+ }
+ if (xla::ShapeUtil::ElementsIn(shard_shape) == 0) {
+ // Shapes with 0 dimensions may be assigned with a different layout, but
+ // it doesn't matter since we're not sending any data.
+ continue;
+ }
+ if (!xla::LayoutUtil::Equal(shard_shape.layout(), shape->layout())) {
+ return xla::Unimplemented(
+ "Sharded infeed with non-uniform layouts is not supported. Try "
+ "turning off the infeed layout optimization "
+ "(--transpose_tpu_infeed=false) and report to XLA team.");
+ }
+ }
+ return Status::OK();
+ }
+ *shape = GetTPUInfeedLayout(*shape);
+ return Status::OK();
+}
+
+// TODO(pbar) Work out if we need to Infeed Tuples - if so then
+// this op will need a way to provide a list of shapes
+// since they can't be provided by the runtime JIT mechanism.
+// (InfeedDequeue has no inputs!)
+// Compare this op to tf.Queue operations which operate on N tensors.
+
+// This TensorFlow op supports the XLA Infeed primitve.
+class InfeedDequeueOp : public XlaOpKernel {
+ public:
+ explicit InfeedDequeueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shape_, b->sharding()));
+ ctx->SetOutput(0, xla::Infeed(b, xla_shape_));
+ }
+
+ private:
+ TensorShape shape_;
+ DataType dtype_;
+ xla::Shape xla_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueOp);
+};
+
+REGISTER_XLA_OP(Name("InfeedDequeue"), InfeedDequeueOp);
+
+// This TensorFlow op supports the XLA Infeed primitive for tuple types.
+class InfeedDequeueTupleOp : public XlaOpKernel {
+ public:
+ explicit InfeedDequeueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+ for (int i = 0; i < shapes_.size(); i++) {
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx,
+ TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
+ xla_shapes_.push_back(xla_shape);
+ }
+ }
+
+ ~InfeedDequeueTupleOp() override {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ for (int64 i = 0; i < xla_shapes_.size(); ++i) {
+ absl::optional<xla::OpSharding> sharding;
+ if (b->sharding()) {
+ sharding = b->sharding()->type() == xla::OpSharding::TUPLE
+ ? b->sharding()->tuple_shardings(i)
+ : b->sharding();
+ }
+ OP_REQUIRES_OK(ctx, UpdateInfeedLayout(&xla_shapes_[i], sharding));
+ }
+ tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
+ auto tuple = xla::Infeed(b, tuple_shape_);
+
+ // Don't apply the infeed tuple sharding to the get-tuple-elements. They
+ // need non-tuple shardings.
+ xla::XlaScopedShardingAssignment clear_sharding(b, absl::nullopt);
+ for (int i = 0; i < shapes_.size(); ++i) {
+ ctx->SetOutput(i, xla::GetTupleElement(tuple, i));
+ }
+ }
+
+ private:
+ std::vector<TensorShape> shapes_;
+ DataTypeVector dtypes_;
+ std::vector<xla::Shape> xla_shapes_;
+ xla::Shape tuple_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InfeedDequeueTupleOp);
+};
+
+REGISTER_XLA_OP(Name("InfeedDequeueTuple"), InfeedDequeueTupleOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/inplace_ops.cc b/tensorflow/core/tpu/kernels/xla/inplace_ops.cc
new file mode 100644
index 0000000..9baffd6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/inplace_ops.cc
@@ -0,0 +1,142 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+class InplaceUpdateOp : public XlaOpKernel {
+ public:
+ explicit InplaceUpdateOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ VLOG(3) << "InplaceUpdateOp::Compile";
+
+ DataType index_type = input_type(1);
+ OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
+ errors::InvalidArgument("index must be int32 or int64"));
+
+ // TF Args are X, I, V
+ const TensorShape x_shape = ctx->InputShape(0);
+ const TensorShape i_shape = ctx->InputShape(1);
+ const TensorShape v_shape = ctx->InputShape(2);
+
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsScalar(i_shape) ||
+ TensorShapeUtils::IsVector(i_shape),
+ errors::InvalidArgument("index must be Rank 0 or 1"));
+ OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
+ errors::InvalidArgument("X and V must have the same Rank,"
+ " X.shape=",
+ x_shape.DebugString(),
+ " V.shape=", v_shape.DebugString()));
+
+ auto* builder = ctx->builder();
+ auto const_zero = xla::ConstantR0(builder, 0);
+ auto current = ctx->Input(0);
+
+ for (int64 i = 0; i < i_shape.num_elements(); i++) {
+ std::vector<xla::XlaOp> update_indices;
+ update_indices.push_back(
+ xla::Reshape(xla::SliceInDim(ctx->Input(1), i, i + 1, 1, 0), {}));
+ for (int xi = 1; xi < x_shape.dims(); xi++) {
+ update_indices.push_back(const_zero);
+ }
+ current = xla::DynamicUpdateSlice(
+ current, xla::SliceInDim(ctx->Input(2), i, i + 1, 1, 0),
+ update_indices);
+ }
+ ctx->SetOutput(0, current);
+
+ // TODO(b/118122460): Uncomment+format this code to use XLA Scatter.
+ // auto* builder = ctx->builder();
+ // const auto initial = ctx->Input(0);
+ // const auto indices = ctx->Input(1);
+ // const auto updates = ctx->Input(2);
+ //
+ // auto result = XlaScatter(
+ // initial, updates, indices, /*indices_are_vectors=*/false,
+ // [](xla::XlaOp, xla::XlaOp second, xla::XlaBuilder*) { return
+ // second; }, builder);
+ // OP_REQUIRES_OK(ctx, result.status());
+ // ctx->SetOutput(0, result.ValueOrDie());
+ }
+};
+
+REGISTER_XLA_OP(Name("InplaceUpdate"), InplaceUpdateOp);
+
+class InplaceAddOp : public XlaOpKernel {
+ public:
+ explicit InplaceAddOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ VLOG(3) << "InplaceAddOp::Compile";
+
+ DataType index_type = input_type(1);
+ OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
+ errors::InvalidArgument("index must be int32 or int64"));
+
+ // TF Args are X, I, V
+ const TensorShape x_shape = ctx->InputShape(0);
+ const TensorShape i_shape = ctx->InputShape(1);
+ const TensorShape v_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx,
+ (TensorShapeUtils::IsScalar(i_shape) ||
+ ((i_shape.dims() == 1) && (i_shape.num_elements() == 1))),
+ errors::InvalidArgument("index must be Rank 1 and size 1"));
+ OP_REQUIRES(ctx, (x_shape.dims() == v_shape.dims()),
+ errors::InvalidArgument("X and V must have the same Rank,"
+ " X.shape=",
+ x_shape.DebugString(),
+ " V.shape=", v_shape.DebugString()));
+ // Pad the indices out to the match the rank of params.
+ auto* builder = ctx->builder();
+ std::vector<xla::XlaOp> padded_indices;
+ padded_indices.push_back(xla::Reshape(ctx->Input(1), {}));
+ for (int i = 0; i < x_shape.dims() - 1; ++i) {
+ padded_indices.push_back(XlaHelpers::Zero(builder, index_type));
+ }
+
+ std::vector<int64> sizes;
+ sizes.push_back(1);
+ for (int i = 1; i < x_shape.dims(); i++) {
+ sizes.push_back(x_shape.dim_size(i));
+ }
+
+ auto prev = xla::DynamicSlice(ctx->Input(0), padded_indices, sizes);
+ auto updated = xla::Add(prev, ctx->Input(2));
+ auto result =
+ xla::DynamicUpdateSlice(ctx->Input(0), updated, padded_indices);
+ ctx->SetOutput(0, result);
+ }
+};
+
+REGISTER_XLA_OP(Name("InplaceAdd"), InplaceAddOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc
new file mode 100644
index 0000000..8abdd3d
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+namespace {
+
+// This TensorFlow op implements the XLA Outfeed primitive.
+class OutfeedEnqueueOp : public XlaOpKernel {
+ public:
+ explicit OutfeedEnqueueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(
+ ctx, TensorShapeToXLAShape(dtype_, ctx->InputShape(0), &xla_shape));
+ // Outfeed configuration is only needed for embedding outfeed.
+ const string outfeed_config;
+ xla::Outfeed(ctx->Input(0), xla_shape, outfeed_config);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueOp);
+};
+
+REGISTER_XLA_OP(Name("OutfeedEnqueue"), OutfeedEnqueueOp);
+
+// This TensorFlow op implements the XLA Outfeed primitive for tuple types.
+class OutfeedEnqueueTupleOp : public XlaOpKernel {
+ public:
+ explicit OutfeedEnqueueTupleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<xla::XlaOp> handles;
+ std::vector<TensorShape> shapes;
+ auto inputs = ctx->InputList("inputs", &handles, &shapes);
+
+ std::vector<xla::Shape> xla_shapes;
+ for (int i = 0; i < shapes.size(); ++i) {
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx,
+ TensorShapeToXLAShape(dtypes_[i], shapes[i], &xla_shape));
+ xla_shapes.push_back(xla_shape);
+ }
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(xla_shapes);
+ VLOG(1) << "OutfeedEnqueueTuple: "
+ << xla::ShapeUtil::HumanStringWithLayout(tuple_shape);
+ auto b = ctx->builder();
+ auto tuple = xla::Tuple(b, handles);
+ // Outfeed configuration is only needed for embedding outfeed.
+ const string outfeed_config;
+ xla::Outfeed(tuple, tuple_shape, outfeed_config);
+ }
+
+ private:
+ DataTypeVector dtypes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OutfeedEnqueueTupleOp);
+};
+
+REGISTER_XLA_OP(Name("OutfeedEnqueueTuple"), OutfeedEnqueueTupleOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
new file mode 100644
index 0000000..f7c33e5
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
@@ -0,0 +1,145 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+// TODO(b/32945756): Add a scatter op in XLA and move this to a HLO optimization
+// pass. Optimization for UnsortedSegmentSum on TPU: use k-hot matmul. This
+// optimization requires:
+// 1. data has dtype supported by TPU matmul and has rank of 1 or 2.
+// 2. indices has rank of 1.
+// 3. matmul op count is less than 800 billion.
+//
+// Example of calculating UnsortedSegmentSum by k-hot matmul:
+// data shape [A, B]
+// indices shape [A]
+// num_segment N
+// output shape [N, B]
+// matmul op count N * A * B
+// Step 1: create k-hot matrix
+// k-hot matrix has shape of [A, N], where row i is responsible for
+// collecting the sum of the i-th segment, concretely
+// k-hot[i][j] = 1 if indices[i] = j
+// Step 2: perform matmul
+// the final result is obtained by multiplying k-hot matrix with data
+// matrix, namely
+// k-hot * data => result
+// shape: [N, A] * [A, B] => [N, B]
+xla::XlaOp KHotMatmul(XlaOpKernelContext* ctx, xla::XlaBuilder* builder,
+ const xla::XlaOp data, const xla::XlaOp indices,
+ int64 num_segments) {
+ DataType data_dtype = ctx->input_type(0);
+ xla::PrimitiveType indices_type = ctx->input_xla_type(1);
+ TensorShape data_shape = ctx->InputShape(0);
+ TensorShape indices_shape = ctx->InputShape(1);
+ xla::XlaOp linspace = xla::Iota(builder, indices_type, num_segments);
+ xla::XlaOp linspace_col = xla::Reshape(linspace, {num_segments, 1});
+ TensorShape indices_row_shape = indices_shape;
+ indices_row_shape.InsertDim(0, 1);
+ xla::XlaOp indices_row = xla::Reshape(indices, indices_row_shape.dim_sizes());
+ xla::XlaOp k_hot = xla::Eq(indices_row, linspace_col);
+ xla::XlaOp k_hot_with_data_dtype =
+ XlaHelpers::ConvertElementType(k_hot, data_dtype);
+ // F32 version of the KHotMatmul. It splits the F32 data into three
+ // BF16 partial data and run KHotMatmul for each of them. The final result
+ // is the summation of three BF16 results.
+ // Note that this still doesn't fully retain f32 precision.
+ // In particular, values smaller than 2^-111 may see loss of precision.
+ xla::PrecisionConfig precision_config;
+ if (data_dtype == DT_FLOAT) {
+ precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST);
+ } else {
+ CHECK_EQ(data_dtype, DT_BFLOAT16);
+ precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
+ }
+ precision_config.add_operand_precision(xla::PrecisionConfig::DEFAULT);
+ return xla::Dot(k_hot_with_data_dtype, data, &precision_config);
+}
+
+class UnsortedSegmentSum : public XlaOpKernel {
+ public:
+ explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // output = unsorted_segment_sum(data, indices, num_segments)
+ // Compute a tensor such that:
+ // output[i] = sum over {j where indices[j] == i} of data[j]
+ // output[i] == 0 if i does not appear in indices
+ //
+ // Contrast with segment_sum(), which assumes indices are sorted and that
+ // max(indices)+1 is the desired size of the output.
+ //
+ // The returned output tensor has the same type as data, and the same shape
+ // as data with the first indices.rank dimensions are replaced
+ // by a single dimension with size num_segments.
+ xla::XlaOp data = ctx->Input(0);
+ TensorShape data_shape = ctx->InputShape(0);
+
+ xla::XlaOp indices = ctx->Input(1);
+ TensorShape indices_shape = ctx->InputShape(1);
+
+ int64 num_segments;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
+
+ OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
+ errors::InvalidArgument(
+ "UnsortedSegmentSum requires that indices' rank be"
+ " less than or equal to data's rank."));
+ // Validate that indices.shape is a prefix of data.shape.
+ for (int d = 0; d < indices_shape.dims(); ++d) {
+ OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
+ errors::InvalidArgument(
+ "UnsortedSegmentSum requires indices shape to be prefix"
+ " of data_shape, but dimension ",
+ d, " differs ", data_shape.dim_size(d), " vs. ",
+ indices_shape.dim_size(d)));
+ }
+ xla::XlaBuilder* builder = ctx->builder();
+ TensorShape buffer_shape = data_shape;
+ buffer_shape.RemoveDimRange(0, indices_shape.dims());
+ buffer_shape.InsertDim(0, num_segments);
+ auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
+ buffer_shape.dim_sizes());
+
+ auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
+ return a + b;
+ };
+
+ auto result = XlaScatter(buffer, /*updates=*/data, indices,
+ /*indices_are_vectors=*/false, combiner, builder);
+ OP_REQUIRES_OK(ctx, result.status());
+ ctx->SetOutput(0, result.ValueOrDie());
+ }
+
+ private:
+ DataType dtype_;
+};
+
+REGISTER_XLA_OP(Name("UnsortedSegmentSum")
+ .Device(DEVICE_TPU_XLA_JIT)
+ .CompileTimeConstantInput("num_segments"),
+ UnsortedSegmentSum);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/xla/where_op.cc b/tensorflow/core/tpu/kernels/xla/where_op.cc
new file mode 100644
index 0000000..420d5bc
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/xla/where_op.cc
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/comparators.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/ops_util.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+namespace {
+
+class WhereOp : public XlaOpKernel {
+ public:
+ explicit WhereOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp condition = ctx->Input(0);
+ xla::StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(condition);
+ OP_REQUIRES_OK(ctx, input_shape.status());
+ // Use S32 as indices first, then convert to S64 in the end if needed.
+ auto iota_shape = input_shape.ValueOrDie();
+ iota_shape.set_element_type(xla::S32);
+
+ int64 flattened_size = xla::Product(iota_shape.dimensions());
+ xla::XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
+ xla::XlaOp zeros = xla::ZerosLike(reshaped_condition);
+ xla::XlaOp zeros_int = xla::ConvertElementType(zeros, xla::S32);
+ xla::XlaOp reshaped_condition_int =
+ xla::ConvertElementType(reshaped_condition, xla::S32);
+ xla::XlaOp compared = xla::ConvertElementType(
+ xla::Gt(reshaped_condition_int, zeros_int), xla::S32);
+ xla::XlaOp length = xla::ReduceAll(
+ compared, xla::Zero(ctx->builder(), xla::S32),
+ xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
+
+ std::vector<xla::XlaOp> to_sort = {reshaped_condition_int};
+ std::vector<xla::PrimitiveType> types_to_sort = {xla::S32};
+ // Generate iota for each dimension, which after combining becomes
+ // indices of each element.
+ for (int64 axis = 0; axis < iota_shape.rank(); ++axis) {
+ xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis);
+ xla::XlaOp reshaped = xla::Reshape(iota, {flattened_size});
+ to_sort.push_back(reshaped);
+ types_to_sort.push_back(xla::S32);
+ }
+
+ xla::XlaOp sorted = xla::Sort(
+ to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()),
+ /*dimension=*/0,
+ /*is_stable=*/true);
+ std::vector<xla::XlaOp> to_concat;
+ for (int64 i = 0; i < iota_shape.rank(); ++i) {
+ xla::XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1);
+ to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}));
+ }
+
+ xla::XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1);
+ result = xla::ConvertElementType(result, ctx->output_xla_type(0));
+ // Dynamic padder will handle the dynamic dimension.
+ xla::XlaOp result_padded = xla::SetDimensionSize(result, length, 0);
+ ctx->SetOutput(0, result_padded);
+ }
+};
+
+REGISTER_XLA_OP(Name("Where").Device(DEVICE_TPU_XLA_JIT), WhereOp);
+
+} // namespace
+} // namespace tensorflow