| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/kernels/data/iterator_ops.h" |
| |
| #include <memory> |
| |
| #include "absl/memory/memory.h" |
| #include "tensorflow/core/common_runtime/graph_runner.h" |
| #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" |
| #include "tensorflow/core/common_runtime/renamed_device.h" |
| #include "tensorflow/core/common_runtime/threadpool_device.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/resource_op_kernel.h" |
| #include "tensorflow/core/framework/stats_aggregator.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/variant_op_registry.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/kernels/data/captured_function.h" |
| #include "tensorflow/core/kernels/data/dataset_utils.h" |
| #include "tensorflow/core/kernels/data/optional_ops.h" |
| #include "tensorflow/core/kernels/ops_util.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/refcount.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/public/session_options.h" |
| |
| namespace tensorflow { |
| namespace data { |
| namespace { |
| |
| // See documentation in ../../ops/dataset_ops.cc for a high-level |
| // description of the following ops. |
| |
| const char kAnonymousIterator[] = "AnonymousIterator"; |
| const char kAnonymousIteratorV2[] = "AnonymousIteratorV2"; |
| const char kIteratorVariantTypeName[] = "tensorflow::Iterator"; |
| const char kOutputShapes[] = "output_shapes"; |
| const char kOutputTypes[] = "output_types"; |
| |
| } // namespace |
| |
| Status IteratorResource::GetNext(OpKernelContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) { |
| std::shared_ptr<State> captured_state; |
| { |
| tf_shared_lock l(mu_); |
| captured_state = iterator_state_; |
| } |
| if (captured_state->iterator) { |
| IteratorContext::Params params(ctx); |
| params.flr = captured_state->flr; |
| params.function_handle_cache = captured_state->function_handle_cache.get(); |
| params.resource_mgr = &captured_state->resource_mgr; |
| params.thread_factory = unbounded_thread_pool_.get_thread_factory(); |
| params.thread_pool = &unbounded_thread_pool_; |
| params.cancellation_manager = &captured_state->cancellation_manager; |
| std::function<void()> deregister_fn; |
| TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(), |
| params.cancellation_manager, |
| &deregister_fn)); |
| auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); |
| return captured_state->iterator->GetNext(IteratorContext(std::move(params)), |
| out_tensors, end_of_sequence); |
| } |
| return errors::FailedPrecondition( |
| "GetNext() failed because the iterator has not been initialized. Ensure " |
| "that you have run the initializer operation for this iterator before " |
| "getting the next element."); |
| } |
| |
| Status IteratorResource::Save(SerializationContext* ctx, |
| IteratorStateWriter* writer) { |
| std::shared_ptr<State> captured_state; |
| { |
| tf_shared_lock l(mu_); |
| captured_state = iterator_state_; |
| } |
| if (captured_state->iterator) { |
| return captured_state->iterator->Save(ctx, writer); |
| } |
| return errors::FailedPrecondition( |
| "Save() failed because the iterator has not been initialized. Ensure " |
| "that you have run the initializer operation for this iterator before " |
| "saving it."); |
| } |
| |
| Status IteratorResource::Restore(OpKernelContext* ctx, |
| IteratorStateReader* reader) { |
| std::shared_ptr<State> captured_state; |
| { |
| tf_shared_lock l(mu_); |
| captured_state = iterator_state_; |
| } |
| if (captured_state->iterator) { |
| IteratorContext::Params params(ctx); |
| params.flr = captured_state->flr; |
| params.function_handle_cache = captured_state->function_handle_cache.get(); |
| params.resource_mgr = &captured_state->resource_mgr; |
| params.thread_factory = unbounded_thread_pool_.get_thread_factory(); |
| params.thread_pool = &unbounded_thread_pool_; |
| params.cancellation_manager = &captured_state->cancellation_manager; |
| std::function<void()> deregister_fn; |
| TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(), |
| params.cancellation_manager, |
| &deregister_fn)); |
| auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); |
| IteratorContext iter_ctx(std::move(params)); |
| return captured_state->iterator->Restore(&iter_ctx, reader); |
| } |
| return errors::FailedPrecondition( |
| "Restore() failed because the iterator has not been initialized. Ensure " |
| "that you have run the initializer operation for this iterator before " |
| "restoring it."); |
| } |
| |
| Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, |
| DatasetBase* dataset) { |
| std::shared_ptr<State> new_state; |
| { |
| tf_shared_lock l(mu_); |
| new_state = std::make_shared<State>( |
| iterator_state_->flib_def, iterator_state_->pflr, iterator_state_->flr, |
| /*iterator=*/nullptr); |
| } |
| // Create new iterator. |
| std::unique_ptr<IteratorBase> iterator; |
| IteratorContext::Params params(ctx); |
| params.flr = new_state->flr; |
| params.function_handle_cache = new_state->function_handle_cache.get(); |
| params.resource_mgr = &new_state->resource_mgr; |
| params.thread_factory = unbounded_thread_pool_.get_thread_factory(); |
| params.thread_pool = &unbounded_thread_pool_; |
| params.cancellation_manager = &new_state->cancellation_manager; |
| std::function<void()> deregister_fn; |
| TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(), |
| params.cancellation_manager, |
| &deregister_fn)); |
| { |
| auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); |
| TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), |
| "Iterator", &iterator)); |
| TF_RETURN_IF_ERROR( |
| VerifyTypesMatch(output_dtypes_, iterator->output_dtypes())); |
| TF_RETURN_IF_ERROR( |
| VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); |
| std::swap(new_state->iterator, iterator); |
| } |
| |
| mutex_lock l(mu_); |
| std::swap(iterator_state_, new_state); |
| return Status::OK(); |
| } |
| |
| namespace { |
| |
| // Wrapper for encoding/decoding the iterator state stored in a Variant tensor. |
| // The get() method returns an IteratorStateReader which can be used |
| // to restore iterator state. |
| // |
| // Usage example: |
| // |
| // Encoding: |
| // |
| // Tensor t(DT_VARIANT, TensorShape({})); |
| // t->scalar<Variant>()() = IteratorStateVariant(iterator_resource); |
| // |
| // Encode() sets the type_name of the VariantTensorData object to |
| // IteratorStateVariant::TypeName(). |
| // |
| // Decoding: |
| // |
| // Variant v = <VariantTensorDataProto object>; |
| // DecodeUnaryVariant(&v); |
| // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>(); |
| // iterator_resource->Restore(ctx, wrapper->get()) |
| // |
| // The type_name of the VariantTensorData object to be decoded must |
| // match IteratorStateVariant::TypeName(). |
| class IteratorStateVariant { |
| public: |
| IteratorStateVariant() : data_(nullptr) {} |
| IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) { |
| if (other.data_) { |
| Decode(*other.data_); |
| } |
| } |
| IteratorStateVariant& operator=(IteratorStateVariant&& other) = default; |
| IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete; |
| |
| // Initializes this object with the current state of the iterator so |
| // that it can be written on the next call to Encode(). |
| Status InitializeFromIterator(OpKernelContext* ctx, |
| IteratorResource* iterator_resource) { |
| SerializationContext serialization_ctx({}); |
| data_ = absl::make_unique<VariantTensorData>(); |
| data_->set_type_name(TypeName()); |
| VariantTensorDataWriter writer(data_.get()); |
| TF_RETURN_IF_ERROR(iterator_resource->Save(&serialization_ctx, &writer)); |
| TF_RETURN_IF_ERROR(writer.Flush()); |
| return Status::OK(); |
| } |
| string TypeName() const { return kIteratorVariantTypeName; } |
| void Encode(VariantTensorData* data) const { *data = *data_; } |
| bool Decode(VariantTensorData data) { |
| if (data.type_name() != TypeName()) { |
| return false; |
| } |
| std::unique_ptr<VariantTensorData> tensor_data = |
| absl::make_unique<VariantTensorData>(); |
| std::swap(*tensor_data, data); |
| std::unique_ptr<VariantTensorDataReader> reader = |
| absl::make_unique<VariantTensorDataReader>(tensor_data.get()); |
| data_ = std::move(tensor_data); |
| reader_ = std::move(reader); |
| return true; |
| } |
| IteratorStateReader* get() { return reader_.get(); } |
| string DebugString() const { |
| if (data_) { |
| return strings::StrCat("IteratorStateVariant<", data_->DebugString(), |
| ">"); |
| } else { |
| return strings::StrCat("IteratorStateVariant<empty>"); |
| } |
| } |
| |
| private: |
| std::unique_ptr<IteratorStateReader> reader_; |
| std::unique_ptr<VariantTensorData> data_; |
| }; |
| |
| // Register the reader class in the global variant decode_fn registry |
| // so that a Variant containing a serialized representation of iterator state |
| // can be decoded using DecodeUnaryVariant. If we don't do this we will need |
| // to manually decode the returned Variant using MaybeDecodeAndCopy in |
| // DeserializeIteratorOp which is not recommended. |
| REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, |
| kIteratorVariantTypeName); |
| |
| } // namespace |
| |
| // Note that IteratorHandleOp holds a reference to the resource it creates. If |
| // cleaning up resources with DestroyResourceOp is important, consider creating |
| // resource containers with AnonymousIteratorHandleOp instead. |
| IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx) |
| : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); |
| } |
| |
| // The resource is deleted from the resource manager only when it is private |
| // to kernel. Ideally the resource should be deleted when it is no longer held |
| // by anyone, but it would break backward compatibility. |
| IteratorHandleOp::~IteratorHandleOp() { |
| if (resource_ != nullptr) { |
| resource_->Unref(); |
| if (cinfo_.resource_is_private_to_kernel()) { |
| if (!cinfo_.resource_manager() |
| ->template Delete<IteratorResource>(cinfo_.container(), |
| cinfo_.name()) |
| .ok()) { |
| // Do nothing; the resource can have been deleted by session resets. |
| } |
| } |
| } |
| } |
| |
| void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { |
| { |
| mutex_lock l(mu_); |
| if (resource_ == nullptr) { |
| FunctionLibraryRuntime* flr; |
| std::unique_ptr<DeviceMgr> device_mgr(nullptr); |
| std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); |
| std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); |
| // If the iterator is shared then we construct a new FLR, and pass that |
| // in. NOTE(mrry,rohanj): In this case it is not possible to call remote |
| // functions from the iterator. We may add this functionality if there |
| // is sufficient demand, but it will require a significant refactoring. |
| if (!name_.empty()) { |
| flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); |
| } else { |
| OP_REQUIRES_OK(context, context->function_library()->Clone( |
| &flib_def, &pflr, &flr, true)); |
| } |
| |
| ResourceMgr* mgr = context->resource_manager(); |
| OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); |
| |
| IteratorResource* resource; |
| OP_REQUIRES_OK( |
| context, |
| mgr->LookupOrCreate<IteratorResource>( |
| cinfo_.container(), cinfo_.name(), &resource, |
| [context, flr, &device_mgr, &flib_def, &pflr, |
| this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| *ret = new IteratorResource( |
| context->env(), output_dtypes_, output_shapes_, |
| graph_def_version_, std::move(device_mgr), |
| std::move(flib_def), std::move(pflr), flr); |
| return Status::OK(); |
| })); |
| |
| Status s = VerifyResource(resource); |
| if (TF_PREDICT_FALSE(!s.ok())) { |
| resource->Unref(); |
| context->SetStatus(s); |
| return; |
| } |
| |
| resource_ = resource; |
| } |
| } |
| OP_REQUIRES_OK(context, MakeResourceHandleToOutput( |
| context, 0, cinfo_.container(), cinfo_.name(), |
| MakeTypeIndex<IteratorResource>())); |
| } |
| |
| Status IteratorHandleOp::VerifyResource(IteratorResource* resource) { |
| TF_RETURN_IF_ERROR( |
| VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); |
| TF_RETURN_IF_ERROR( |
| VerifyShapesCompatible(output_shapes_, resource->output_shapes())); |
| return Status::OK(); |
| } |
| |
| FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR( |
| OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, |
| std::unique_ptr<FunctionLibraryDefinition>* flib_def, |
| std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) { |
| // Wrap the existing device in order to see any captured resources |
| // in its resource manager. The existing device will outlive the |
| // IteratorResource, because we are storing the IteratorResource |
| // in that device's resource manager. |
| |
| *device_mgr = |
| absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice( |
| ctx->device()->name(), down_cast<Device*>(ctx->device()), |
| false /* owns_underlying */, false /* isolate_session_state */)); |
| *flib_def = absl::make_unique<FunctionLibraryDefinition>( |
| *ctx->function_library()->GetFunctionLibraryDefinition()); |
| const auto* config = ctx->function_library()->config_proto(); |
| *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>( |
| device_mgr->get(), ctx->env(), |
| /*config=*/config, graph_def_version_, flib_def->get(), |
| config->graph_options().optimizer_options()); |
| |
| return (*pflr)->GetFLR(ctx->device()->name()); |
| } |
| |
| // Like IteratorHandleOp, but creates handles which are never shared, and does |
| // not hold a reference to these handles. The latter is important for eager |
| // execution, since OpKernel instances generally live as long as the program |
| // running them. |
| AnonymousIteratorHandleOp::AnonymousIteratorHandleOp( |
| OpKernelConstruction* context) |
| : AnonymousResourceOp<IteratorResource>(context), |
| graph_def_version_(context->graph_def_version()) { |
| OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_)); |
| OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_)); |
| create_deleter_ = context->def().op() == kAnonymousIteratorV2; |
| } |
| |
| string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; } |
| |
| Status AnonymousIteratorHandleOp::CreateResource( |
| OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def, |
| std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, |
| FunctionLibraryRuntime* lib, IteratorResource** resource) { |
| std::unique_ptr<DeviceMgr> device_mgr(nullptr); |
| *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_, |
| graph_def_version_, std::move(device_mgr), |
| std::move(flib_def), std::move(pflr), lib); |
| return Status::OK(); |
| } |
| |
| void MakeIteratorOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { |
| DatasetBase* dataset; |
| OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); |
| IteratorResource* iterator_resource; |
| OP_REQUIRES_OK( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); |
| background_worker_.Schedule(std::bind( |
| [ctx, iterator_resource, dataset](DoneCallback done) { |
| Status s = iterator_resource->SetIteratorFromDataset(ctx, dataset); |
| iterator_resource->Unref(); |
| if (!s.ok()) { |
| ctx->SetStatus(s); |
| } |
| done(); |
| }, |
| std::move(done))); |
| } |
| |
| void DeleteIteratorOp::Compute(OpKernelContext* ctx) { |
| ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0); |
| // The iterator resource is guaranteed to exist because the variant tensor |
| // wrapping the deleter is provided as an unused input to this op, which |
| // guarantees that it has not run yet. |
| Status s = ctx->resource_manager()->Delete(handle); |
| if (errors::IsNotFound(s)) { |
| // TODO(b/135948230): Investigate why is the above statement not true and |
| // then get rid of the special case. |
| ctx->SetStatus(Status::OK()); |
| return; |
| } |
| ctx->SetStatus(s); |
| } |
| |
| namespace { |
| |
| class ToSingleElementOp : public AsyncOpKernel { |
| public: |
| explicit ToSingleElementOp(OpKernelConstruction* ctx) |
| : AsyncOpKernel(ctx), |
| background_worker_(ctx->env(), "tf_data_to_single_element") {} |
| |
| void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
| // The call to `iterator->GetNext()` may block and depend on an |
| // inter-op thread pool thread, so we issue the call from the |
| // owned thread pool. |
| background_worker_.Schedule(std::bind( |
| [ctx](std::function<void()>& done) { |
| DatasetBase* dataset; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); |
| |
| IteratorContext::Params params(ctx); |
| FunctionHandleCache function_handle_cache(params.flr); |
| params.function_handle_cache = &function_handle_cache; |
| ResourceMgr resource_mgr; |
| params.resource_mgr = &resource_mgr; |
| CancellationManager cancellation_manager; |
| params.cancellation_manager = &cancellation_manager; |
| std::function<void()> deregister_fn; |
| OP_REQUIRES_OK_ASYNC(ctx, |
| ConnectCancellationManagers( |
| ctx->cancellation_manager(), |
| params.cancellation_manager, &deregister_fn), |
| done); |
| |
| // Update the `done` callback to deregister the cancellation callback. |
| done = std::bind( |
| [](const std::function<void()>& done, |
| const std::function<void()>& deregister_fn) { |
| deregister_fn(); |
| done(); |
| }, |
| std::move(done), std::move(deregister_fn)); |
| |
| IteratorContext iter_ctx(std::move(params)); |
| std::unique_ptr<IteratorBase> iterator; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| dataset->MakeIterator(&iter_ctx, "SingleElementIterator", |
| &iterator), |
| done); |
| |
| // Update the `done` callback to destroy the iterator before calling |
| // the actual callback to avoid destruction races. |
| IteratorBase* raw_iterator = iterator.release(); |
| done = std::bind( |
| [raw_iterator](const std::function<void()>& done) { |
| delete raw_iterator; |
| done(); |
| }, |
| std::move(done)); |
| |
| std::vector<Tensor> components; |
| components.reserve(dataset->output_dtypes().size()); |
| bool end_of_sequence = false; |
| |
| Status s = |
| raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence); |
| if (!s.ok()) { |
| ctx->SetStatus(s); |
| done(); |
| return; |
| } |
| if (end_of_sequence) { |
| ctx->SetStatus(errors::InvalidArgument("Dataset was empty.")); |
| done(); |
| return; |
| } |
| for (int i = 0; i < components.size(); ++i) { |
| // TODO(mrry): Check that the shapes match the shape attrs. |
| ctx->set_output(i, components[i]); |
| } |
| |
| components.clear(); |
| s.Update( |
| raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); |
| if (!s.ok()) { |
| ctx->SetStatus(s); |
| done(); |
| return; |
| } |
| if (!end_of_sequence) { |
| ctx->SetStatus( |
| errors::InvalidArgument("Dataset had more than one element.")); |
| done(); |
| return; |
| } |
| done(); |
| }, |
| std::move(done))); |
| } |
| |
| private: |
| BackgroundWorker background_worker_; |
| }; |
| |
| class ReduceDatasetOp : public AsyncOpKernel { |
| public: |
| explicit ReduceDatasetOp(OpKernelConstruction* ctx) |
| : AsyncOpKernel(ctx), |
| background_worker_(ctx->env(), "tf_data_reduce_dataset") { |
| FunctionMetadata::Params params; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", |
| ¶ms.use_inter_op_parallelism)); |
| params.is_multi_device_function = true; |
| OP_REQUIRES_OK(ctx, |
| FunctionMetadata::Create(ctx, "f", params, &func_metadata_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
| } |
| |
| void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
| // The call to `iterator->GetNext()` may block and depend on an |
| // inter-op thread pool thread, so we issue the call from the |
| // owned thread pool. |
| background_worker_.Schedule(std::bind( |
| [this, ctx](std::function<void()>& done) { |
| DatasetBase* dataset; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); |
| OpInputList inputs; |
| OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs), |
| done); |
| std::vector<Tensor> state(inputs.begin(), inputs.end()); |
| |
| std::unique_ptr<CapturedFunction> captured_func; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| CapturedFunction::Create(ctx, func_metadata_, "other_arguments", |
| &captured_func), |
| done); |
| |
| IteratorContext::Params params(ctx); |
| auto function_handle_cache = |
| absl::make_unique<FunctionHandleCache>(params.flr); |
| params.function_handle_cache = function_handle_cache.get(); |
| ResourceMgr resource_mgr; |
| params.resource_mgr = &resource_mgr; |
| CancellationManager cancellation_manager; |
| params.cancellation_manager = &cancellation_manager; |
| std::function<void()> deregister_fn; |
| OP_REQUIRES_OK_ASYNC(ctx, |
| ConnectCancellationManagers( |
| ctx->cancellation_manager(), |
| params.cancellation_manager, &deregister_fn), |
| done); |
| |
| // Update the `done` callback to deregister the cancellation callback. |
| done = std::bind( |
| [](const std::function<void()>& done, |
| const std::function<void()>& deregister_fn) { |
| deregister_fn(); |
| done(); |
| }, |
| std::move(done), std::move(deregister_fn)); |
| |
| IteratorContext iter_ctx(std::move(params)); |
| std::unique_ptr<InstantiatedCapturedFunction> |
| instantiated_captured_func; |
| OP_REQUIRES_OK_ASYNC(ctx, |
| captured_func->Instantiate( |
| &iter_ctx, &instantiated_captured_func), |
| done); |
| |
| std::unique_ptr<IteratorBase> iterator; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator), |
| done); |
| |
| // Update the `done` callback to destroy the iterator before calling |
| // the actual callback to avoid destruction races. |
| IteratorBase* raw_iterator = iterator.release(); |
| done = std::bind( |
| [raw_iterator](const std::function<void()>& done) { |
| delete raw_iterator; |
| done(); |
| }, |
| std::move(done)); |
| |
| // Iterate through the input dataset. |
| Status status; |
| while (true) { |
| OP_REQUIRES_ASYNC(ctx, !ctx->cancellation_manager()->IsCancelled(), |
| errors::Cancelled("Operation was cancelled"), |
| done); |
| std::vector<Tensor> next_input_element; |
| bool end_of_input; |
| status = raw_iterator->GetNext(&iter_ctx, &next_input_element, |
| &end_of_input); |
| if (!status.ok() || end_of_input) { |
| break; |
| } |
| |
| // Run the reduce function to update the current state. |
| std::vector<Tensor> args; |
| args.reserve(state.size() + next_input_element.size()); |
| std::copy(state.begin(), state.end(), std::back_inserter(args)); |
| std::copy(next_input_element.begin(), next_input_element.end(), |
| std::back_inserter(args)); |
| |
| std::vector<Tensor> reduce_func_output; |
| status = instantiated_captured_func->Run(&iter_ctx, std::move(args), |
| &reduce_func_output); |
| if (!status.ok()) { |
| break; |
| } |
| OP_REQUIRES_ASYNC( |
| ctx, reduce_func_output.size() == state.size(), |
| errors::InvalidArgument( |
| "The number of components of the initial state and the " |
| "reduce " |
| "function output does not match. (initial_state=", |
| state.size(), ", output=", reduce_func_output.size(), ")."), |
| done); |
| std::swap(reduce_func_output, state); |
| } |
| |
| if (!status.ok()) { |
| ctx->SetStatus(status); |
| done(); |
| return; |
| } |
| |
| OP_REQUIRES_ASYNC(ctx, state.size() == output_types_.size(), |
| errors::InvalidArgument( |
| "The number of result elements does not match " |
| "the size of output types: ", |
| state.size(), " vs. ", output_types_.size()), |
| done); |
| OP_REQUIRES_ASYNC(ctx, state.size() == output_shapes_.size(), |
| errors::InvalidArgument( |
| "The number of result elements does not match " |
| "the size of output shapes: ", |
| state.size(), " vs. ", output_shapes_.size()), |
| done); |
| for (int i = 0; i < state.size(); ++i) { |
| OP_REQUIRES_ASYNC( |
| ctx, state[i].dtype() == output_types_[i], |
| errors::InvalidArgument( |
| "The result does not match the expected type for " |
| "component ", |
| i, ". Expected: ", DataTypeString(output_types_[i]), |
| ". Actual: ", DataTypeString(state[i].dtype()), "."), |
| done); |
| OP_REQUIRES_ASYNC( |
| ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()), |
| errors::InvalidArgument( |
| "The result does not match the expected shape for " |
| "component ", |
| i, ". Expected: ", output_shapes_[i].DebugString(), |
| ". Actual: ", state[i].shape().DebugString(), "."), |
| done); |
| ctx->set_output(i, state[i]); |
| } |
| done(); |
| }, |
| std::move(done))); |
| } |
| |
| private: |
| std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr; |
| DataTypeVector output_types_; |
| std::vector<PartialTensorShape> output_shapes_; |
| BackgroundWorker background_worker_; |
| }; |
| |
| class OneShotIteratorOp : public AsyncOpKernel { |
| public: |
| explicit OneShotIteratorOp(OpKernelConstruction* ctx) |
| : AsyncOpKernel(ctx), |
| background_worker_(ctx->env(), "tf_data_one_shot_iterator"), |
| graph_def_version_(ctx->graph_def_version()) |
| |
| { |
| string shared_name; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name)); |
| OP_REQUIRES(ctx, shared_name.empty(), |
| errors::InvalidArgument("OneShotIteratorOp does not currently " |
| "support the 'shared_name' attr.")); |
| OP_REQUIRES_OK(ctx, |
| ctx->GetAttr("dataset_factory", &dataset_factory_func_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
| } |
| |
| ~OneShotIteratorOp() override { |
| if (iterator_resource_ != nullptr) { |
| iterator_resource_->Unref(); |
| if (!cinfo_.resource_manager() |
| ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name()) |
| .ok()) { |
| // Do nothing; the resource can have been deleted by session resets. |
| } |
| } |
| } |
| |
| // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`, |
| // but due to the fact that `ResourceOpKernel<T>::CreateResource()` |
| // does not provide access to the `OpKernelContext*` and we need |
| // this to invoke the factory function, it's not possible to |
| // implement this kernel by implementing `CreateResource()`. |
| // Furthermore, due to the fact that this kernel might block when |
| // running the initialization function, we must implement this |
| // kernel as an async kernel. |
| void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
| { |
| mutex_lock l(mu_); |
| if (iterator_resource_ == nullptr && initialization_status_.ok()) { |
| // The initialization thread will call `done`. |
| if (!initialization_started_) { |
| // TODO(mrry): Convert the initialization code to use |
| // callbacks instead of wasting a thread. |
| background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); }); |
| initialization_started_ = true; |
| } else { |
| done_callbacks_.emplace_back(ctx, std::move(done)); |
| } |
| return; |
| } |
| } |
| ProduceOutput(ctx, done); |
| } |
| |
| private: |
| void Init(OpKernelContext* ctx, const DoneCallback& done) { |
| IteratorResource* iterator = nullptr; |
| ContainerInfo cinfo; |
| Status s = TryInit(ctx, &iterator, &cinfo); |
| |
| std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run; |
| { |
| mutex_lock l(mu_); |
| if (s.ok()) { |
| iterator_resource_ = iterator; |
| cinfo_ = cinfo; |
| } |
| initialization_status_ = s; |
| std::swap(done_callbacks_, callbacks_to_run); |
| } |
| |
| for (auto&& ctx_done : callbacks_to_run) { |
| ProduceOutput(ctx_done.first, ctx_done.second); |
| } |
| ProduceOutput(ctx, done); |
| } |
| |
| Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, |
| ContainerInfo* cinfo) { |
| TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); |
| |
| FunctionLibraryRuntime* flr; |
| std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); |
| std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); |
| TF_RETURN_IF_ERROR( |
| ctx->function_library()->Clone(&flib_def, &pflr, &flr, true)); |
| |
| // Create an IteratorResource that will hold the iterator for this op. |
| TF_RETURN_IF_ERROR( |
| ctx->resource_manager()->LookupOrCreate<IteratorResource>( |
| cinfo->container(), cinfo->name(), iterator, |
| [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret) |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| *ret = new IteratorResource( |
| ctx->env(), output_dtypes_, output_shapes_, |
| graph_def_version_, nullptr, std::move(flib_def), |
| std::move(pflr), flr); |
| return Status::OK(); |
| })); |
| |
| core::ScopedUnref unref_iterator(*iterator); |
| |
| TF_RETURN_IF_ERROR( |
| VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes())); |
| TF_RETURN_IF_ERROR( |
| VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes())); |
| |
| // Call the dataset_factory_func_ to create a new dataset, |
| // over which this op will iterate. |
| FunctionLibraryRuntime::Handle f_handle; |
| TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( |
| dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()), |
| &f_handle)); |
| FunctionLibraryRuntime::Options opts; |
| opts.cancellation_manager = ctx->cancellation_manager(); |
| ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) { |
| ctx->resource_manager()->Cleanup(name).IgnoreError(); |
| }); |
| opts.step_container = &step_container; |
| opts.runner = ctx->runner(); |
| Notification n; |
| Status factory_status; |
| std::vector<Tensor> return_values; |
| ctx->function_library()->Run(opts, f_handle, {}, &return_values, |
| [&n, &factory_status](Status s) { |
| factory_status.Update(s); |
| n.Notify(); |
| }); |
| n.WaitForNotification(); |
| TF_RETURN_IF_ERROR(factory_status); |
| if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT || |
| !TensorShapeUtils::IsScalar(return_values[0].shape())) { |
| return errors::InvalidArgument( |
| "The `dataset_factory` function must return " |
| "a single scalar of dtype DT_VARIANT."); |
| } |
| |
| // Create an iterator for the dataset that was created in the |
| // factory function. |
| DatasetBase* dataset; |
| TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); |
| TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset)); |
| (*iterator)->Ref(); |
| return Status::OK(); |
| } |
| |
| void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) { |
| Tensor* handle; |
| OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle), |
| done); |
| Status s; |
| { |
| mutex_lock l(mu_); |
| s = initialization_status_; |
| if (s.ok()) { |
| handle->scalar<ResourceHandle>()() = |
| MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(), |
| cinfo_.name()); |
| } |
| } |
| OP_REQUIRES_OK_ASYNC(ctx, s, done); |
| done(); |
| } |
| |
| NameAttrList dataset_factory_func_; |
| DataTypeVector output_dtypes_; |
| std::vector<PartialTensorShape> output_shapes_; |
| |
| BackgroundWorker background_worker_; |
| |
| mutex mu_; |
| ContainerInfo cinfo_ GUARDED_BY(mu_); |
| IteratorResource* iterator_resource_ GUARDED_BY(mu_) = nullptr; |
| |
| bool initialization_started_ GUARDED_BY(mu_) = false; |
| Status initialization_status_ GUARDED_BY(mu_); |
| std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_ |
| GUARDED_BY(mu_); |
| const int graph_def_version_; |
| }; |
| |
| } // namespace |
| |
| void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { |
| IteratorResource* iterator; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); |
| // The call to `iterator->GetNext()` may block and depend on an |
| // inter-op thread pool thread, so we issue the call from the |
| // owned thread pool. |
| background_worker_.Schedule(std::bind( |
| [ctx, iterator](DoneCallback done) { |
| std::vector<Tensor> components; |
| bool end_of_sequence = false; |
| |
| Status s = iterator->GetNext(ctx, &components, &end_of_sequence); |
| // NOTE(mrry): We must unref the iterator before calling `done()`, to |
| // avoid destruction races. |
| iterator->Unref(); |
| |
| if (!s.ok()) { |
| ctx->SetStatus(s); |
| } else if (end_of_sequence) { |
| ctx->SetStatus(errors::OutOfRange("End of sequence")); |
| } else { |
| for (int i = 0; i < components.size(); ++i) { |
| // TODO(mrry): Check that the shapes match the shape attrs. |
| ctx->set_output(i, components[i]); |
| } |
| } |
| done(); |
| }, |
| std::move(done))); |
| } |
| |
| void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) { |
| IteratorResource* iterator; |
| OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); |
| core::ScopedUnref unref_iterator(iterator); |
| std::vector<Tensor> components; |
| bool end_of_sequence = false; |
| |
| OP_REQUIRES_OK(ctx, iterator->GetNext(ctx, &components, &end_of_sequence)); |
| OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); |
| |
| for (int i = 0; i < components.size(); ++i) { |
| // TODO(mrry): Check that the shapes match the shape attrs. |
| ctx->set_output(i, components[i]); |
| } |
| } |
| |
| void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx, |
| DoneCallback done) { |
| IteratorResource* iterator; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); |
| // The call to `iterator->GetNext()` may block and depend on an |
| // inter-op thread pool thread, so we issue the call from the |
| // owned thread pool. |
| background_worker_.Schedule(std::bind( |
| [this, ctx, iterator](DoneCallback done) { |
| std::vector<Tensor> components; |
| bool end_of_sequence = false; |
| |
| Status s = iterator->GetNext(ctx, &components, &end_of_sequence); |
| // NOTE(mrry): We must unref the iterator before calling `done()`, to |
| // avoid destruction races. |
| iterator->Unref(); |
| |
| if (!s.ok()) { |
| ctx->SetStatus(s); |
| } else if (end_of_sequence) { |
| OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done); |
| } else { |
| for (int i = 0; i < components.size(); ++i) { |
| OP_REQUIRES_ASYNC( |
| ctx, components[i].dtype() == output_types_[i], |
| errors::InvalidArgument( |
| "The given optional does not match the expected type for " |
| "component ", |
| i, ". Expected: ", DataTypeString(output_types_[i]), |
| ". Actual: ", DataTypeString(components[i].dtype()), "."), |
| done); |
| OP_REQUIRES_ASYNC( |
| ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()), |
| errors::InvalidArgument( |
| "The given optional does not match the expected shape " |
| "for component ", |
| i, ". Expected: ", output_shapes_[i].DebugString(), |
| ". Actual: ", components[i].shape().DebugString(), "."), |
| done); |
| } |
| |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| WriteOptionalWithValueToOutput(ctx, 0, std::move(components)), |
| done); |
| } |
| done(); |
| }, |
| std::move(done))); |
| } |
| |
| void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { |
| const Tensor& resource_handle_t = ctx->input(0); |
| OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), |
| errors::InvalidArgument("resource_handle must be a scalar")); |
| |
| // Validate that the handle corresponds to a real resource, and |
| // that it is an IteratorResource. |
| IteratorResource* iterator_resource; |
| OP_REQUIRES_OK( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); |
| iterator_resource->Unref(); |
| |
| Tensor* string_handle_t; |
| OP_REQUIRES_OK(ctx, |
| ctx->allocate_output(0, TensorShape({}), &string_handle_t)); |
| string_handle_t->scalar<tstring>()() = |
| resource_handle_t.scalar<ResourceHandle>()().SerializeAsString(); |
| } |
| |
| IteratorFromStringHandleOp::IteratorFromStringHandleOp( |
| OpKernelConstruction* ctx) |
| : OpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
| OP_REQUIRES( |
| ctx, |
| output_dtypes_.empty() || output_shapes_.empty() || |
| output_dtypes_.size() == output_shapes_.size(), |
| errors::InvalidArgument("If both 'output_types' and 'output_shapes' " |
| "are set, they must have the same length.")); |
| } |
| |
| void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { |
| const Tensor& string_handle_t = ctx->input(0); |
| OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), |
| errors::InvalidArgument("string_handle must be a scalar")); |
| |
| ResourceHandle resource_handle; |
| OP_REQUIRES( |
| ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()), |
| errors::InvalidArgument( |
| "Could not parse string_handle as a valid ResourceHandle")); |
| |
| OP_REQUIRES( |
| ctx, resource_handle.device() == ctx->device()->attributes().name(), |
| errors::InvalidArgument("Attempted create an iterator on device \"", |
| ctx->device()->attributes().name(), |
| "\" from handle defined on device \"", |
| resource_handle.device(), "\"")); |
| |
| // Validate that the handle corresponds to a real resource, and |
| // that it is an IteratorResource. |
| IteratorResource* iterator_resource; |
| OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource)); |
| core::ScopedUnref unref_iterator(iterator_resource); |
| if (!output_dtypes_.empty()) { |
| OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_, |
| iterator_resource->output_dtypes())); |
| } |
| if (!output_shapes_.empty()) { |
| OP_REQUIRES_OK(ctx, |
| VerifyShapesCompatible(output_shapes_, |
| iterator_resource->output_shapes())); |
| } |
| |
| Tensor* resource_handle_t; |
| OP_REQUIRES_OK(ctx, |
| ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); |
| resource_handle_t->scalar<ResourceHandle>()() = resource_handle; |
| } |
| |
| namespace { |
| |
| class SerializeIteratorOp : public OpKernel { |
| public: |
| explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
| |
| void Compute(OpKernelContext* ctx) override { |
| const Tensor& resource_handle_t = ctx->input(0); |
| OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), |
| errors::InvalidArgument("resource_handle must be a scalar")); |
| |
| // Validate that the handle corresponds to a real resource, and |
| // that it is an IteratorResource. |
| IteratorResource* iterator_resource; |
| OP_REQUIRES_OK( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); |
| core::ScopedUnref unref_iterator(iterator_resource); |
| Tensor* variant_t; |
| OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t)); |
| IteratorStateVariant v; |
| OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource)); |
| variant_t->scalar<Variant>()() = v; |
| } |
| }; |
| |
| class DeserializeIteratorOp : public OpKernel { |
| public: |
| explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
| |
| void Compute(OpKernelContext* ctx) override { |
| // Validate that the handle corresponds to a real resource, and |
| // that it is an IteratorResource. |
| IteratorResource* iterator_resource; |
| OP_REQUIRES_OK( |
| ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); |
| core::ScopedUnref unref_iterator(iterator_resource); |
| Variant variant = ctx->input(1).scalar<Variant>()(); |
| auto* wrapper = variant.get<IteratorStateVariant>(); |
| OP_REQUIRES(ctx, wrapper != nullptr, |
| errors::InvalidArgument( |
| "DeserializeIteratorOp: Unable to parse variant tensor.")); |
| OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get())); |
| } |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2), |
| IteratorHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1), |
| IteratorHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2), |
| MakeIteratorOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"), |
| MakeIteratorOp); |
| REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2), |
| DeleteIteratorOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("DeleteIterator").Device(DEVICE_GPU).HostMemory("deleter").Priority(1), |
| DeleteIteratorOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2), |
| AnonymousIteratorHandleOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1), |
| AnonymousIteratorHandleOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2), |
| AnonymousIteratorHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2") |
| .Device(DEVICE_GPU) |
| .HostMemory("deleter") |
| .Priority(1), |
| AnonymousIteratorHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), |
| ToSingleElementOp); |
| REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU), |
| ReduceDatasetOp); |
| REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), |
| OneShotIteratorOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2), |
| IteratorGetNextOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1), |
| IteratorGetNextOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2), |
| IteratorGetNextSyncOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1), |
| IteratorGetNextSyncOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2), |
| IteratorGetNextAsOptionalOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1), |
| IteratorGetNextAsOptionalOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2), |
| IteratorToStringHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") |
| .Device(DEVICE_GPU) |
| .HostMemory("string_handle") |
| .Priority(1), |
| IteratorToStringHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU), |
| IteratorFromStringHandleOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2), |
| IteratorFromStringHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") |
| .Device(DEVICE_GPU) |
| .HostMemory("string_handle") |
| .Priority(1), |
| IteratorFromStringHandleOp); |
| REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), |
| SerializeIteratorOp); |
| REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), |
| DeserializeIteratorOp); |
| |
| REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset"); |
| |
| } // namespace |
| |
| } // namespace data |
| } // namespace tensorflow |