blob: e46f5249a9702ce3762bd8b2a6676175df03c624 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kDelimiter[] = "@@";
void AddFakeSinks(FunctionDef* function_def) {
int counter = 0;
for (const auto& output : function_def->signature().output_arg()) {
NodeDef* node = function_def->add_node_def();
tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
strings::StrCat("FakeSink", counter++), function_def, node);
node->set_op("Identity");
node->add_input(function_def->ret().at(output.name()));
(*node->mutable_attr())["T"].set_type(output.type());
(*function_def->mutable_ret())[output.name()] =
strings::StrCat(node->name(), ":output:0");
}
}
void RemoveFakeSinks(FunctionDef* function_def) {
// Map from identity node names to their input tensor strings
std::map<string, string> identity_map;
for (const auto& node : function_def->node_def()) {
if (node.op() == "Identity" && node.input_size() == 1) {
identity_map[node.name()] = node.input(0);
}
}
for (const auto& output_arg : function_def->signature().output_arg()) {
const string& tensor = function_def->ret().at(output_arg.name());
const string& output_node = tensor.substr(0, tensor.find(':'));
if (identity_map.find(output_node) != identity_map.end()) {
(*function_def->mutable_ret())[output_arg.name()] =
identity_map.at(output_node);
}
}
}
Status ApplyRewrites(OpKernelContext* ctx,
const std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library, GraphDef* graph_def,
string* output_node) {
// Add an identity node as the fetch node, otherwise we might get 'placeholder
// is both fed and fetched' errors in some cases when using input list with
// placeholder dataset nodes.
NodeDef* node = graph_def->mutable_node()->Add();
tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
node);
node->set_op("Identity");
node->add_input(*output_node);
(*node->mutable_attr())["T"].set_type(DT_VARIANT);
*output_node = node->name();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
AddFakeSinks(&function_def);
}
// Create metagraph.
MetaGraphDef meta_graph_def;
(*meta_graph_def.mutable_graph_def()) = *graph_def;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef collection_def;
auto node_list = collection_def.mutable_node_list();
node_list->add_value(*output_node);
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"graph", meta_graph_def, item_config);
grappler_item->optimization_options().optimize_function_library =
optimize_function_library;
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
*config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, config, ctx->device(), &cluster, graph_def));
// Remove fake sinks after optimizations are done.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
RemoveFakeSinks(&function_def);
}
return Status::OK();
}
uint64 DefaultDependencyLoopNodeHash() {
static const uint64 hash = Hash64("DependencyLoopNode");
return hash;
}
uint64 DefaultDependencyLoopFnHash() {
static const uint64 hash = Hash64("DependencyLoopFn");
return hash;
}
void ClearOpDefForHashing(OpDef* op) {
op->clear_name();
op->clear_description();
op->clear_summary();
for (auto& arg : *op->mutable_input_arg()) {
arg.clear_name();
arg.clear_description();
}
for (auto& arg : *op->mutable_output_arg()) {
arg.clear_name();
arg.clear_description();
}
}
// forward declaration for use in HashAttr.
uint64 HashSubgraphFunctionImpl(
const FunctionDefLibrary& library, const FunctionDef* f,
std::vector<std::string>* visited,
absl::flat_hash_map<std::string, uint64>* cache);
// Produces a hash of a attribute from an op or a function. Since attributes
// may refer to functions present in the graph, we may need to hash the function
// referred to by the attribute, and thus we need the FunctionDefLibrary.
uint64 HashAttr(const FunctionDefLibrary& library, const std::string& attr_key,
const AttrValue& attr_value, std::vector<std::string>* visited,
absl::flat_hash_map<std::string, uint64>* cache) {
uint64 attr_hash = 0;
if (attr_value.has_func()) {
for (const auto& func : library.function()) {
if (func.signature().name() == attr_value.func().name()) {
attr_hash = Hash64CombineUnordered(
attr_hash,
Hash64(absl::StrCat(
attr_key, "=",
HashSubgraphFunctionImpl(library, &func, visited, cache))));
break;
}
}
} else {
attr_hash = Hash64CombineUnordered(
attr_hash, Hash64(absl::StrCat(attr_key, "=",
DeterministicProtoHash64(attr_value))));
}
return attr_hash;
}
// This function hashes a subgraph (rooted at node) by traversing all possible
// dependency paths from that node.
uint64 HashSubgraphImpl(const grappler::GraphView& g, const NodeDef* node,
std::vector<std::string>* visited,
absl::flat_hash_map<std::string, uint64>* cache) {
uint64 input_hash = 0;
uint64 control_dep_hash = 0;
std::string canonical_node_name = absl::StrCat("node-", node->name());
auto it = cache->find(canonical_node_name);
if (it != cache->end()) {
return it->second;
}
uint64 op_hash = Hash64(node->op());
// Checks to make sure we won't get stuck in an infinite loop (especially in
// loops with control dependencies).
for (const std::string& visited_node_name : *visited) {
if (visited_node_name == canonical_node_name) {
uint64 final_hash =
Hash64Combine(DefaultDependencyLoopNodeHash(), op_hash);
(*cache)[canonical_node_name] = final_hash;
return final_hash;
}
}
visited->push_back(canonical_node_name);
for (int i = 0; i < node->input_size(); ++i) {
DCHECK_GT(node->input(i).length(), 0);
if (node->input(i)[0] == '^') {
// TODO(frankchn): Investigate if control dependencies are necessary
// inputs to the hash.
// Control dependency node names start with '^', and order of appearance
// for the control dependencies does not matter.
control_dep_hash = Hash64CombineUnordered(
control_dep_hash,
HashSubgraphImpl(g, g.GetNode(node->input(i).substr(1)), visited,
cache));
} else {
// The output port is significant and is optionally delimited by a ':'
// for non-zero ports.
std::pair<std::string, std::string> node_spec =
absl::StrSplit(node->input(i), absl::MaxSplits(':', 1));
uint64 child_node_hash =
HashSubgraphImpl(g, g.GetNode(node_spec.first), visited, cache);
uint64 child_port_hash = Hash64(node_spec.second);
input_hash = Hash64Combine(
input_hash, Hash64Combine(child_node_hash, child_port_hash));
}
}
uint64 attr_hash = 0;
for (const auto& attr : node->attr()) {
attr_hash = Hash64CombineUnordered(
attr_hash, HashAttr(g.graph()->library(), attr.first, attr.second,
visited, cache));
}
uint64 device_hash = Hash64(node->device());
uint64 final_hash = Hash64Combine(
Hash64Combine(attr_hash, op_hash),
Hash64Combine(device_hash, Hash64Combine(input_hash, control_dep_hash)));
(*cache)[canonical_node_name] = final_hash;
visited->pop_back();
return final_hash;
}
// This function hashes a function by traversing all possible dependency paths
// from all output nodes declared by the function in its definition.
uint64 HashSubgraphFunctionImpl(
const FunctionDefLibrary& library, const FunctionDef* f,
std::vector<std::string>* visited,
absl::flat_hash_map<std::string, uint64>* cache) {
std::string canonical_function_name =
absl::StrCat("function-", f->signature().name());
auto it = cache->find(canonical_function_name);
if (it != cache->end()) {
return it->second;
}
OpDef op = f->signature();
ClearOpDefForHashing(&op);
uint64 signature_hash = OpDefHash(op);
// Checks to make sure we won't get stuck in an infinite loop (especially when
// functions depend on other function ops as a control dependency).
for (const std::string& visited_node_name : *visited) {
if (visited_node_name == canonical_function_name) {
uint64 final_hash =
Hash64Combine(DefaultDependencyLoopFnHash(), signature_hash);
(*cache)[canonical_function_name] = final_hash;
return final_hash;
}
}
visited->push_back(canonical_function_name);
uint64 attr_hash = 0;
for (const auto& attr : f->attr()) {
attr_hash = Hash64CombineUnordered(
attr_hash, HashAttr(library, attr.first, attr.second, visited, cache));
}
uint64 arg_attr_hash = 0;
for (const auto& arg_attr : f->arg_attr()) {
for (const auto& attr : arg_attr.second.attr()) {
arg_attr_hash = Hash64CombineUnordered(
arg_attr_hash,
Hash64Combine(arg_attr.first, HashAttr(library, attr.first,
attr.second, visited, cache)));
}
}
GraphDef node_graph;
for (const auto& node : f->node_def()) {
NodeDef* node_graph_node = node_graph.add_node();
*node_graph_node = node;
}
for (const auto& input_arg : f->signature().input_arg()) {
// We add dummy input nodes for the inputs to the function.
NodeDef* node_graph_node = node_graph.add_node();
node_graph_node->set_name(input_arg.name());
node_graph_node->set_op("_Retval");
}
*(node_graph.mutable_library()) = library;
grappler::GraphView node_gv(&node_graph);
// TODO(frankchn): Investigate whether we need to hash the name of the
// return argument / control return argument or whether we can relax it and
// hash the index (etc...)
uint64 ret_hash = f->ret_size();
for (const auto& ret : f->ret()) {
std::pair<std::string, std::string> node_spec =
absl::StrSplit(ret.second, absl::MaxSplits(':', 1));
// For every return value, we need to hash the output node (and the subgraph
// rooted at the output node) to ensure that the computation graph that
// ends at the output node has not changed.
uint64 node_hash = HashSubgraphImpl(
node_gv, node_gv.GetNode(node_spec.first), visited, cache);
uint64 node_port_hash = Hash64(node_spec.second);
ret_hash = Hash64CombineUnordered(
ret_hash, Hash64Combine(Hash64(ret.first),
Hash64Combine(node_hash, node_port_hash)));
}
uint64 control_ret_hash = f->control_ret_size();
for (const auto& ret : f->control_ret()) {
std::pair<std::string, std::string> node_spec =
absl::StrSplit(ret.second, absl::MaxSplits(':', 1));
uint64 node_hash = HashSubgraphImpl(
node_gv, node_gv.GetNode(node_spec.first), visited, cache);
uint64 node_port_hash = Hash64(node_spec.second);
control_ret_hash = Hash64CombineUnordered(
control_ret_hash,
Hash64Combine(Hash64(ret.first),
Hash64Combine(node_hash, node_port_hash)));
}
uint64 final_hash = Hash64Combine(
Hash64Combine(Hash64Combine(signature_hash, attr_hash), arg_attr_hash),
Hash64Combine(ret_hash, control_ret_hash));
(*cache)[canonical_function_name] = final_hash;
visited->pop_back();
return final_hash;
}
} // anonymous namespace
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def) {
if (serialization_ctx.check_external_state()) {
TF_RETURN_IF_ERROR(dataset->CheckExternalState());
}
GraphDefBuilder b;
DatasetBase::DatasetGraphDefBuilder db(&b);
Node* output_node = nullptr;
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
// Insert a purely symbolic _Retval node to indicate to consumers which node
// represents `dataset`.
ops::UnaryOp("_Retval", output_node,
b.opts()
.WithName("dataset")
.WithAttr("T", DT_VARIANT)
.WithAttr("index", 0));
TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
return Status::OK();
}
Status ConnectCancellationManagers(CancellationManager* parent,
CancellationManager* child,
std::function<void()>* deregister_fn) {
if (parent) {
CancellationToken token = parent->get_cancellation_token();
if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
return errors::Cancelled("Operation was cancelled");
}
*deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
} else {
VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
"not be propagated to the child cancellation manager.";
*deregister_fn = []() {};
}
return Status::OK();
}
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library,
DatasetBase** rewritten_input) {
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.check_external_state = false;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
TF_RETURN_IF_ERROR(
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
string output_node;
for (const auto& node : graph_def.node()) {
if (node.op() == "_Retval") {
output_node = node.input(0);
}
}
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
optimize_function_library, &graph_def,
&output_node));
VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
FunctionLibraryRuntime* flr = nullptr;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
TF_RETURN_IF_ERROR(
ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave).
TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(flr->device());
TF_RETURN_IF_ERROR(
graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
(*rewritten_input)->Ref();
return Status::OK();
}
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received) {
if (expected.size() != received.size()) {
return errors::InvalidArgument(
"Number of components does not match: expected ", expected.size(),
" types but got ", received.size(), ".");
}
for (size_t i = 0; i < expected.size(); ++i) {
if (expected[i] != received[i]) {
return errors::InvalidArgument("Data type mismatch at component ", i,
": expected ", DataTypeString(expected[i]),
" but got ", DataTypeString(received[i]),
".");
}
}
return Status::OK();
}
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
const std::vector<PartialTensorShape>& received) {
if (expected.size() != received.size()) {
return errors::InvalidArgument(
"Number of components does not match: expected ", expected.size(),
" shapes but got ", received.size(), ".");
}
for (size_t i = 0; i < expected.size(); ++i) {
if (!expected[i].IsCompatibleWith(received[i])) {
return errors::InvalidArgument("Incompatible shapes at component ", i,
": expected ", expected[i].DebugString(),
" but got ", received[i].DebugString(),
".");
}
}
return Status::OK();
}
uint64 HashSubgraphFunction(const FunctionDefLibrary& library,
const FunctionDef* f) {
std::vector<std::string> visited;
absl::flat_hash_map<std::string, uint64> cache;
return HashSubgraphFunctionImpl(library, f, &visited, &cache);
}
uint64 HashSubgraph(const GraphDef& g, const NodeDef* node) {
std::vector<std::string> visited;
absl::flat_hash_map<std::string, uint64> cache;
return HashSubgraphImpl(grappler::GraphView(&g), node, &visited, &cache);
}
VariantTensorDataReader::VariantTensorDataReader(
const tensorflow::VariantTensorData* data)
: data_(data) {
string metadata;
data_->get_metadata(&metadata);
auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
for (size_t i = 0; i < keys.size(); ++i) {
map_[keys[i]] = i;
}
}
Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) {
return ReadScalarInternal(key, val);
}
Status VariantTensorDataReader::ReadScalar(StringPiece key, string* val) {
return ReadScalarInternal(key, val);
}
Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) {
return ReadTensorInternal(key, val);
}
bool VariantTensorDataReader::Contains(StringPiece key) {
return map_.find(string(key)) != map_.end();
}
template <typename T>
Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) {
if (map_.find(string(key)) == map_.end()) {
return errors::NotFound(key);
}
*val = data_->tensors(map_[string(key)]).scalar<T>()();
return Status::OK();
}
Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
Tensor* val) {
if (map_.find(string(key)) == map_.end()) {
return errors::NotFound(key);
}
*val = data_->tensors(map_[string(key)]);
return Status::OK();
}
Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
return WriteScalarInternal(key, val);
}
Status VariantTensorDataWriter::WriteScalar(StringPiece key,
const string& val) {
return WriteScalarInternal(key, val);
}
Status VariantTensorDataWriter::WriteTensor(StringPiece key,
const Tensor& val) {
return WriteTensorInternal(key, val);
}
Status VariantTensorDataWriter::Flush() {
string metadata;
for (size_t i = 0; i < keys_.size(); ++i) {
strings::StrAppend(&metadata, kDelimiter, keys_[i]);
}
data_->set_metadata(metadata);
return Status::OK();
}
template <typename T>
Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
const T& val) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
val_t.scalar<T>()() = val;
return WriteTensorInternal(key, val_t);
}
Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
const Tensor& val) {
DCHECK_EQ(key.find(kDelimiter), string::npos);
keys_.push_back(string(key));
*(data_->add_tensors()) = val;
return Status::OK();
}
Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
const FunctionLibraryDefinition& to_add) {
for (const auto& fn : to_add.ListFunctionNames()) {
if (auto found = base->Find(fn)) {
if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
return errors::InvalidArgument("Cannot add function '", fn,
"' because a different function with "
"the same signature already exists.");
}
TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
}
}
return base->AddLibrary(to_add);
}
Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
const FunctionDefLibrary& to_add) {
for (const auto& fd : to_add.function()) {
if (auto found = base->Find(fd.signature().name())) {
if (!OpDefEqual(found->signature(), fd.signature())) {
return errors::InvalidArgument("Cannot add function '",
fd.signature().name(),
"' because a different function with "
"the same signature already exists.");
}
TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
}
}
return base->AddLibrary(to_add);
}
std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
std::function<void(std::function<void()>)> runner, int max_parallelism) {
return std::bind(
[max_parallelism](
// Note: `runner` is a const reference to avoid copying it.
const std::function<void(std::function<void()>)>& runner,
std::function<void()> fn) {
std::function<void()> scoped_fn = std::bind(
[max_parallelism](const std::function<void()>& fn) {
ScopedPerThreadMaxParallelism scope(max_parallelism);
fn();
},
std::move(fn));
runner(std::move(scoped_fn));
},
std::move(runner), std::placeholders::_1);
}
} // namespace data
} // namespace tensorflow