blob: beb9d0c5e82bfc0451e8ecab70b43d349c5c8dff [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
namespace data {
template <typename T>
class AnonymousResourceOp : public OpKernel {
public:
static std::atomic<int64> resource_id_counter_;
explicit AnonymousResourceOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
FunctionLibraryRuntime* lib;
std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
OP_REQUIRES_OK(
ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
T* resource;
OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
std::move(pflr), lib, &resource));
string container_name = name();
string unique_name =
strings::StrCat(container_name, resource_id_counter_.fetch_add(1));
ResourceMgr* mgr = ctx->resource_manager();
OP_REQUIRES_OK(ctx, mgr->Create<T>(container_name, unique_name, resource));
Tensor* handle_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
ResourceHandle handle = MakeResourceHandle(ctx, container_name, unique_name,
MakeTypeIndex<T>());
handle_t->scalar<ResourceHandle>()() = handle;
if (create_deleter_) {
Tensor* deleter_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t));
deleter_t->scalar<Variant>()() =
ResourceDeleter(handle, ctx->resource_manager());
}
}
protected:
virtual string name() = 0;
virtual Status CreateResource(
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib, T** resource) = 0;
bool create_deleter_ = true;
};
template <typename T>
std::atomic<int64> AnonymousResourceOp<T>::resource_id_counter_;
// Returns a GraphDef representation of the given dataset.
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def);
// Creates a connection between "child" and "parent" cancellation managers so
// that parent cancellations are propagated to the child, returning a function
// that can be used to remove the connection.
Status ConnectCancellationManagers(CancellationManager* parent,
CancellationManager* child,
std::function<void()>* deregister_fn);
// Rewrites the input dataset using the given config.
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library,
DatasetBase** rewritten_input);
// Returns Status::OK() if `expected` and `received` types match,
// errors::InvalidArgument otherwise.
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received);
// Returns Status::OK() if `expected` and `received` shapes are compatible,
// errors::InvalidArgument otherwise.
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
const std::vector<PartialTensorShape>& received);
// Returns a stable hash of the portion of the graph `g` rooted at
// `node`, by creating a Merkle tree-like structure.
//
// Specifically, this function recursively walks the graph from `node` by
// following its inputs.
//
// The hash is computed by hashing its op name, device, attributes, and hashes
// of its inputs (if applicable).
//
// There is currently no guarantee that the hash of a subgraph will stay the
// same between TensorFlow builds.
uint64 HashSubgraph(const GraphDef& g, const NodeDef* node);
// Returns a stable hash of the function `f`.
//
// This function computes the hash by hashing the metadata of the
// function (disregarding the auto-generated names and descriptions) and also
// hashing the subgraph rooted at each of the output nodes.
//
// There is currently no guarantee that the hash of a function will stay the
// same between TensorFlow builds.
uint64 HashSubgraphFunction(const FunctionDefLibrary& library,
const FunctionDef* f);
// Helper class for reading data from a VariantTensorData object.
class VariantTensorDataReader : public IteratorStateReader {
public:
explicit VariantTensorDataReader(const VariantTensorData* data);
// Returns OK iff the initialization was successful.
Status ReadScalar(StringPiece key, int64* val) override;
Status ReadScalar(StringPiece key, string* val) override;
Status ReadTensor(StringPiece key, Tensor* val) override;
bool Contains(StringPiece key) override;
private:
template <typename T>
Status ReadScalarInternal(StringPiece key, T* val);
Status ReadTensorInternal(StringPiece key, Tensor* val);
std::map<string, size_t> map_;
const VariantTensorData* data_; // Not owned.
};
// Helper class for writing data to a VariantTensorData object.
class VariantTensorDataWriter : public IteratorStateWriter {
public:
// Does not take ownership of data.
explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
Status WriteScalar(StringPiece key, const int64 val) override;
Status WriteScalar(StringPiece key, const string& val) override;
Status WriteTensor(StringPiece key, const Tensor& val) override;
// Writes the metadata to `data_`.
Status Flush();
private:
template <typename T>
Status WriteScalarInternal(StringPiece key, const T& val);
Status WriteTensorInternal(StringPiece key, const Tensor& val);
VariantTensorData* data_;
std::vector<string> keys_;
};
// Adds the functions in `to_add` to `base`. If a function with a matching
// signature already exists in `base`, replaces it with the function from
// `to_add`.
Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
const FunctionLibraryDefinition& to_add);
Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
const FunctionDefLibrary& to_add);
// Creates a runner that runs functions with limited parallelism.
std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
std::function<void(std::function<void()>)> runner, int max_parallelism);
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_