Add OutfeedDequeue ops that allow for a dynamic device ordinal to be used.

PiperOrigin-RevId: 333388391
Change-Id: I2bbed92c3e98f7180751f07a31cbf507ef6616cb
diff --git a/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueTupleV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueTupleV2.pbtxt
new file mode 100644
index 0000000..c8e044a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueTupleV2.pbtxt
@@ -0,0 +1,38 @@
+op {
+  graph_op_name: "OutfeedDequeueTupleV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "device_ordinal"
+    description: <<END
+An int scalar tensor, representing the TPU device to use. This should be -1 when
+the Op is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+END
+  }
+  out_arg {
+    name: "outputs"
+    description: <<END
+A list of tensors that will be read from the outfeed.
+END
+  }
+  attr {
+    name: "dtypes"
+    description: <<END
+The element types of each element in `outputs`.
+END
+  }
+  attr {
+    name: "shapes"
+    description: <<END
+The shapes of each tensor in `outputs`.
+END
+  }
+  summary: <<END
+Retrieve multiple values from the computation outfeed. Device ordinal is a
+tensor allowing dynamic outfeed.
+END
+  description: <<END
+This operation will block indefinitely until data is available. Output `i`
+corresponds to XLA tuple element `i`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueV2.pbtxt
new file mode 100644
index 0000000..fc7a0b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueV2.pbtxt
@@ -0,0 +1,37 @@
+op {
+  graph_op_name: "OutfeedDequeueV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "device_ordinal"
+    description: <<END
+An int scalar tensor, representing the TPU device to use. This should be -1 when
+the Op is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+A tensor that will be read from the device outfeed.
+END
+  }
+  attr {
+    name: "dtype"
+    description: <<END
+The type of elements in the tensor.
+END
+  }
+  attr {
+    name: "shape"
+    description: <<END
+The shape of the tensor.
+END
+  }
+  summary: <<END
+Retrieves a single tensor from the computation outfeed. Device ordinal is a
+tensor allowing dynamic outfeed.
+END
+  description: <<END
+This operation will block indefinitely until data is available.
+END
+}
diff --git a/tensorflow/core/ops/tpu_outfeed_ops.cc b/tensorflow/core/ops/tpu_outfeed_ops.cc
index e170ed0..dce19bc 100644
--- a/tensorflow/core/ops/tpu_outfeed_ops.cc
+++ b/tensorflow/core/ops/tpu_outfeed_ops.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace tensorflow {
 
@@ -65,4 +66,38 @@
       return Status::OK();
     });
 
+REGISTER_OP("OutfeedDequeueV2")
+    .Input("device_ordinal: int32")
+    .Output("output: dtype")
+    .Attr("dtype: type")
+    .Attr("shape: shape")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ExplicitShape);
+
+REGISTER_OP("OutfeedDequeueTupleV2")
+    .Input("device_ordinal: int32")
+    .Output("outputs: dtypes")
+    .Attr("dtypes: list(type)")
+    .Attr("shapes: list(shape)")
+    .SetIsStateful()
+    .SetShapeFn([](InferenceContext* c) {
+      if (c->Rank(c->input(0)) != 0) {
+        return errors::InvalidArgument("device ordinal must be a scalar.");
+      }
+      std::vector<PartialTensorShape> shapes;
+      std::vector<DataType> dtypes;
+      TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+      TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+      if (shapes.size() != dtypes.size()) {
+        return errors::InvalidArgument(
+            "Incorrect number of output shapes specified");
+      }
+      for (int i = 0; i < shapes.size(); ++i) {
+        ShapeHandle out;
+        TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
+        c->set_output(i, out);
+      }
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.cc b/tensorflow/core/tpu/kernels/outfeed_ops.cc
index 51a3a71..bc9a9d1 100644
--- a/tensorflow/core/tpu/kernels/outfeed_ops.cc
+++ b/tensorflow/core/tpu/kernels/outfeed_ops.cc
@@ -30,14 +30,16 @@
 
 namespace tensorflow {
 
-TpuOutfeedDequeueOp::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
-    : TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
+template <class T>
+TpuOutfeedDequeueOp<T>::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
+    : T(ctx, "outfeed_dequeue", 1) {
   OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
   OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
   OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
 }
 
-Status TpuOutfeedDequeueOp::DoWork(
+template <class T>
+Status TpuOutfeedDequeueOp<T>::DoWork(
     OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
     stream_executor::StreamExecutor* stream_executor) {
   Tensor* output;
@@ -61,8 +63,9 @@
 
 // The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
 // device outfeed queue.
-TpuOutfeedDequeueTupleOp::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
-    : TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
+template <class T>
+TpuOutfeedDequeueTupleOp<T>::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
+    : T(ctx, "outfeed_dequeue", 1) {
   OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
   OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
   OP_REQUIRES(
@@ -79,7 +82,8 @@
   tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
 }
 
-Status TpuOutfeedDequeueTupleOp::DoWork(
+template <class T>
+Status TpuOutfeedDequeueTupleOp<T>::DoWork(
     OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
     stream_executor::StreamExecutor* stream_executor) {
   VLOG(1) << "TransferLiteralFromOutfeed "
@@ -103,14 +107,29 @@
 // device_ordinal to indicate which TPU to receive outfeed from.
 REGISTER_KERNEL_BUILDER(
     Name("OutfeedDequeue").Device(DEVICE_TPU_NODE).HostMemory("output"),
-    TpuOutfeedDequeueOp);
+    TpuOutfeedDequeueOp<TpuTransferAsyncOpKernel>);
 REGISTER_KERNEL_BUILDER(Name("OutfeedDequeue").Device(DEVICE_CPU),
-                        TpuOutfeedDequeueOp);
+                        TpuOutfeedDequeueOp<TpuTransferAsyncOpKernel>);
 
 REGISTER_KERNEL_BUILDER(
     Name("OutfeedDequeueTuple").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
-    TpuOutfeedDequeueTupleOp);
+    TpuOutfeedDequeueTupleOp<TpuTransferAsyncOpKernel>);
 REGISTER_KERNEL_BUILDER(Name("OutfeedDequeueTuple").Device(DEVICE_CPU),
-                        TpuOutfeedDequeueTupleOp);
+                        TpuOutfeedDequeueTupleOp<TpuTransferAsyncOpKernel>);
+
+// Below ops take device_ordinal as an input tensor rather than a attribute.
+REGISTER_KERNEL_BUILDER(
+    Name("OutfeedDequeueV2").Device(DEVICE_TPU_NODE).HostMemory("output"),
+    TpuOutfeedDequeueOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
+REGISTER_KERNEL_BUILDER(
+    Name("OutfeedDequeueV2").Device(DEVICE_CPU),
+    TpuOutfeedDequeueOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
+
+REGISTER_KERNEL_BUILDER(
+    Name("OutfeedDequeueTupleV2").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
+    TpuOutfeedDequeueTupleOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
+REGISTER_KERNEL_BUILDER(
+    Name("OutfeedDequeueTupleV2").Device(DEVICE_CPU),
+    TpuOutfeedDequeueTupleOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.h b/tensorflow/core/tpu/kernels/outfeed_ops.h
index 5e3ed87..3474ff2 100644
--- a/tensorflow/core/tpu/kernels/outfeed_ops.h
+++ b/tensorflow/core/tpu/kernels/outfeed_ops.h
@@ -25,7 +25,8 @@
 
 // The OutfeedDequeue op is used to retrieve a single tensor from the device
 // outfeed queue.
-class TpuOutfeedDequeueOp : public TpuTransferAsyncOpKernel {
+template <class T>
+class TpuOutfeedDequeueOp : public T {
  public:
   explicit TpuOutfeedDequeueOp(OpKernelConstruction* ctx);
 
@@ -45,7 +46,8 @@
 
 // The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
 // device outfeed queue.
-class TpuOutfeedDequeueTupleOp : public TpuTransferAsyncOpKernel {
+template <class T>
+class TpuOutfeedDequeueTupleOp : public T {
  public:
   explicit TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx);
 
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.cc b/tensorflow/core/tpu/kernels/transfer_ops.cc
index a5cdfd4..16d3ecf 100644
--- a/tensorflow/core/tpu/kernels/transfer_ops.cc
+++ b/tensorflow/core/tpu/kernels/transfer_ops.cc
@@ -27,27 +27,19 @@
 
 namespace tensorflow {
 
-TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
-                                                   const string& transfer_type,
-                                                   int number_of_threads)
+TpuTransferAsyncOpKernelBase::TpuTransferAsyncOpKernelBase(
+    OpKernelConstruction* ctx, const string& transfer_type,
+    int number_of_threads)
     : AsyncOpKernel(ctx),
+      transfer_type_(transfer_type),
       thread_pool_(new thread::ThreadPool(
           ctx->env(),
           strings::StrCat(transfer_type, "_thread_",
                           SanitizeThreadSuffix(def().name())),
-          /*num_threads=*/8)) {
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
-  if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
-    OP_REQUIRES(
-        ctx, device_ordinal_ >= 0,
-        errors::InvalidArgument(transfer_type,
-                                " ops must specify a device_ordinal when "
-                                "placed on CPU."));
-  }
-}
+          /*num_threads=*/8)) {}
 
-void TpuTransferAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
-                                            DoneCallback done) {
+void TpuTransferAsyncOpKernelBase::ComputeAsync(OpKernelContext* ctx,
+                                                DoneCallback done) {
   CancellationToken token =
       ctx->cancellation_manager()->get_cancellation_token();
   bool already_cancelled;
@@ -68,11 +60,12 @@
   });
 }
 
-Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
+Status TpuTransferAsyncOpKernelBase::RunTransferWithOrdinal(
+    OpKernelContext* ctx, int device_ordinal) {
   auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
       /*initialize_platform=*/false);
 
-  int real_device_ordinal = device_ordinal_;
+  int real_device_ordinal = device_ordinal;
   if (real_device_ordinal < 0) {
     const XlaDevice::Metadata* metadata;
     TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
@@ -91,9 +84,47 @@
       stream_executor);
 }
 
-void TpuTransferAsyncOpKernel::Cancel() {
+void TpuTransferAsyncOpKernelBase::Cancel() {
   mutex_lock lock(mu_);
   TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
 }
 
+TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
+                                                   const string& transfer_type,
+                                                   int number_of_threads)
+    : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads) {
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
+  if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+    OP_REQUIRES(
+        ctx, device_ordinal_ >= 0,
+        errors::InvalidArgument(transfer_type,
+                                " ops must specify a device_ordinal when "
+                                "placed on CPU."));
+  }
+}
+
+Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
+  return RunTransferWithOrdinal(ctx, device_ordinal_);
+}
+
+TpuTransferAsyncDynamicOrdinalOpKernel::TpuTransferAsyncDynamicOrdinalOpKernel(
+    OpKernelConstruction* ctx, const string& transfer_type,
+    int number_of_threads)
+    : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads) {}
+
+Status TpuTransferAsyncDynamicOrdinalOpKernel::RunTransfer(
+    OpKernelContext* ctx) {
+  const Tensor& device_ordinal_tensor = ctx->input(0);
+  const int device_ordinal = device_ordinal_tensor.scalar<int32>()();
+  XlaDevice* xla_device =
+      dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
+  if (((xla_device == nullptr) || (xla_device->device_type() == DEVICE_CPU)) &&
+      (device_ordinal < 0)) {
+    return errors::InvalidArgument(transfer_type_,
+                                   " ops must specify a device_ordinal when "
+                                   "placed on CPU.");
+  }
+  return RunTransferWithOrdinal(ctx, device_ordinal);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.h b/tensorflow/core/tpu/kernels/transfer_ops.h
index d98d743..6a3662a 100644
--- a/tensorflow/core/tpu/kernels/transfer_ops.h
+++ b/tensorflow/core/tpu/kernels/transfer_ops.h
@@ -25,11 +25,11 @@
 
 // Base class providing common functionality for async ops that transfer from
 // host to TPU.
-class TpuTransferAsyncOpKernel : public AsyncOpKernel {
+class TpuTransferAsyncOpKernelBase : public AsyncOpKernel {
  public:
-  explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
-                                    const string& transfer_type,
-                                    int number_of_threads);
+  explicit TpuTransferAsyncOpKernelBase(OpKernelConstruction* ctx,
+                                        const string& transfer_type,
+                                        int number_of_threads);
 
   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
 
@@ -38,19 +38,54 @@
                         xla::TpuTransferManagerInterface* transfer_manager,
                         stream_executor::StreamExecutor* stream_executor) = 0;
 
+  Status RunTransferWithOrdinal(OpKernelContext* ctx, int device_ordinal);
+  std::string transfer_type_;
+
  private:
-  Status RunTransfer(OpKernelContext* ctx);
+  virtual Status RunTransfer(OpKernelContext* ctx) = 0;
   void Cancel();
 
   std::unique_ptr<thread::ThreadPool> thread_pool_;
-  int device_ordinal_;
   mutex mu_;
 
+  // TpuTransferAsyncOpKernelBase is neither copyable nor movable.
+  TpuTransferAsyncOpKernelBase(const TpuTransferAsyncOpKernelBase&) = delete;
+  TpuTransferAsyncOpKernelBase& operator=(const TpuTransferAsyncOpKernelBase&) =
+      delete;
+};
+
+class TpuTransferAsyncOpKernel : public TpuTransferAsyncOpKernelBase {
+ public:
+  explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
+                                    const string& transfer_type,
+                                    int number_of_threads);
+
+ private:
+  Status RunTransfer(OpKernelContext* ctx) override;
+  int device_ordinal_;
+
   // TpuTransferAsyncOpKernel is neither copyable nor movable.
   TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
   TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
 };
 
+class TpuTransferAsyncDynamicOrdinalOpKernel
+    : public TpuTransferAsyncOpKernelBase {
+ public:
+  explicit TpuTransferAsyncDynamicOrdinalOpKernel(OpKernelConstruction* ctx,
+                                                  const string& transfer_type,
+                                                  int number_of_threads);
+
+ private:
+  Status RunTransfer(OpKernelContext* ctx) override;
+
+  // TpuTransferAsyncDynamicOpKernel is neither copyable nor movable.
+  TpuTransferAsyncDynamicOrdinalOpKernel(
+      const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete;
+  TpuTransferAsyncDynamicOrdinalOpKernel& operator=(
+      const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete;
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index acfb75c..2a2c310 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -2733,6 +2733,14 @@
     argspec: "args=[\'dtypes\', \'shapes\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
   }
   member_method {
+    name: "OutfeedDequeueTupleV2"
+    argspec: "args=[\'device_ordinal\', \'dtypes\', \'shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "OutfeedDequeueV2"
+    argspec: "args=[\'device_ordinal\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "OutfeedEnqueue"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index acfb75c..2a2c310 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -2733,6 +2733,14 @@
     argspec: "args=[\'dtypes\', \'shapes\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
   }
   member_method {
+    name: "OutfeedDequeueTupleV2"
+    argspec: "args=[\'device_ordinal\', \'dtypes\', \'shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "OutfeedDequeueV2"
+    argspec: "args=[\'device_ordinal\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "OutfeedEnqueue"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }