blob: 7a538d77d1be4418ca6ddac317298722a96d9ca9 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <deque>
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace data {
namespace {
const char kAnonymousMultiDeviceIterator[] = "AnonymousMultiDeviceIterator";
const char kDevices[] = "devices";
const char kOutputShapes[] = "output_shapes";
const char kOutputTypes[] = "output_types";
struct HostBufferElement {
Status status;
bool end_of_sequence;
std::vector<Tensor> value;
};
using MultiDeviceIteratorCallback =
std::function<void(const HostBufferElement&)>;
class MultiDeviceIterator : public ResourceBase {
public:
MultiDeviceIterator(
Env* env, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
const std::vector<string>& devices,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* flr,
std::unique_ptr<FunctionHandleCache> function_handle_cache)
: unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"),
output_types_(output_types),
output_shapes_(output_shapes),
devices_(devices),
flib_def_(std::move(flib_def)),
flr_(flr),
pflr_(std::move(pflr)),
function_handle_cache_(std::move(function_handle_cache)) {
DCHECK(flr_ != nullptr);
}
string DebugString() const override {
return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
" devices");
}
Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
int64* incarnation_id) {
if (iterator) {
TF_RETURN_IF_ERROR(
VerifyTypesMatch(output_types_, iterator->output_dtypes()));
TF_RETURN_IF_ERROR(
VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
}
mutex_lock l(mu_);
if (multi_device_buffer_) {
multi_device_buffer_->Reset();
}
++incarnation_id_;
*incarnation_id = incarnation_id_;
multi_device_buffer_ = absl::make_unique<MultiDeviceBuffer>(
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator),
this);
return Status::OK();
}
void GetNextFromShard(OpKernelContext* ctx, int shard_num,
int64 incarnation_id, std::function<void()> done) {
tf_shared_lock l(mu_);
IteratorContext::Params params(ctx);
params.flr = flr_;
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &cancellation_manager_;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(ctx,
ConnectCancellationManagers(
ctx->cancellation_manager(),
params.cancellation_manager, &deregister_fn),
done);
IteratorContext iter_ctx(std::move(params));
MultiDeviceIteratorCallback callback = std::bind(
[ctx](const HostBufferElement& elem, const std::function<void()>& done,
const std::function<void()>& deregister_fn) {
// iterator->Unref();
Status s = elem.status;
if (!s.ok()) {
ctx->SetStatus(s);
} else if (elem.end_of_sequence) {
ctx->SetStatus(errors::OutOfRange("End of sequence"));
} else {
for (int i = 0; i < elem.value.size(); ++i) {
ctx->set_output(i, elem.value[i]);
}
}
deregister_fn();
done();
},
std::placeholders::_1, std::move(done), std::move(deregister_fn));
multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
std::move(callback));
}
const DataTypeVector& output_types() const { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const {
return output_shapes_;
}
FunctionLibraryRuntime* const flr() {
tf_shared_lock l(mu_);
return flr_;
}
FunctionHandleCache* function_handle_cache() {
return function_handle_cache_.get();
}
ResourceMgr* resource_mgr() { return &resource_mgr_; }
CancellationManager* cancellation_manager() { return &cancellation_manager_; }
private:
// A private class that uses a background thread to keep a per device buffer
// full.
class MultiDeviceBuffer {
public:
MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
std::unique_ptr<IteratorBase> host_iterator,
MultiDeviceIterator* parent)
: buffer_(size),
size_(size),
max_buffer_size_(max_buffer_size),
incarnation_id_(incarnation_id),
host_iterator_(std::move(host_iterator)),
parent_(parent) {}
~MultiDeviceBuffer() {
{
mutex_lock l(mu_);
if (!background_thread_started_) return;
}
Reset();
}
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
if (background_thread_ && !background_thread_finished_) {
cancelled_ = true;
// Wake up the background thread.
for (int i = 0; i < size_; ++i) {
buffer_[i].cond_var.notify_all();
}
// Make sure background thread has finished first.
while (!background_thread_finished_) {
shutdown_cond_var_.wait(l);
}
}
}
RunPendingCallbacks();
}
void GetNextFromShard(IteratorContext* ctx, int shard_num,
int64 incarnation_id,
MultiDeviceIteratorCallback callback) {
HostBufferElement elem;
if (incarnation_id_ != incarnation_id) {
elem.status = errors::InvalidArgument(
"Invalid incarnation id. Provided: ", incarnation_id,
"; Expected: ", incarnation_id_);
callback(elem);
return;
}
bool produced_output = false;
{
mutex_lock l(mu_);
if (cancelled_) {
elem.status = errors::Cancelled("Cancelled Multidevice iterator");
callback(elem);
return;
}
EnsureBackgroundThreadStarted(ctx);
if (!buffer_[shard_num].data.empty()) {
produced_output = true;
std::swap(elem, buffer_[shard_num].data.front());
buffer_[shard_num].data.pop_front();
// Wake up background thread if it is blocked on this element.
if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
buffer_[shard_num].cond_var.notify_all();
}
} else {
if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
buffer_[shard_num].callbacks.push_back(std::move(callback));
callback = nullptr;
}
}
}
if (produced_output) {
callback(elem);
}
}
private:
void EnsureBackgroundThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!background_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
background_thread_ =
parent_->unbounded_thread_pool_.get_thread_factory()->StartThread(
"tf_data_multi_device_iterator",
std::bind(
&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
this, std::move(ctx_copy)));
}
}
void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
// Run all remaining callbacks.
std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
std::vector<HostBufferElement> cancellation_elements;
{
mutex_lock l(mu_);
for (int i = 0; i < size_; ++i) {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
if (end_of_iterator_) {
elem.end_of_sequence = true;
} else {
elem.status =
errors::Cancelled("Cancelled and buffer not filled.");
}
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
std::move(buffer_[i].data.front()));
buffer_[i].data.pop_front();
}
cancellation_callbacks.push_back(
std::move(buffer_[i].callbacks.front()));
buffer_[i].callbacks.pop_front();
}
}
}
for (int i = 0; i < cancellation_callbacks.size(); ++i) {
cancellation_callbacks[i](cancellation_elements[i]);
}
}
void BackgroundThread(std::shared_ptr<IteratorContext> ctx) {
{
mutex_lock l(mu_);
background_thread_started_ = true;
}
int shard_to_fetch = 0;
while (true) {
HostBufferElement elem;
MultiDeviceIteratorCallback callback = nullptr;
bool end_of_iterator = false;
{
mutex_lock l(mu_);
while (!cancelled_ &&
buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
buffer_[shard_to_fetch].cond_var.wait(l);
}
if (cancelled_) {
background_thread_finished_ = true;
shutdown_cond_var_.notify_all();
return;
}
}
elem.status = host_iterator_->GetNext(ctx.get(), &elem.value,
&elem.end_of_sequence);
if (elem.status.ok() && elem.end_of_sequence) {
end_of_iterator = true;
}
{
mutex_lock l(mu_);
// Try to find a callback, else just push stuff into buffer.
if (!buffer_[shard_to_fetch].callbacks.empty()) {
callback = buffer_[shard_to_fetch].callbacks.front();
buffer_[shard_to_fetch].callbacks.pop_front();
} else {
buffer_[shard_to_fetch].data.push_back(std::move(elem));
elem = HostBufferElement();
}
}
if (callback) {
(*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
}
// Finish off the thread if we reach the end of the iterator. Runs
// pending callbacks.
if (end_of_iterator) {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
return;
}
shard_to_fetch = (shard_to_fetch + 1) % size_;
}
}
struct HostBuffer {
condition_variable cond_var;
std::deque<HostBufferElement> data;
std::deque<MultiDeviceIteratorCallback> callbacks;
};
mutex mu_;
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
bool background_thread_started_ GUARDED_BY(mu_) = false;
bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
std::vector<HostBuffer> buffer_;
const size_t size_;
const int64 max_buffer_size_;
const int64 incarnation_id_;
const std::unique_ptr<IteratorBase> host_iterator_;
MultiDeviceIterator* const parent_; // Not owned.
};
UnboundedThreadPool unbounded_thread_pool_;
mutex mu_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const std::vector<string> devices_;
const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
FunctionLibraryRuntime* const flr_ = nullptr; // not owned.
const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
const std::unique_ptr<FunctionHandleCache> function_handle_cache_;
ResourceMgr resource_mgr_;
CancellationManager cancellation_manager_;
int64 incarnation_id_ GUARDED_BY(mu_) = 0;
std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
};
// Used to generate unique names for anonymous multi device iterators.
static std::atomic<int64> current_id_;
// Just creates a MultiDeviceIterator and returns it.
class MultiDeviceIteratorHandleOp : public OpKernel {
public:
explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
: OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
}
// The resource is deleted from the resource manager only when it is private
// to kernel.
~MultiDeviceIteratorHandleOp() override {
if (resource_ != nullptr) {
resource_->Unref();
if (cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager()
->template Delete<MultiDeviceIterator>(cinfo_.container(),
cinfo_.name())
.ok()) {
// Do nothing; the resource can have been deleted by session resets.
}
}
}
}
void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
string unique_name = cinfo_.name();
string container_name = cinfo_.container();
{
mutex_lock l(mu_);
if (resource_ == nullptr) {
FunctionLibraryRuntime* flr;
std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
OP_REQUIRES_OK(context, context->function_library()->Clone(
&flib_def, &pflr, &flr));
std::unique_ptr<FunctionHandleCache> function_handle_cache =
absl::make_unique<FunctionHandleCache>(flr);
ResourceMgr* mgr = context->resource_manager();
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
MultiDeviceIterator* resource;
if (name_ == ResourceHandle::ANONYMOUS_NAME) {
unique_name = strings::StrCat("_AnonymousMultiDeviceIterator",
current_id_.fetch_add(1));
container_name = kAnonymousMultiDeviceIterator;
resource = new MultiDeviceIterator(
context->env(), output_types_, output_shapes_, devices_,
std::move(flib_def), std::move(pflr), flr,
std::move(function_handle_cache));
// NOTE: `mgr->Create()` transfers the one reference on `resource` to
// `mgr`.
OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
container_name, unique_name, resource));
} else {
unique_name = cinfo_.name();
container_name = cinfo_.container();
OP_REQUIRES_OK(context,
mgr->LookupOrCreate<MultiDeviceIterator>(
container_name, unique_name, &resource,
[this, context, flr, &flib_def, &pflr,
&function_handle_cache](MultiDeviceIterator** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new MultiDeviceIterator(
context->env(), output_types_,
output_shapes_, devices_,
std::move(flib_def), std::move(pflr),
flr, std::move(function_handle_cache));
return Status::OK();
}));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
resource->Unref();
context->SetStatus(s);
return;
}
resource_ = resource;
}
}
}
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, container_name, unique_name,
MakeTypeIndex<MultiDeviceIterator>()));
}
private:
// During the first Compute(), resource is either created or looked up using
// shared_name. In the latter case, the resource found should be verified if
// it is compatible with this op's configuration. The verification may fail in
// cases such as two graphs asking queues of the same shared name to have
// inconsistent capacities.
Status VerifyResource(MultiDeviceIterator* resource) {
TF_RETURN_IF_ERROR(
VerifyTypesMatch(output_types_, resource->output_types()));
TF_RETURN_IF_ERROR(
VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
return Status::OK();
}
mutex mu_;
ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
const int graph_def_version_;
string name_;
string container_;
std::vector<string> devices_;
};
REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
MultiDeviceIteratorHandleOp);
class AnonymousMultiDeviceIteratorOp
: public AnonymousResourceOp<MultiDeviceIterator> {
public:
explicit AnonymousMultiDeviceIteratorOp(OpKernelConstruction* ctx)
: AnonymousResourceOp<MultiDeviceIterator>(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
private:
string name() override { return kAnonymousMultiDeviceIterator; }
Status CreateResource(OpKernelContext* ctx,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib,
MultiDeviceIterator** resource) override {
auto function_handle_cache = absl::make_unique<FunctionHandleCache>(lib);
*resource =
new MultiDeviceIterator(ctx->env(), output_dtypes_, output_shapes_,
devices_, std::move(flib_def), std::move(pflr),
lib, std::move(function_handle_cache));
return Status::OK();
}
std::vector<string> devices_;
DataTypeVector output_dtypes_;
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name(kAnonymousMultiDeviceIterator).Device(DEVICE_CPU),
AnonymousMultiDeviceIteratorOp);
// Calls init on the MultiDeviceIterator.
class MultiDeviceIteratorInitOp : public OpKernel {
public:
explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor* tensor_max_buffer_size;
OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
std::unique_ptr<IteratorBase> iterator;
IteratorContext::Params params(ctx);
params.flr = resource->flr();
params.function_handle_cache = resource->function_handle_cache();
params.resource_mgr = resource->resource_mgr();
params.cancellation_manager = resource->cancellation_manager();
std::function<void()> deregister_fn;
OP_REQUIRES_OK(ctx, ConnectCancellationManagers(ctx->cancellation_manager(),
params.cancellation_manager,
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
OP_REQUIRES_OK(
ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
int64 incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
&incarnation_id));
Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
tensor_incarnation_id.scalar<int64>()() = incarnation_id;
OP_REQUIRES_OK(ctx,
ctx->set_output("incarnation_id", tensor_incarnation_id));
}
};
REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
MultiDeviceIteratorInitOp);
// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
// TODO(rohanj): Implement using BackgroundWorker that Derek built?
class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
public:
explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx) {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor* tensor_shard_num;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
int32 shard_num = tensor_shard_num->scalar<int32>()();
const Tensor* tensor_incarnation_id;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
core::RefCountPtr<MultiDeviceIterator> iterator;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
iterator->GetNextFromShard(ctx, shard_num, incarnation_id, std::move(done));
}
};
REGISTER_KERNEL_BUILDER(
Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
MultiDeviceIteratorGetNextFromShardOp);
class MultiDeviceIteratorToStringHandleOp : public OpKernel {
public:
explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
// Validate that the handle corresponds to a real resource, and
// that it is an MultiDeviceIterator.
core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
Tensor* string_handle_t;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &string_handle_t));
string_handle_t->scalar<tstring>()() =
resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
}
};
REGISTER_KERNEL_BUILDER(
Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
MultiDeviceIteratorToStringHandleOp);
class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
public:
explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES(
ctx,
output_types_.empty() || output_shapes_.empty() ||
output_types_.size() == output_shapes_.size(),
errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
"are set, they must have the same length."));
}
void Compute(OpKernelContext* ctx) override {
const Tensor& string_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
errors::InvalidArgument("string_handle must be a scalar"));
ResourceHandle resource_handle;
OP_REQUIRES(
ctx,
resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
errors::InvalidArgument(
"Could not parse string_handle as a valid ResourceHandle"));
OP_REQUIRES(
ctx, resource_handle.device() == ctx->device()->attributes().name(),
errors::InvalidArgument("Attempted create an iterator on device \"",
ctx->device()->attributes().name(),
"\" from handle defined on device \"",
resource_handle.device(), "\""));
// Validate that the handle corresponds to a real resource, and
// that it is an MultiDeviceIterator.
core::RefCountPtr<MultiDeviceIterator> resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
if (!output_types_.empty()) {
OP_REQUIRES_OK(ctx,
VerifyTypesMatch(output_types_, resource->output_types()));
}
if (!output_shapes_.empty()) {
OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
resource->output_shapes()));
}
Tensor* resource_handle_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
}
private:
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(
Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
MultiDeviceIteratorFromStringHandleOp);
class DeleteMultiDeviceIteratorOp : public OpKernel {
public:
explicit DeleteMultiDeviceIteratorOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
// The iterator resource is guaranteed to exist because the variant tensor
// wrapping the deleter is provided as an unused input to this op, which
// guarantees that it has not run yet.
Status s = ctx->resource_manager()->Delete(handle);
if (errors::IsNotFound(s)) {
// TODO(b/135948230): Investigate why is the above statement not true and
// then get rid of the special case.
ctx->SetStatus(Status::OK());
return;
}
ctx->SetStatus(s);
}
};
REGISTER_KERNEL_BUILDER(Name("DeleteMultiDeviceIterator").Device(DEVICE_CPU),
DeleteMultiDeviceIteratorOp);
// Since this op takes in Iterator handles as (unused) inputs, we don't want
// to constrain the iterator location to CPU only. Therefore, we exempt the
// colocation restriction for this op allowing the iterators to be placed on
// other devices.
REGISTER_INPUT_COLOCATION_EXEMPTION("DeleteMultiDeviceIterator");
} // namespace
} // namespace data
} // namespace tensorflow