blob: 3717016bba4557d6a380b4ca36f3c085cfa6f3f1 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/rewrite_utils.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/serialization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kDelimiter[] = "@@";
void AddFakeSinks(FunctionDef* function_def) {
int counter = 0;
for (const auto& output : function_def->signature().output_arg()) {
NodeDef* node = function_def->add_node_def();
tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
strings::StrCat("FakeSink", counter++), function_def, node);
node->set_op("Identity");
node->add_input(function_def->ret().at(output.name()));
(*node->mutable_attr())["T"].set_type(output.type());
(*function_def->mutable_ret())[output.name()] =
strings::StrCat(node->name(), ":output:0");
}
}
void RemoveFakeSinks(FunctionDef* function_def) {
// Map from identity node names to their input tensor strings
std::map<string, string> identity_map;
for (const auto& node : function_def->node_def()) {
if (node.op() == "Identity" && node.input_size() == 1) {
identity_map[node.name()] = node.input(0);
}
}
for (const auto& output_arg : function_def->signature().output_arg()) {
const string& tensor = function_def->ret().at(output_arg.name());
const string& output_node = tensor.substr(0, tensor.find(':'));
if (identity_map.find(output_node) != identity_map.end()) {
(*function_def->mutable_ret())[output_arg.name()] =
identity_map.at(output_node);
}
}
}
Status ApplyRewrites(OpKernelContext* ctx,
const std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library, GraphDef* graph_def,
string* output_node) {
// Add an identity node as the fetch node, otherwise we might get 'placeholder
// is both fed and fetched' errors in some cases when using input list with
// placeholder dataset nodes.
NodeDef* node = graph_def->mutable_node()->Add();
tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
node);
node->set_op("Identity");
node->add_input(*output_node);
(*node->mutable_attr())["T"].set_type(DT_VARIANT);
*output_node = node->name();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
AddFakeSinks(&function_def);
}
// Create metagraph.
MetaGraphDef meta_graph_def;
(*meta_graph_def.mutable_graph_def()) = *graph_def;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef collection_def;
auto node_list = collection_def.mutable_node_list();
node_list->add_value(*output_node);
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"graph", meta_graph_def, item_config);
grappler_item->optimization_options().optimize_function_library =
optimize_function_library;
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
*config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
std::move(*grappler_item), config, ctx->device(), &cluster, graph_def));
// Remove fake sinks after optimizations are done.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
RemoveFakeSinks(&function_def);
}
return Status::OK();
}
} // anonymous namespace
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library, bool record_fingerprint,
DatasetBase** rewritten_input) {
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.external_state_policy =
SerializationContext::ExternalStatePolicy::kIgnore;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
params.preserve_random_seeds = false;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
TF_RETURN_IF_ERROR(
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
string output_node;
for (const auto& node : graph_def.node()) {
if (node.op() == "_Retval") {
output_node = node.input(0);
}
}
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
optimize_function_library, &graph_def,
&output_node));
VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
FunctionLibraryRuntime* flr = nullptr;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
TF_RETURN_IF_ERROR(
ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
// Some functions may have been modified without having their names changed
// (for example, nested dataset graphs from FlatMap or Interleave).
TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(flr->device());
TF_RETURN_IF_ERROR(
graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
(*rewritten_input)->Ref();
if (record_fingerprint) {
(*ctx->runner())([graph_def = std::move(graph_def),
input_list = std::move(input_list),
output_node = std::move(output_node)]() {
const NodeDef* node_def = nullptr;
for (const auto& node : graph_def.node()) {
if (node.name() == output_node) {
node_def = &node;
break;
}
}
if (node_def == nullptr) {
VLOG(3) << "Failed to find node: " << output_node;
return;
}
uint64 hash = 0;
Status s = HashNode(graph_def, *node_def, &hash);
if (!s.ok()) {
VLOG(3) << "Failed to hash graph: " << s.ToString();
return;
}
for (const auto& pair : input_list) {
hash = Hash64CombineUnordered(hash, Hash64(pair.first));
uint64 tensor_hash = 0;
Status s = HashTensor(pair.second, &tensor_hash);
if (s.ok()) {
hash = Hash64CombineUnordered(hash, tensor_hash);
} else {
VLOG(3) << "Failed to hash tensor: " << s.ToString();
}
}
string graph_hash =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
metrics::RecordTFDataFingerprint(graph_hash);
});
}
return Status::OK();
}
} // namespace data
} // namespace tensorflow