Add OutfeedDequeue ops that allow for a dynamic device ordinal to be used.
PiperOrigin-RevId: 333388391
Change-Id: I2bbed92c3e98f7180751f07a31cbf507ef6616cb
diff --git a/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueTupleV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueTupleV2.pbtxt
new file mode 100644
index 0000000..c8e044a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueTupleV2.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "OutfeedDequeueTupleV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "device_ordinal"
+ description: <<END
+An int scalar tensor, representing the TPU device to use. This should be -1 when
+the Op is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+END
+ }
+ out_arg {
+ name: "outputs"
+ description: <<END
+A list of tensors that will be read from the outfeed.
+END
+ }
+ attr {
+ name: "dtypes"
+ description: <<END
+The element types of each element in `outputs`.
+END
+ }
+ attr {
+ name: "shapes"
+ description: <<END
+The shapes of each tensor in `outputs`.
+END
+ }
+ summary: <<END
+Retrieve multiple values from the computation outfeed. Device ordinal is a
+tensor allowing dynamic outfeed.
+END
+ description: <<END
+This operation will block indefinitely until data is available. Output `i`
+corresponds to XLA tuple element `i`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueV2.pbtxt
new file mode 100644
index 0000000..fc7a0b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OutfeedDequeueV2.pbtxt
@@ -0,0 +1,37 @@
+op {
+ graph_op_name: "OutfeedDequeueV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "device_ordinal"
+ description: <<END
+An int scalar tensor, representing the TPU device to use. This should be -1 when
+the Op is running on a TPU device, and >= 0 when the Op is running on the CPU
+device.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A tensor that will be read from the device outfeed.
+END
+ }
+ attr {
+ name: "dtype"
+ description: <<END
+The type of elements in the tensor.
+END
+ }
+ attr {
+ name: "shape"
+ description: <<END
+The shape of the tensor.
+END
+ }
+ summary: <<END
+Retrieves a single tensor from the computation outfeed. Device ordinal is a
+tensor allowing dynamic outfeed.
+END
+ description: <<END
+This operation will block indefinitely until data is available.
+END
+}
diff --git a/tensorflow/core/ops/tpu_outfeed_ops.cc b/tensorflow/core/ops/tpu_outfeed_ops.cc
index e170ed0..dce19bc 100644
--- a/tensorflow/core/ops/tpu_outfeed_ops.cc
+++ b/tensorflow/core/ops/tpu_outfeed_ops.cc
@@ -16,6 +16,7 @@
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
@@ -65,4 +66,38 @@
return Status::OK();
});
+REGISTER_OP("OutfeedDequeueV2")
+ .Input("device_ordinal: int32")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ExplicitShape);
+
+REGISTER_OP("OutfeedDequeueTupleV2")
+ .Input("device_ordinal: int32")
+ .Output("outputs: dtypes")
+ .Attr("dtypes: list(type)")
+ .Attr("shapes: list(shape)")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("device ordinal must be a scalar.");
+ }
+ std::vector<PartialTensorShape> shapes;
+ std::vector<DataType> dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+ if (shapes.size() != dtypes.size()) {
+ return errors::InvalidArgument(
+ "Incorrect number of output shapes specified");
+ }
+ for (int i = 0; i < shapes.size(); ++i) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
+ c->set_output(i, out);
+ }
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.cc b/tensorflow/core/tpu/kernels/outfeed_ops.cc
index 51a3a71..bc9a9d1 100644
--- a/tensorflow/core/tpu/kernels/outfeed_ops.cc
+++ b/tensorflow/core/tpu/kernels/outfeed_ops.cc
@@ -30,14 +30,16 @@
namespace tensorflow {
-TpuOutfeedDequeueOp::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
- : TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
+template <class T>
+TpuOutfeedDequeueOp<T>::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
+ : T(ctx, "outfeed_dequeue", 1) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
}
-Status TpuOutfeedDequeueOp::DoWork(
+template <class T>
+Status TpuOutfeedDequeueOp<T>::DoWork(
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) {
Tensor* output;
@@ -61,8 +63,9 @@
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
// device outfeed queue.
-TpuOutfeedDequeueTupleOp::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
- : TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
+template <class T>
+TpuOutfeedDequeueTupleOp<T>::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
+ : T(ctx, "outfeed_dequeue", 1) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
OP_REQUIRES(
@@ -79,7 +82,8 @@
tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
}
-Status TpuOutfeedDequeueTupleOp::DoWork(
+template <class T>
+Status TpuOutfeedDequeueTupleOp<T>::DoWork(
OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) {
VLOG(1) << "TransferLiteralFromOutfeed "
@@ -103,14 +107,29 @@
// device_ordinal to indicate which TPU to receive outfeed from.
REGISTER_KERNEL_BUILDER(
Name("OutfeedDequeue").Device(DEVICE_TPU_NODE).HostMemory("output"),
- TpuOutfeedDequeueOp);
+ TpuOutfeedDequeueOp<TpuTransferAsyncOpKernel>);
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeue").Device(DEVICE_CPU),
- TpuOutfeedDequeueOp);
+ TpuOutfeedDequeueOp<TpuTransferAsyncOpKernel>);
REGISTER_KERNEL_BUILDER(
Name("OutfeedDequeueTuple").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
- TpuOutfeedDequeueTupleOp);
+ TpuOutfeedDequeueTupleOp<TpuTransferAsyncOpKernel>);
REGISTER_KERNEL_BUILDER(Name("OutfeedDequeueTuple").Device(DEVICE_CPU),
- TpuOutfeedDequeueTupleOp);
+ TpuOutfeedDequeueTupleOp<TpuTransferAsyncOpKernel>);
+
+// Below ops take device_ordinal as an input tensor rather than a attribute.
+REGISTER_KERNEL_BUILDER(
+ Name("OutfeedDequeueV2").Device(DEVICE_TPU_NODE).HostMemory("output"),
+ TpuOutfeedDequeueOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
+REGISTER_KERNEL_BUILDER(
+ Name("OutfeedDequeueV2").Device(DEVICE_CPU),
+ TpuOutfeedDequeueOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("OutfeedDequeueTupleV2").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
+ TpuOutfeedDequeueTupleOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
+REGISTER_KERNEL_BUILDER(
+ Name("OutfeedDequeueTupleV2").Device(DEVICE_CPU),
+ TpuOutfeedDequeueTupleOp<TpuTransferAsyncDynamicOrdinalOpKernel>);
} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.h b/tensorflow/core/tpu/kernels/outfeed_ops.h
index 5e3ed87..3474ff2 100644
--- a/tensorflow/core/tpu/kernels/outfeed_ops.h
+++ b/tensorflow/core/tpu/kernels/outfeed_ops.h
@@ -25,7 +25,8 @@
// The OutfeedDequeue op is used to retrieve a single tensor from the device
// outfeed queue.
-class TpuOutfeedDequeueOp : public TpuTransferAsyncOpKernel {
+template <class T>
+class TpuOutfeedDequeueOp : public T {
public:
explicit TpuOutfeedDequeueOp(OpKernelConstruction* ctx);
@@ -45,7 +46,8 @@
// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
// device outfeed queue.
-class TpuOutfeedDequeueTupleOp : public TpuTransferAsyncOpKernel {
+template <class T>
+class TpuOutfeedDequeueTupleOp : public T {
public:
explicit TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.cc b/tensorflow/core/tpu/kernels/transfer_ops.cc
index a5cdfd4..16d3ecf 100644
--- a/tensorflow/core/tpu/kernels/transfer_ops.cc
+++ b/tensorflow/core/tpu/kernels/transfer_ops.cc
@@ -27,27 +27,19 @@
namespace tensorflow {
-TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
- const string& transfer_type,
- int number_of_threads)
+TpuTransferAsyncOpKernelBase::TpuTransferAsyncOpKernelBase(
+ OpKernelConstruction* ctx, const string& transfer_type,
+ int number_of_threads)
: AsyncOpKernel(ctx),
+ transfer_type_(transfer_type),
thread_pool_(new thread::ThreadPool(
ctx->env(),
strings::StrCat(transfer_type, "_thread_",
SanitizeThreadSuffix(def().name())),
- /*num_threads=*/8)) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
- if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
- OP_REQUIRES(
- ctx, device_ordinal_ >= 0,
- errors::InvalidArgument(transfer_type,
- " ops must specify a device_ordinal when "
- "placed on CPU."));
- }
-}
+ /*num_threads=*/8)) {}
-void TpuTransferAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
- DoneCallback done) {
+void TpuTransferAsyncOpKernelBase::ComputeAsync(OpKernelContext* ctx,
+ DoneCallback done) {
CancellationToken token =
ctx->cancellation_manager()->get_cancellation_token();
bool already_cancelled;
@@ -68,11 +60,12 @@
});
}
-Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
+Status TpuTransferAsyncOpKernelBase::RunTransferWithOrdinal(
+ OpKernelContext* ctx, int device_ordinal) {
auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
/*initialize_platform=*/false);
- int real_device_ordinal = device_ordinal_;
+ int real_device_ordinal = device_ordinal;
if (real_device_ordinal < 0) {
const XlaDevice::Metadata* metadata;
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
@@ -91,9 +84,47 @@
stream_executor);
}
-void TpuTransferAsyncOpKernel::Cancel() {
+void TpuTransferAsyncOpKernelBase::Cancel() {
mutex_lock lock(mu_);
TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
}
+TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
+ const string& transfer_type,
+ int number_of_threads)
+ : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
+ if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+ OP_REQUIRES(
+ ctx, device_ordinal_ >= 0,
+ errors::InvalidArgument(transfer_type,
+ " ops must specify a device_ordinal when "
+ "placed on CPU."));
+ }
+}
+
+Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
+ return RunTransferWithOrdinal(ctx, device_ordinal_);
+}
+
+TpuTransferAsyncDynamicOrdinalOpKernel::TpuTransferAsyncDynamicOrdinalOpKernel(
+ OpKernelConstruction* ctx, const string& transfer_type,
+ int number_of_threads)
+ : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads) {}
+
+Status TpuTransferAsyncDynamicOrdinalOpKernel::RunTransfer(
+ OpKernelContext* ctx) {
+ const Tensor& device_ordinal_tensor = ctx->input(0);
+ const int device_ordinal = device_ordinal_tensor.scalar<int32>()();
+ XlaDevice* xla_device =
+ dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
+ if (((xla_device == nullptr) || (xla_device->device_type() == DEVICE_CPU)) &&
+ (device_ordinal < 0)) {
+ return errors::InvalidArgument(transfer_type_,
+ " ops must specify a device_ordinal when "
+ "placed on CPU.");
+ }
+ return RunTransferWithOrdinal(ctx, device_ordinal);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.h b/tensorflow/core/tpu/kernels/transfer_ops.h
index d98d743..6a3662a 100644
--- a/tensorflow/core/tpu/kernels/transfer_ops.h
+++ b/tensorflow/core/tpu/kernels/transfer_ops.h
@@ -25,11 +25,11 @@
// Base class providing common functionality for async ops that transfer from
// host to TPU.
-class TpuTransferAsyncOpKernel : public AsyncOpKernel {
+class TpuTransferAsyncOpKernelBase : public AsyncOpKernel {
public:
- explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
- const string& transfer_type,
- int number_of_threads);
+ explicit TpuTransferAsyncOpKernelBase(OpKernelConstruction* ctx,
+ const string& transfer_type,
+ int number_of_threads);
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
@@ -38,19 +38,54 @@
xla::TpuTransferManagerInterface* transfer_manager,
stream_executor::StreamExecutor* stream_executor) = 0;
+ Status RunTransferWithOrdinal(OpKernelContext* ctx, int device_ordinal);
+ std::string transfer_type_;
+
private:
- Status RunTransfer(OpKernelContext* ctx);
+ virtual Status RunTransfer(OpKernelContext* ctx) = 0;
void Cancel();
std::unique_ptr<thread::ThreadPool> thread_pool_;
- int device_ordinal_;
mutex mu_;
+ // TpuTransferAsyncOpKernelBase is neither copyable nor movable.
+ TpuTransferAsyncOpKernelBase(const TpuTransferAsyncOpKernelBase&) = delete;
+ TpuTransferAsyncOpKernelBase& operator=(const TpuTransferAsyncOpKernelBase&) =
+ delete;
+};
+
+class TpuTransferAsyncOpKernel : public TpuTransferAsyncOpKernelBase {
+ public:
+ explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
+ const string& transfer_type,
+ int number_of_threads);
+
+ private:
+ Status RunTransfer(OpKernelContext* ctx) override;
+ int device_ordinal_;
+
// TpuTransferAsyncOpKernel is neither copyable nor movable.
TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
};
+class TpuTransferAsyncDynamicOrdinalOpKernel
+ : public TpuTransferAsyncOpKernelBase {
+ public:
+ explicit TpuTransferAsyncDynamicOrdinalOpKernel(OpKernelConstruction* ctx,
+ const string& transfer_type,
+ int number_of_threads);
+
+ private:
+ Status RunTransfer(OpKernelContext* ctx) override;
+
+ // TpuTransferAsyncDynamicOpKernel is neither copyable nor movable.
+ TpuTransferAsyncDynamicOrdinalOpKernel(
+ const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete;
+ TpuTransferAsyncDynamicOrdinalOpKernel& operator=(
+ const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index acfb75c..2a2c310 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -2733,6 +2733,14 @@
argspec: "args=[\'dtypes\', \'shapes\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
+ name: "OutfeedDequeueTupleV2"
+ argspec: "args=[\'device_ordinal\', \'dtypes\', \'shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "OutfeedDequeueV2"
+ argspec: "args=[\'device_ordinal\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "OutfeedEnqueue"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index acfb75c..2a2c310 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -2733,6 +2733,14 @@
argspec: "args=[\'dtypes\', \'shapes\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
+ name: "OutfeedDequeueTupleV2"
+ argspec: "args=[\'device_ordinal\', \'dtypes\', \'shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "OutfeedDequeueV2"
+ argspec: "args=[\'device_ordinal\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "OutfeedEnqueue"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}