| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |
| #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |
| |
| #include <deque> |
| #include <memory> |
| #include <unordered_map> |
| |
| #include "absl/memory/memory.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/attr_value_util.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/model.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/thread_factory.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/framework/variant_encode_decode.h" |
| #include "tensorflow/core/framework/variant_tensor_data.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/lib/core/threadpool_interface.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/cpu_info.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/tracing.h" |
| |
| // Polymorphic datasets should support all primitive TensorFlow |
| // types. Use this macro to expand `m(T)` once for each primitive type |
| // `T`, e.g. to build a `switch` statement. |
| #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) |
| |
| namespace tensorflow { |
| |
| // Forward declarations to avoid introducing a dependency on headers in |
| // "tensorflow/core/graph/...". |
| class GraphDefBuilder; |
| class Node; |
| |
| namespace data { |
| |
| constexpr int kInfiniteCardinality = -1; |
| constexpr int kUnknownCardinality = -2; |
| |
| class DatasetBase; |
| class SerializationContext; |
| |
| // Interface for reading values from a key-value store. |
| // Used for restoring iterator state. |
| class IteratorStateReader { |
| public: |
| virtual Status ReadScalar(StringPiece key, int64* val) = 0; |
| #ifdef USE_TSTRING |
| // TODO(dero): Temp guard to prevent duplicate declaration during tstring |
| // migration. |
| virtual Status ReadScalar(StringPiece key, string* val) = 0; |
| #endif |
| virtual Status ReadScalar(StringPiece key, tstring* val) = 0; |
| virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; |
| virtual bool Contains(StringPiece key) = 0; |
| |
| virtual ~IteratorStateReader() {} |
| }; |
| |
| // Interface for writing values to a key-value store. |
| // Used for saving iterator state. |
| class IteratorStateWriter { |
| public: |
| virtual Status WriteScalar(StringPiece key, const int64 val) = 0; |
| #ifdef USE_TSTRING |
| // TODO(dero): Temp guard to prevent duplicate declaration during tstring |
| // migration. |
| virtual Status WriteScalar(StringPiece key, const string& val) = 0; |
| #endif |
| virtual Status WriteScalar(StringPiece key, const tstring& val) = 0; |
| virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; |
| |
| virtual ~IteratorStateWriter() {} |
| }; |
| |
| // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. |
| class GraphDefBuilderWrapper { |
| public: |
| explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} |
| |
| // Adds a Const node with scalar value to the Graph. |
| // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
| // non-null if the method returns with an OK status. |
| // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. |
| template <typename T> |
| Status AddScalar(const T& val, Node** output) { |
| Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); |
| val_t.scalar<T>()() = val; |
| AddTensorInternal(val_t, output); |
| if (*output == nullptr) { |
| return errors::Internal("AddScalar: Failed to build Const op."); |
| } |
| return Status::OK(); |
| } |
| |
| // Adds a Const node with vector value to the Graph. |
| // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
| // non-null if the method returns with an OK status. |
| // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. |
| // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? |
| template <typename T> |
| Status AddVector(const std::vector<T>& val, Node** output) { |
| Tensor val_t = Tensor(DataTypeToEnum<T>::v(), |
| TensorShape({static_cast<int64>(val.size())})); |
| for (size_t i = 0; i < val.size(); i++) { |
| val_t.flat<T>()(i) = val[i]; |
| } |
| AddTensorInternal(val_t, output); |
| if (*output == nullptr) { |
| return errors::Internal("AddVector: Failed to build Const op."); |
| } |
| return Status::OK(); |
| } |
| |
| #ifdef USE_TSTRING |
| // TODO(dero): Temp guard to prevent duplicate declaration during tstring |
| // migration. |
| Status AddVector(const std::vector<string>& val, Node** output) { |
| Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(), |
| TensorShape({static_cast<int64>(val.size())})); |
| for (size_t i = 0; i < val.size(); i++) { |
| val_t.flat<tstring>()(i) = val[i]; |
| } |
| AddTensorInternal(val_t, output); |
| if (*output == nullptr) { |
| return errors::Internal("AddVector: Failed to build Const op."); |
| } |
| return Status::OK(); |
| } |
| #endif // USE_TSTRING |
| |
| // Adds a `Const` node for the given tensor value to the graph. |
| // |
| // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
| // non-null if the method returns with an OK status. The returned `Node` |
| // pointer is owned by the backing graph of `GraphDefBuilder`. |
| Status AddTensor(const Tensor& val, Node** output) { |
| AddTensorInternal(val, output); |
| if (*output == nullptr) { |
| return errors::Internal("AddTensor: Failed to build Const op."); |
| } |
| return Status::OK(); |
| } |
| |
| // Adds a `Placeholder` node for the given tensor value to the graph. |
| // |
| // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
| // non-null if the method returns with an OK status. The returned `Node` |
| // pointer is owned by the backing graph of `GraphDefBuilder`. |
| Status AddPlaceholder(const Tensor& val, Node** output) { |
| AddPlaceholderInternal(val, output); |
| if (*output == nullptr) { |
| return errors::Internal( |
| "AddPlaceholder: Failed to build Placeholder op."); |
| } |
| return Status::OK(); |
| } |
| |
| Status AddDataset(const DatasetBase* dataset, |
| const std::vector<Node*>& inputs, Node** output) { |
| return AddDataset(dataset, inputs, {}, output); |
| } |
| |
| // Adds a node corresponding to the `DatasetType` to the Graph. |
| // Return value of `DatasetType::type_string()` is used as the op type for the |
| // node. |
| // Values for the output_types and output_shapes node attributes are also |
| // written if those attributes are defined in the OpDef. |
| // `*output` contains a pointer to the output `Node`. It is guaranteed to be |
| // non-null if the method returns with an OK status. |
| // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. |
| Status AddDataset(const DatasetBase* dataset, |
| const std::vector<Node*>& inputs, |
| const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
| Node** output) { |
| std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size()); |
| for (size_t i = 0; i < inputs.size(); i++) { |
| enumerated_inputs[i] = std::make_pair(i, inputs[i]); |
| } |
| return AddDataset(dataset, enumerated_inputs, {}, attrs, output); |
| } |
| |
| Status AddDataset( |
| const DatasetBase* dataset, |
| const std::vector<std::pair<size_t, Node*>>& inputs, |
| const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, |
| const std::vector<std::pair<StringPiece, AttrValue>>& attrs, |
| Node** output); |
| |
| // Adds a user-defined function with name `function_name` to the graph and |
| // recursively adds all functions it references. If a function with a matching |
| // name has already been added, returns with OK status. If a user-defined with |
| // name `function_name` is not found in the context's function library, |
| // returns an InvalidArgumentError. If the function with name `function_name` |
| // or any of its dependent functions are stateful, and the context does not |
| // explicitly permit stateful functions, returns an InvalidArgument error. |
| Status AddFunction(SerializationContext* ctx, const string& function_name, |
| const FunctionLibraryDefinition& lib_def); |
| |
| template <typename T> |
| void BuildAttrValue(const T& value, AttrValue* attr) { |
| SetAttrValue(value, attr); |
| } |
| |
| private: |
| void AddPlaceholderInternal(const Tensor& val, Node** output); |
| void AddTensorInternal(const Tensor& val, Node** output); |
| bool HasAttr(const string& op_type_name, const string& attr_name) const; |
| |
| bool HasAttr(const OpDef* op_def, const string& attr_name) const { |
| for (auto attr : op_def->attr()) { |
| if (attr.name() == attr_name) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| Status AddAttrFunctions(SerializationContext* ctx, |
| const AttrValue& attr_value, |
| const FunctionLibraryDefinition& lib_def) { |
| if (attr_value.has_func()) { |
| TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def)); |
| } else if (attr_value.has_list()) { |
| for (const NameAttrList& name_attr_list : attr_value.list().func()) { |
| TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def)); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| GraphDefBuilder* b_; |
| }; |
| |
| class StatsAggregator; |
| class FunctionHandleCache; |
| |
| // A utility class for running a function and ensuring that there is always a |
| // `tensorflow::data` symbol on the stack. |
| class Runner { |
| public: |
| virtual ~Runner() {} |
| |
| // Runs the given function. |
| virtual void Run(const std::function<void()>& f) = 0; |
| |
| // Returns a global singleton Runner. |
| static Runner* get(); |
| }; |
| |
| // A cut-down version of `OpKernelContext` for running computations in |
| // iterators. Note that we cannot simply use `OpKernelContext` here because we |
| // might run computation in an iterator whose lifetime is not nested within the |
| // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching). |
| // |
| // TODO(mrry): We're making some daring assumptions about the lifetime of the |
| // runner passed in here. A runner will be deleted when the original step ends, |
| // but all existing runners only close over session-lifetime (or longer-lived) |
| // state, so we can make a copy of the function. There's nothing in the |
| // definition of the API from which we took the runner to guarantee that what we |
| // are doing is safe. We should formalize the properties here. |
| class IteratorContext { |
| public: |
| struct Params { |
| explicit Params(IteratorContext* ctx) |
| : allocator_getter(ctx->allocator_getter()), |
| cancellation_manager(ctx->cancellation_manager()), |
| env(ctx->env()), |
| flr(ctx->flr()), |
| function_handle_cache(ctx->function_handle_cache()), |
| resource_mgr(ctx->resource_mgr()), |
| model(ctx->model()), |
| runner(*(ctx->runner())), |
| runner_threadpool_size(ctx->runner_threadpool_size()), |
| stats_aggregator(ctx->stats_aggregator()), |
| thread_factory(ctx->thread_factory()), |
| thread_pool(ctx->thread_pool()) {} |
| |
| explicit Params(OpKernelContext* ctx) |
| : env(ctx->env()), flr(ctx->function_library()) { |
| // NOTE: need reinterpret_cast because function.h forward-declares Device. |
| DeviceBase* device = |
| reinterpret_cast<DeviceBase*>(ctx->function_library()->device()); |
| allocator_getter = [device](AllocatorAttributes attrs) { |
| return device->GetAllocator(attrs); |
| }; |
| thread::ThreadPool* thread_pool = |
| ctx->device()->tensorflow_device_thread_pool(); |
| if (thread_pool) { |
| runner_threadpool_size = thread_pool->NumThreads(); |
| } else { |
| runner_threadpool_size = port::MaxParallelism(); |
| } |
| |
| // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so |
| // that a symbol in the tensorflow::data namespace is always on the stack |
| // when executing a function inside a Dataset. |
| runner = std::bind( |
| []( |
| // Note: `runner` is a const reference to avoid copying it. |
| const std::function<void(std::function<void()>)>& ctx_runner, |
| std::function<void()> fn) { |
| std::function<void()> wrapped_fn = std::bind( |
| [](const std::function<void()>& fn) { Runner::get()->Run(fn); }, |
| std::move(fn)); |
| ctx_runner(std::move(wrapped_fn)); |
| }, |
| *ctx->runner(), std::placeholders::_1); |
| } |
| |
| // The Allocator to be used to allocate the output of an iterator. |
| std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; |
| |
| // The CancellationManager to be used to cancel execution of ops. |
| CancellationManager* cancellation_manager; |
| |
| // Interface to operating system functionality. |
| Env* env = nullptr; |
| |
| // The FunctionLibraryRuntime object to be used to make function calls. |
| FunctionLibraryRuntime* flr = nullptr; |
| |
| // A FunctionHandleCache that owns all the function handles. Not owned. |
| FunctionHandleCache* function_handle_cache = nullptr; |
| |
| // A resource manager for storing dataset-related state, e.g. random |
| // seeds or cached tensors. Not owned. |
| ResourceMgr* resource_mgr = nullptr; |
| |
| // If non-null, identifies the object used for performance modeling. |
| std::shared_ptr<model::Model> model = nullptr; |
| |
| // Function call support. |
| std::function<void(std::function<void()>)> runner = nullptr; |
| |
| // Number of threads used for executing user-defined functions. |
| int32 runner_threadpool_size = 0; |
| |
| // The `StatsAggregator` object to record statistics about the iterator. |
| std::shared_ptr<StatsAggregator> stats_aggregator = nullptr; |
| |
| // A factory for creating threads to perform blocking work. |
| std::shared_ptr<ThreadFactory> thread_factory = nullptr; |
| |
| // A shared thread pool to schedule computation into. |
| thread::ThreadPoolInterface* thread_pool = nullptr; |
| }; |
| |
| explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} |
| |
| explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {} |
| |
| explicit IteratorContext(Params params) : params_(std::move(params)) {} |
| |
| Allocator* allocator(AllocatorAttributes attrs) { |
| return params_.allocator_getter(attrs); |
| } |
| |
| std::function<Allocator*(AllocatorAttributes)> allocator_getter() { |
| return params_.allocator_getter; |
| } |
| |
| CancellationManager* cancellation_manager() { |
| return params_.cancellation_manager; |
| } |
| |
| Env* env() const { return params_.env; } |
| |
| FunctionLibraryRuntime* flr() { return params_.flr; } |
| |
| FunctionHandleCache* function_handle_cache() { |
| return params_.function_handle_cache; |
| } |
| |
| ResourceMgr* resource_mgr() { return params_.resource_mgr; } |
| |
| const std::shared_ptr<model::Model>& model() { return params_.model; } |
| |
| std::function<void(std::function<void()>)>* runner() { |
| return ¶ms_.runner; |
| } |
| |
| int32 runner_threadpool_size() { return params_.runner_threadpool_size; } |
| |
| std::shared_ptr<StatsAggregator> stats_aggregator() { |
| return params_.stats_aggregator; |
| } |
| |
| const std::shared_ptr<ThreadFactory>& thread_factory() { |
| return params_.thread_factory; |
| } |
| |
| thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; } |
| |
| Params params() { return params_; } |
| |
| std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name, |
| int num_threads) { |
| if (params_.thread_pool) { |
| // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which |
| // is an instance of `thread::ThreadPoolInterface`). Notably, the |
| // ownership of `params_.thread_pool` is *not* transferred onto the newly |
| // created `ThreadPool` instance. |
| return absl::make_unique<thread::ThreadPool>(params_.thread_pool); |
| } else { |
| return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(), |
| name, num_threads, |
| /*low_latency_hint=*/false); |
| } |
| } |
| |
| std::unique_ptr<Thread> StartThread(const string& name, |
| std::function<void()> fn) { |
| if (params_.thread_factory) { |
| return params_.thread_factory->StartThread(name, std::move(fn)); |
| } else { |
| return absl::WrapUnique( |
| Env::Default()->StartThread({}, name, std::move(fn))); |
| } |
| } |
| |
| private: |
| Params params_; |
| }; |
| |
| // Aggregates runtime support needed for dataset and iterator serialization. |
| class SerializationContext { |
| public: |
| struct Params { |
| std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned. |
| |
| // Indicates whether serialization should check if the dataset depends on |
| // external state. If the check is enabled and external state is |
| // encountered, then the serialization will fail. |
| bool check_external_state = true; |
| |
| // Indicates whether an attempt to serialize a dataset that does not |
| // implement serialization should result in an error. If set to `false`, the |
| // serialized graph will replace the dataset with a placeholder returned in |
| // `input_list`. |
| bool fail_if_unimplemented = true; |
| |
| // Indicates whether (potentionally large) data tensors should be |
| // serialized, or replaced with a placeholder returned in `input_list`. The |
| // latter makes sense to do when performing data agnostic graph rewrites to |
| // reduce the memory usage. |
| bool serialize_data_tensors = true; |
| }; |
| |
| explicit SerializationContext(Params params) : params_(std::move(params)) {} |
| |
| std::vector<std::pair<string, Tensor>>* input_list() { |
| return params_.input_list; |
| } |
| |
| bool check_external_state() const { return params_.check_external_state; } |
| |
| bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; } |
| |
| bool serialize_data_tensors() const { return params_.serialize_data_tensors; } |
| |
| private: |
| Params params_; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext); |
| }; |
| |
| // Represents the current position in a range of outputs, where the |
| // range of outputs is typically represented by an `DatasetBase`, |
| // defined below. |
| class IteratorBase { |
| public: |
| virtual ~IteratorBase() { |
| for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { |
| (*rit)(); |
| } |
| } |
| |
| // Gets the next output from the range that this iterator is traversing. |
| // |
| // If at least one output remains in this iterator's range, that |
| // output will be stored in `*out_tensors` and `false` will be |
| // stored in `*end_of_sequence`. |
| // |
| // If no more outputs remain in this iterator's range, `true` will |
| // be stored in `*end_of_sequence`, and the content of |
| // `*out_tensors` will be undefined. |
| // |
| // Implementations should never return `OutOfRange` error. If at end of |
| // sequence, set `*end_of_sequence = true` and return `Status::OK()`. |
| // Internally raised `OutOfRange` errors that do not imply end of sequence |
| // should be converted to a different error type before being propagated to |
| // the caller. |
| // |
| // This method is thread-safe. |
| // |
| // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and |
| // potentially remove this method. |
| virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) = 0; |
| |
| Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) { |
| return GetNext(&ctx, out_tensors, end_of_sequence); |
| } |
| |
| // Returns a vector of DataType values, representing the respective |
| // element types of each tuple component in the outputs of this |
| // iterator. |
| virtual const DataTypeVector& output_dtypes() const = 0; |
| |
| // Returns a vector of tensor shapes, representing the respective |
| // (and possibly partially defined) shapes of each tuple component |
| // in the outputs of this iterator. |
| virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; |
| |
| // Returns a string that identifies the sequence of iterators leading up to |
| // this iterator. |
| virtual const string& prefix() const = 0; |
| |
| // Performs initialization that needs to happen outside of a constructor to |
| // properly propagate errors. |
| virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } |
| |
| // Saves the state of this iterator. |
| virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { |
| return SaveInternal(writer); |
| } |
| |
| // Restores the state of this iterator. |
| virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { |
| return RestoreInternal(ctx, reader); |
| } |
| |
| protected: |
| // Returns a node that models this iterator. |
| virtual std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const = 0; |
| |
| // This is needed so that sub-classes of IteratorBase can call |
| // `SaveInternal` on their input iterators. |
| Status SaveInput(IteratorStateWriter* writer, |
| const std::unique_ptr<IteratorBase>& input) { |
| return input->SaveInternal(writer); |
| } |
| |
| // This is needed so that sub-classes of IteratorBase can call |
| // `RestoreInternal` on their input iterators. |
| Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, |
| const std::unique_ptr<IteratorBase>& input) { |
| return input->RestoreInternal(ctx, reader); |
| } |
| |
| // Saves the state of this iterator. |
| // |
| // This method is used to store the state of the iterator in a checkpoint. |
| // |
| // TODO(jsimsa): Make this method pure virtual once all `IteratorBase` |
| // implementations have an override. |
| virtual Status SaveInternal(IteratorStateWriter* writer) { |
| return errors::Unimplemented("SaveInternal"); |
| } |
| |
| // Restores the state of this iterator. |
| // |
| // This method is used to restore the state of the iterator from a checkpoint. |
| // |
| // TODO(jsimsa): Make this method pure virtual once all `IteratorBase` |
| // implementations have an override. |
| virtual Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) { |
| return errors::Unimplemented("RestoreInternal"); |
| } |
| |
| // Returns the number of elements produced by this itertaor. |
| int64 num_elements() const { |
| if (node_) return node_->num_elements(); |
| return 0; |
| } |
| |
| private: |
| friend class DatasetBase; // for access to `AddCleanupFunction` |
| friend class DatasetBaseIterator; // for access to `node_` |
| |
| // Registers a cleanup function to be called upon object destruction. |
| // |
| // Registered functions are invoked in the reserve order of registration. |
| void AddCleanupFunction(std::function<void()>&& cleanup_fn) { |
| cleanup_fns_.push_back(std::move(cleanup_fn)); |
| } |
| |
| // Associates the given performance modeling `Node` with this iterator. |
| void SetNode(std::shared_ptr<model::Node> node) { node_ = node.get(); } |
| |
| std::vector<std::function<void()>> cleanup_fns_; |
| model::Node* node_ = nullptr; // Not owned. |
| }; |
| |
| // Represents runtime information needed to construct a dataset. |
| class DatasetContext { |
| public: |
| struct Params { |
| string type_string; // op type name of this dataset. |
| string node_name; // graph node name of this dataset op, uniquely |
| // identifying the dataset in the graph. |
| }; |
| |
| explicit DatasetContext(Params params) : params_(std::move(params)) {} |
| |
| explicit DatasetContext(OpKernelContext* ctx) { |
| params_.type_string = ctx->op_kernel().type_string(); |
| params_.node_name = ctx->op_kernel().name(); |
| } |
| |
| const string& type_string() const { return params_.type_string; } |
| const string& node_name() const { return params_.node_name; } |
| |
| private: |
| Params params_; |
| }; |
| |
| // Returns the number of bytes allocated for the given tensor. |
| int64 GetAllocatedBytes(const std::vector<Tensor>& element); |
| |
| // Validates and extracts a `DatasetBase` object from `tensor`. |
| // |
| // `tensor` must have been written by a call to SetVariantTensorToDataset(). |
| // |
| // The retrieved pointer is a borrowed reference to the dataset, which is owned |
| // by the tensor. The consumer must either acquire its own reference to the |
| // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not |
| // destroyed or mutated while the retrieved pointer is in use. |
| Status GetDatasetFromVariantTensor(const Tensor& tensor, |
| DatasetBase** out_dataset); |
| |
| // Stores a `DatasetBase` object in `tensor`. |
| // |
| // The ownership of `dataset` is transferred to `tensor`. |
| Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); |
| |
| // Represents a (potentially infinite) range of outputs, where each |
| // output is a tuple of tensors. |
| class DatasetBase : public core::RefCounted { |
| public: |
| // Key for storing the Dataset graph in the serialized format. |
| TF_EXPORT static const char kDatasetGraphKey[]; |
| |
| // Key for storing the output node of the Dataset graph in the serialized |
| // format. |
| TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; |
| |
| explicit DatasetBase(DatasetContext&& ctx) |
| : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {} |
| |
| // Op type name of this dataset. |
| const string& type_string() const { return type_string_; } |
| |
| // Graph node name of this dataset op, uniquely identifying the dataset in |
| // the graph. |
| const string& node_name() const { return node_name_; } |
| |
| // Returns a new iterator for iterating over the range of elements in |
| // this dataset. |
| // |
| // This method may be called multiple times on the same instance, |
| // and the resulting iterators will have distinct state. Each |
| // iterator will traverse all elements in this dataset from the |
| // start. |
| // |
| // The prefix identifies the sequence of iterators leading up to the newly |
| // created iterator. |
| Status MakeIterator(IteratorContext* ctx, const string& output_prefix, |
| std::unique_ptr<IteratorBase>* iterator) const { |
| *iterator = MakeIteratorInternal(output_prefix); |
| if (const auto& model = ctx->model()) { |
| const string& prefix = (*iterator)->prefix(); |
| (*iterator)->SetNode(model->AddNode(MakeNodeFactory(ctx, iterator->get()), |
| prefix, output_prefix)); |
| (*iterator)->AddCleanupFunction( |
| [model, prefix]() { model->RemoveNode(prefix); }); |
| } |
| return (*iterator)->Initialize(ctx); |
| } |
| |
| Status MakeIterator(IteratorContext&& ctx, const string& output_prefix, |
| std::unique_ptr<IteratorBase>* iterator) const { |
| return MakeIterator(&ctx, output_prefix, iterator); |
| } |
| |
| // Returns a vector of DataType values, representing the respective |
| // element types of each tuple component in the outputs of this |
| // dataset. |
| virtual const DataTypeVector& output_dtypes() const = 0; |
| |
| // Returns a vector of tensor shapes, representing the respective |
| // (and possibly partially defined) shapes of each tuple component |
| // in the outputs of this dataset. |
| virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; |
| |
| // Returns the number of bytes allocated for tensors of this dataset. |
| virtual int64 AllocatedBytes() const { return 0; } |
| |
| // Returns the cardinality of this dataset. |
| virtual int64 Cardinality() const { return kUnknownCardinality; } |
| |
| // A human-readable debug string for this dataset. |
| virtual string DebugString() const = 0; |
| |
| // If the dataset is stateful it will not be possible to save its graph or |
| // checkpoint the state of its iterators. |
| // |
| // TODO(jsimsa): Remove this method once all `DatasetBase` implementations are |
| // migrated over to `CheckExternalState`. |
| virtual bool IsStateful() const { return false; } |
| |
| // Indicates whether the dataset depends on any external state. If so, the |
| // method returns `errors::FailedPrecondition` with a message that identifies |
| // the external state. Otherwise, the method returns `Status::OK()`. |
| // |
| // TODO(jsimsa): Make this method pure virtual once all `DatasetBase` |
| // implementations have an override. |
| virtual Status CheckExternalState() const { |
| if (IsStateful()) { |
| return errors::FailedPrecondition("Dataset cannot be serialized."); |
| } |
| return Status::OK(); |
| } |
| |
| protected: |
| friend Status AsGraphDef( |
| OpKernelContext* ctx, const DatasetBase* dataset, |
| SerializationContext&& serialization_ctx, |
| GraphDef* graph_def); // For access to graph related members. |
| friend class CapturedFunction; |
| |
| class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { |
| public: |
| explicit DatasetGraphDefBuilder(GraphDefBuilder* b) |
| : GraphDefBuilderWrapper(b) {} |
| Status AddInputDataset(SerializationContext* ctx, |
| const DatasetBase* dataset, Node** output); |
| }; |
| |
| // Serializes the dataset into a `GraphDef`, which has two uses: |
| // |
| // 1) To perform static input pipeline optimizations, tf.data serializes the |
| // dataset graph, applies graph rewrites, and then deserializes the graph. |
| // If a subclass of `DatasetBase` does not implement this method, then it will |
| // be excluded from static optimizations (and so will any upstream datasets). |
| // |
| // 2) To save the dataset so that it can restore at a later point (possibly in |
| // different environment). If a subclass of `DatasetBase` does not implement |
| // this method, then this migration will not be possible. |
| virtual Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** node) const = 0; |
| |
| virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const = 0; |
| |
| private: |
| // Returns a factory for nodes that represent the given iterator. |
| static model::Node::Factory MakeNodeFactory(IteratorContext* ctx, |
| IteratorBase* iterator) { |
| return [ctx, iterator](model::Node::Args args) { |
| return iterator->CreateNode(ctx, std::move(args)); |
| }; |
| } |
| |
| const string type_string_; |
| const string node_name_; |
| }; |
| |
| // Represents an iterator that is associated with a particular dataset. |
| class DatasetBaseIterator : public IteratorBase { |
| public: |
| struct BaseParams { |
| // Owns one reference on the shared dataset object. |
| const DatasetBase* dataset; |
| |
| // Identifies the sequence of iterators leading up to this iterator. |
| const string prefix; |
| }; |
| |
| explicit DatasetBaseIterator(const BaseParams& params) : params_(params) { |
| params_.dataset->Ref(); |
| } |
| |
| ~DatasetBaseIterator() override { params_.dataset->Unref(); } |
| |
| const DataTypeVector& output_dtypes() const override { |
| return params_.dataset->output_dtypes(); |
| } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| return params_.dataset->output_shapes(); |
| } |
| |
| // The sequence of iterators leading up to this iterator. |
| const string& prefix() const override { return params_.prefix; } |
| |
| // Returns a name to be used for the TraceMe event. |
| // |
| // NOTE: TraceMe support passing key value pairs of "arguments" using the |
| // following format "name#arg_1=value_,...,arg_n=value_n". |
| virtual string BuildTraceMeName() { return params_.prefix; } |
| |
| Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) final; |
| |
| Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final { |
| TF_RETURN_IF_ERROR(params_.dataset->CheckExternalState()); |
| return IteratorBase::Save(ctx, writer); |
| } |
| |
| protected: |
| // Internal implementation of GetNext that is wrapped in tracing logic. |
| virtual Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) = 0; |
| |
| string full_name(const string& name) const { |
| return strings::StrCat(params_.prefix, ":", name); |
| } |
| |
| // By default we model iterators using an unknown node, which acts as |
| // pass-through with respect to performance modeling. |
| std::shared_ptr<model::Node> CreateNode( |
| IteratorContext* ctx, model::Node::Args args) const override { |
| return model::MakeUnknownNode(std::move(args)); |
| } |
| |
| // When modeling is enabled, this method disables autotuning for the given |
| // iterator (and the transitive closure of its inputs). |
| void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) { |
| if (iterator->node_) { |
| iterator->node_->set_autotune(false); |
| } |
| } |
| |
| // When modeling is enabled, this method enables autotuning for the given |
| // iterator (and the transitive closure of its inputs). |
| void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) { |
| if (iterator->node_) { |
| iterator->node_->set_autotune(true); |
| } |
| } |
| |
| // When modeling is enabled, this method records the fact that this iterator |
| // has dequeued an element from an internal buffer. |
| void RecordBufferDequeue(IteratorContext* ctx, |
| const std::vector<Tensor>& element) { |
| if (collect_resource_usage(ctx)) { |
| node_->add_buffered_bytes(-GetAllocatedBytes(element)); |
| } |
| } |
| |
| // When modeling is enabled, this method records the fact that this iterator |
| // has enqueued an element in an internal buffer. |
| void RecordBufferEnqueue(IteratorContext* ctx, |
| const std::vector<Tensor>& element) { |
| if (collect_resource_usage(ctx)) { |
| node_->add_buffered_bytes(GetAllocatedBytes(element)); |
| } |
| } |
| |
| // When modeling is enabled, this method records the fact that this iterator |
| // has produced an element. |
| void RecordElement(IteratorContext* ctx) { |
| if (node_) { |
| node_->record_element(); |
| } |
| } |
| |
| // When modeling is enabled, this method records the fact that a thread of |
| // this iterator has started work. |
| void RecordStart(IteratorContext* ctx, bool stop_output = false) { |
| if (collect_resource_usage(ctx)) { |
| int64 now_nanos = Env::Default()->NowNanos(); |
| if (stop_output && node_->output()) { |
| node_->output()->record_stop(now_nanos); |
| } |
| node_->record_start(now_nanos); |
| } |
| } |
| |
| // When modeling is enabled, this method records the fact that a thread of |
| // this iterator has stopped work. |
| void RecordStop(IteratorContext* ctx, bool start_output = false) { |
| if (collect_resource_usage(ctx)) { |
| int64 now_nanos = Env::Default()->NowNanos(); |
| node_->record_stop(now_nanos); |
| if (start_output && node_->output()) { |
| node_->output()->record_start(now_nanos); |
| } |
| } |
| } |
| |
| private: |
| inline bool collect_resource_usage(IteratorContext* ctx) { |
| auto model = ctx->model(); |
| return model && model->collect_resource_usage() && node_; |
| } |
| |
| BaseParams params_; |
| }; |
| |
| // Represents an iterator that is associated with a particular dataset |
| // with a particular type. |
| template <class DatasetType> |
| class DatasetIterator : public DatasetBaseIterator { |
| public: |
| struct Params { |
| // Borrowed pointer to the dataset. |
| const DatasetType* dataset; |
| |
| // Identifies the sequence of iterators leading up to this iterator. |
| const string prefix; |
| }; |
| |
| explicit DatasetIterator(const Params& params) |
| : DatasetBaseIterator({params.dataset, params.prefix}), |
| typed_dataset_(params.dataset) {} |
| |
| // The dataset from which this iterator was created. |
| const DatasetType* dataset() const { return typed_dataset_; } |
| |
| protected: |
| virtual Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) = 0; |
| |
| private: |
| const DatasetType* const typed_dataset_; // Not owned. |
| }; |
| |
| template <typename T> |
| Status ParseScalarArgument(OpKernelContext* ctx, |
| const StringPiece& argument_name, T* output) { |
| const Tensor* argument_t; |
| TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); |
| if (!TensorShapeUtils::IsScalar(argument_t->shape())) { |
| return errors::InvalidArgument(argument_name, " must be a scalar"); |
| } |
| *output = argument_t->scalar<T>()(); |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| Status ParseVectorArgument(OpKernelContext* ctx, |
| const StringPiece& argument_name, |
| std::vector<T>* output) { |
| const Tensor* argument_t; |
| TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); |
| if (!TensorShapeUtils::IsVector(argument_t->shape())) { |
| return errors::InvalidArgument(argument_name, " must be a vector"); |
| } |
| int size = argument_t->vec<T>().size(); |
| output->reserve(size); |
| for (int i = 0; i < size; ++i) { |
| output->push_back(argument_t->vec<T>()(i)); |
| } |
| return Status::OK(); |
| } |
| |
| // Encapsulates the work required to plug a DatasetBase into the core TensorFlow |
| // graph execution engine. |
| class DatasetOpKernel : public OpKernel { |
| public: |
| DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
| void Compute(OpKernelContext* ctx) final; |
| |
| protected: |
| // Subclasses should implement this method. It will be called during Compute |
| // execution. |
| virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; |
| }; |
| |
| // Encapsulates the work required to plug unary Datasets into the core |
| // TensorFlow graph execution engine. |
| class UnaryDatasetOpKernel : public DatasetOpKernel { |
| public: |
| UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} |
| |
| protected: |
| void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; |
| virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
| DatasetBase** output) = 0; |
| }; |
| |
| // Encapsulates the work required to plug binary Datasets into the core |
| // TensorFlow graph execution engine. |
| class BinaryDatasetOpKernel : public DatasetOpKernel { |
| public: |
| BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} |
| |
| protected: |
| void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; |
| virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
| DatasetBase* another_input, |
| DatasetBase** output) = 0; |
| }; |
| |
| // A simple background worker that executes closures asynchronously and without |
| // blocking. |
| // |
| // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel` |
| // to avoid blocking an executor thread that may be required by the blocking |
| // work. |
| // |
| // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this |
| // purpose because its current implementation (in Eigen) uses a finite-length |
| // queue and will block the caller when full. This can lead to deadlock under |
| // heavy load. Since the number of concurrent work items in each user of a |
| // `BackgroundWorker` is at most one per op invocation, the dynamic allocation |
| // overhead is tolerable. |
| class BackgroundWorker { |
| public: |
| BackgroundWorker(Env* env, const string& name); |
| |
| ~BackgroundWorker(); |
| |
| void Schedule(std::function<void()> work_item); |
| |
| private: |
| void WorkerLoop(); |
| |
| std::unique_ptr<Thread> thread_; |
| mutex mu_; |
| condition_variable cond_var_; |
| bool cancelled_ GUARDED_BY(mu_) = false; |
| std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); |
| }; |
| |
| } // namespace data |
| |
| // TODO(b/114112161): Remove these aliases when all users have moved over to the |
| // `tensorflow::data` namespace. |
| using data::DatasetBase; |
| using data::DatasetContext; |
| using data::DatasetIterator; |
| using data::DatasetOpKernel; |
| using data::IteratorBase; |
| using data::IteratorContext; |
| using data::IteratorStateReader; |
| using data::IteratorStateWriter; |
| using data::SerializationContext; |
| using data::UnaryDatasetOpKernel; |
| |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |