blob: 5e01f4d2d33fd85a2970d6eab6fff2cd4fd04c24 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include <algorithm>
#include <queue>
#include <utility>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
const Node* FindNodeByName(const string& name, const Graph& graph) {
for (const Node* node : graph.nodes()) {
CHECK_NOTNULL(node);
if (node->name() == name) {
return node;
}
}
return nullptr;
}
std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
const std::vector<string>& node_names_and_ports) {
std::unordered_set<string> retval;
for (const string& node_name_and_port : node_names_and_ports) {
const TensorId tid = ParseTensorName(node_name_and_port);
retval.emplace(tid.first);
}
return retval;
}
Node* FindMutableNodeByName(const string& name, Graph* graph) {
for (Node* node : graph->nodes()) {
if (node != nullptr && node->name() == name) {
return node;
}
}
return nullptr;
}
const NodeDef* FindNodeDefByName(const string& input,
const GraphDef& graph_def) {
const TensorId tid = ParseTensorName(input);
const string name = string(tid.first);
for (const NodeDef& node_def : graph_def.node()) {
if (node_def.name() == name) {
return &node_def;
}
}
return nullptr;
}
bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
TensorId* tid) {
CHECK_NOTNULL(tid);
*tid = ParseTensorName(node_name_and_port);
if (node_def.name() == tid->first) {
return true;
}
return false;
}
bool ContainsSameTensorId(const string& tensor_name,
const std::vector<string>& tensor_names) {
const TensorId tid0 = ParseTensorName(tensor_name);
for (const string& name : tensor_names) {
const TensorId tid1 = ParseTensorName(name);
if (tid0.first == tid1.first && tid0.second == tid1.second) {
return true;
}
}
return false;
}
void AppendDeliminator(string* str) {
CHECK_NOTNULL(str);
if (!str->empty()) {
*str += ":";
}
}
void ConvertMapToVector(const std::unordered_map<int, string>& in,
std::vector<string>* out) {
CHECK_NOTNULL(out);
out->resize(in.size());
for (size_t i = 0; i < in.size(); ++i) {
CHECK(in.count(i) > 0);
out->at(i) = in.at(i);
}
}
string DumpGraphDef(const GraphDef& graph_def) {
string out;
for (const NodeDef& node : graph_def.node()) {
out += strings::StrCat("node: ", node.name(), "\n input: ");
for (const string& input : node.input()) {
out += strings::StrCat(input, ", ");
}
out += "\n";
}
return out;
}
string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
string out;
out += "Nodes:\n";
for (const string& str : std::get<0>(cluster)) {
out += str + ", ";
}
out += "\nInput border:\n";
for (const string& str : std::get<1>(cluster)) {
out += str + ", ";
}
out += "\nOutput border:\n";
for (const string& str : std::get<2>(cluster)) {
out += str + ", ";
}
return out;
}
} // namespace
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
const string& name, ExecutorBuildFunc executor_build_func) {
ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
executor_build_registry[name] = std::move(executor_build_func);
}
/* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
if (executor_build_registry.count(name) <= 0) {
return nullptr;
}
return &executor_build_registry.at(name);
}
/* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
static ExecutorBuildRegistry executor_builder_registry;
return &executor_builder_registry;
}
/**
* - DryRunInference
* To determine shapes of output tensors of all nodes, dryrun the graph.
* This function supplies memory allocation information when loading
* the graph. This function is used to verify shape inference and actual
* output shape.
*/
/* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const bool initialize_by_zero,
std::vector<tensorflow::Tensor>* output_tensors) {
// Create input tensor vector. If "initialize_by_zero" is true,
// input tensor fields are initialized by 0.
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
for (const std::pair<string, Tensor>& input : input_node_info_list) {
CHECK(input.second.IsInitialized());
if (!initialize_by_zero) {
input_tensors.push_back({input.first, input.second});
continue;
}
// If input tensor is not initialized, initialize by 0-filling
const DataType data_type = input.second.dtype();
const TensorShape& shape = input.second.shape();
Tensor input_tensor(data_type, shape);
switch (data_type) {
case DT_INT32: {
auto int_tensor = input_tensor.flat<int32>();
int_tensor = int_tensor.constant(0);
break;
}
case DT_FLOAT: {
auto float_tensor = input_tensor.flat<float>();
float_tensor = float_tensor.constant(0.0f);
break;
}
case DT_QUINT8: {
auto int_tensor = input_tensor.flat<quint8>();
int_tensor = int_tensor.constant(0);
break;
}
default:
LOG(FATAL) << "Unsupported input type: " << data_type;
}
input_tensors.push_back({input.first, input_tensor});
}
// Setup session
CHECK(output_tensors != nullptr);
SessionOptions session_options;
session_options.env = Env::Default();
std::unique_ptr<Session> session =
std::unique_ptr<Session>(NewSession(session_options));
Status status = session->Create(graph_def);
if (!status.ok()) {
return status;
}
// Setup session arguments
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
// Run inference with all node as output
status = session->Run(run_options, input_tensors, output_node_names, {},
output_tensors, &run_metadata);
if (!status.ok()) {
LOG(ERROR) << "Error during inference: " << status;
return status;
}
return Status();
}
/* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const bool initialize_by_zero,
RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
CHECK(tensor_shape_map != nullptr);
std::vector<Tensor> output_tensors;
output_tensors.reserve(graph_def.node_size());
std::vector<string> output_node_names;
Graph graph(OpRegistry::Global());
Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
if (!status.ok()) {
return status;
}
for (const Node* node : graph.nodes()) {
if (IsInputNode(input_node_info_list, node->name())) {
continue;
}
for (int i = 0; i < node->num_outputs(); ++i) {
output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
}
}
status = DryRunInference(graph_def, input_node_info_list, output_node_names,
initialize_by_zero, &output_tensors);
if (!status.ok()) {
VLOG(1) << "Failed to dryrun " << status;
return status;
}
CHECK_EQ(output_node_names.size(), output_tensors.size())
<< output_node_names.size() << ", " << output_tensors.size();
// Append output tensor of input node in advance to create a map
// to avoid memory reallocation inside vector
for (const std::pair<string, Tensor>& input_node_info :
input_node_info_list) {
output_tensors.push_back(input_node_info.second);
}
for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
const string& name = output_node_names.at(i);
const Tensor& tensor = output_tensors.at(i);
EmplaceTensorShapeType(name, tensor, tensor_shape_map);
}
for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
const string& name = input_node_info_list.at(i).first;
const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
EmplaceTensorShapeType(name, tensor, tensor_shape_map);
}
CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
output_tensors.size());
return status;
}
/* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
const string& node_name) {
for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
const TensorId tid = ParseTensorName(pair.first);
if (node_name == tid.first) {
return true;
}
}
return false;
}
/* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
const std::vector<tensorflow::Tensor>& output_tensors,
TensorShapeMap* tensor_shape_map) {
CHECK_NE(tensor_shape_map, nullptr);
tensor_shape_map->clear();
tensor_shape_map->reserve(input_node_info_list.size() +
output_node_names.size());
const int output_node_count = output_node_names.size();
CHECK_EQ(output_node_count, output_tensors.size());
for (int i = 0; i < output_node_count; ++i) {
const string& node_name = output_node_names.at(i);
const Tensor& tensor = output_tensors.at(i);
EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
}
}
/* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
const TensorProto& tensor_proto, Tensor* tensor) {
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
Tensor parsed(tensor_proto.dtype());
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
*tensor = parsed;
return Status::OK();
}
}
return errors::InvalidArgument("Cannot parse tensor from proto");
}
/* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
const std::vector<DataType>& data_types,
const std::vector<TensorShape>& shapes, NodeDef* node_def) {
AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
return true;
}
/* static */ Status
RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
CHECK_NE(node_def, nullptr);
std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
auto its = tensor_shape_map.equal_range(node_def->name());
for (auto it = its.first; it != its.second; ++it) {
queue.emplace(std::make_tuple(it->second.first, &it->second.second));
}
int last_port = queue.size();
std::vector<DataType> data_types;
std::vector<TensorShape> shapes;
while (!queue.empty()) {
const int port = std::get<0>(queue.top());
const TensorShapeType* tst = std::get<1>(queue.top());
CHECK_NE(tst, nullptr);
data_types.emplace(data_types.begin(), tst->first);
shapes.emplace(shapes.begin(), tst->second);
CHECK_EQ(last_port - 1, port);
last_port = port;
queue.pop();
}
AddOutputTensorShapeType(data_types, shapes, node_def);
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
AttrSlice attrs, std::vector<DataType>* data_types,
std::vector<TensorShape>* shapes) {
Status status;
if (data_types != nullptr) {
status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
}
if (!status.ok()) {
return status;
}
if (shapes != nullptr) {
status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
if (status.ok() && data_types != nullptr) {
CHECK_EQ(data_types->size(), shapes->size());
}
}
return status;
}
/* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
TensorShape* shape) {
std::vector<DataType> data_types;
std::vector<TensorShape> shapes;
const TensorId tid = ParseTensorName(name_and_port);
const string node_name(tid.first);
const int port = tid.second;
const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
CHECK_NOTNULL(node_def);
GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
if (data_types.empty()) {
return false;
}
CHECK(data_types.size() > port);
*data_type = data_types.at(port);
*shape = shapes.at(port);
return true;
}
/* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
Graph* graph, ShapeRefiner* shape_refiner) {
Status status;
auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
if (!status.ok()) {
return;
}
CHECK_NE(node, nullptr);
// If we visit an input node, we use the shape provided and set the
// shape accordingly.
bool is_input_node = false;
for (const std::pair<string, Tensor>& input_node_info :
input_node_info_list) {
if (node->name() == input_node_info.first) {
shape_inference::InferenceContext* context =
shape_refiner->GetContext(node);
shape_inference::ShapeHandle handle;
status = context->MakeShapeFromTensorShape(
input_node_info.second.shape(), &handle);
if (!status.ok()) {
break;
}
status = shape_refiner->SetShape(node, 0, handle);
if (!status.ok()) {
break;
}
is_input_node = true;
}
if (!status.ok()) {
break;
}
}
// If not an input node call AddNode() that recomputes the shape.
if (!is_input_node && status.ok()) {
status = shape_refiner->AddNode(node);
}
if (!status.ok()) {
VLOG(1) << "Shape inference failed for node: " << node->name();
}
};
ReverseDFS(*graph, {}, visit);
return status;
}
/* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
const Graph& graph, const ShapeRefiner& shape_refiner,
TensorShapeMap* tensor_shape_map) {
for (int i = 0; i < graph.num_node_ids(); ++i) {
const Node* node = graph.FindNodeId(i);
CHECK_NE(node, nullptr);
for (int j = 0; j < node->num_outputs(); ++j) {
const int output_index = j;
const DataType dt = node->output_type(output_index);
shape_inference::InferenceContext* context =
shape_refiner.GetContext(node);
CHECK_NE(context, nullptr);
shape_inference::ShapeHandle shape_handle = context->output(output_index);
if (context->RankKnown(shape_handle)) {
TensorShape ts;
for (int k = 0; k < context->Rank(shape_handle); ++k) {
shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
CHECK(context->ValueKnown(dh));
ts.AddDim(context->Value(dh));
}
const string& node_name = node->name();
CHECK(tensor_shape_map->count(node_name) == 0);
tensor_shape_map->emplace(node_name,
std::make_pair(j, std::make_pair(dt, ts)));
} else {
return errors::InvalidArgument("Graph contains unknow shapes");
}
}
}
return Status::OK();
}
/* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
RemoteFusedGraphExecuteUtils::GetTensorShapeType(
const TensorShapeMap& tensor_shape_map, const string& node_name) {
if (node_name.find(':') != string::npos) {
const TensorId tid = ParseTensorName(node_name);
return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
} else {
return GetTensorShapeType(tensor_shape_map, node_name, 0);
}
}
/* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
RemoteFusedGraphExecuteUtils::GetTensorShapeType(
const TensorShapeMap& tensor_shape_map, const string& node_name,
const int port) {
CHECK_EQ(node_name.find(':'), string::npos);
if (tensor_shape_map.count(node_name) <= 0) {
return nullptr;
}
auto its = tensor_shape_map.equal_range(node_name);
for (auto it = its.first; it != its.second; ++it) {
if (it->second.first == port) {
return &it->second.second;
}
}
return nullptr;
}
/* static */ void
RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
const RemoteFusedGraphExecuteInfo& proto,
std::vector<std::pair<string, Tensor>>* inputs,
std::vector<string>* outputs) {
CHECK_EQ(proto.graph_input_node_name_size(),
proto.default_graph_input_tensor_shape_size());
for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
inputs->emplace_back(
proto.graph_input_node_name(i),
Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
}
for (const string& output_node_name : proto.graph_output_node_name()) {
outputs->emplace_back(output_node_name);
}
}
/* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
const string& name, const Tensor& tensor,
TensorShapeMap* tensor_shape_map) {
const TensorId tid = ParseTensorName(name);
CHECK_EQ(tensor_shape_map->count(name), 0);
tensor_shape_map->emplace(
string(tid.first),
std::make_pair(tid.second,
std::make_pair(tensor.dtype(), tensor.shape())));
}
/* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
const std::vector<std::pair<string, Tensor>>& input_tensors,
const bool dry_run_inference, GraphDef* graph_def) {
TensorShapeMap tensor_shape_map;
if (dry_run_inference) {
TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
/*initialize_by_zero=*/true,
&tensor_shape_map));
} else {
ImportGraphDefOptions opts;
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
TF_RETURN_IF_ERROR(
ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
&graph, &shape_refiner));
TF_RETURN_IF_ERROR(
BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
}
for (NodeDef& node_def : *graph_def->mutable_node()) {
TF_RETURN_IF_ERROR(
AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
}
return Status::OK();
}
/* static */ Status
RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
const string& executor_name, const GraphDef& subgraph_def,
const std::vector<string>& inputs, const std::vector<string>& outputs,
const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
DataTypeVector* input_types, DataTypeVector* output_types) {
CHECK_NOTNULL(execute_info);
CHECK_NOTNULL(input_types);
CHECK_NOTNULL(output_types);
execute_info->Clear();
execute_info->set_executor_name(executor_name);
// copy graph
*execute_info->mutable_remote_graph() = subgraph_def;
for (const string& input : inputs) {
DataType dt;
TensorShape shape;
const bool has_shapetype =
GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
execute_info->add_graph_input_node_name(input);
if (has_shapetype) {
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
*execute_info->add_default_graph_input_tensor_shape();
tensor_shape_type.set_dtype(dt);
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
for (const int64 dim : shape.dim_sizes()) {
tensor_shape_proto.add_dim()->set_size(dim);
}
input_types->push_back(dt);
} else {
CHECK(!require_shape_type)
<< "No shape type found for " << input << DumpGraphDef(subgraph_def);
// Assuming input type is float if no data provided.
input_types->push_back(DT_FLOAT);
}
}
for (const string& output : outputs) {
DataType dt;
TensorShape shape;
const bool has_shapetype =
GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
execute_info->add_graph_output_node_name(output);
if (has_shapetype) {
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
tensor_shape_type_proto =
*execute_info->add_default_graph_output_tensor_shape();
tensor_shape_type_proto.set_dtype(dt);
TensorShapeProto& tensor_shape_proto =
*tensor_shape_type_proto.mutable_shape();
for (const int64 dim : shape.dim_sizes()) {
tensor_shape_proto.add_dim()->set_size(dim);
}
output_types->push_back(dt);
} else {
CHECK(!require_shape_type)
<< "No shape type found for " << output << DumpGraphDef(subgraph_def);
// Assuming output type is float if no data provided.
output_types->push_back(DT_FLOAT);
}
}
return Status::OK();
}
/* static */ Status
RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
const string& node_name, const string& executor_name,
const GraphDef& subgraph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, const bool require_shape_type,
Graph* graph, Node** created_node) {
CHECK_NOTNULL(graph);
CHECK_NOTNULL(created_node);
RemoteFusedGraphExecuteInfo execute_info;
DataTypeVector input_types;
DataTypeVector output_types;
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
executor_name, subgraph_def, inputs, outputs, require_shape_type,
&execute_info, &input_types, &output_types));
std::vector<NodeBuilder::NodeOut> node_out_list;
for (const string& input : inputs) {
const TensorId tid = ParseTensorName(input);
Node* node = FindMutableNodeByName(string(tid.first), graph);
CHECK_NOTNULL(node);
node_out_list.emplace_back(node, tid.second);
}
const string execute_info_str = execute_info.SerializeAsString();
auto builder =
NodeBuilder(node_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types)
.Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
const string& node_name, const string& input_node_name,
const int input_node_port, const DataType dt, Graph* graph,
Node** created_node) {
Node* node = FindMutableNodeByName(input_node_name, graph);
CHECK_NOTNULL(node);
NodeBuilder::NodeOut node_out(node, input_node_port);
auto builder =
NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
const std::unordered_set<string>& node_names, const GraphDef& graph_def,
std::vector<ClusterInfo>* cluster_infos) {
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
std::unordered_set<string> remaining_nodes = node_names;
while (!remaining_nodes.empty()) {
ClusterInfo ci;
// Determine one cluster nodes
std::unordered_set<const Node*> visited;
std::deque<const Node*> queue;
queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
while (!queue.empty()) {
const Node* node = queue.front();
CHECK_NOTNULL(node);
queue.pop_front();
const string& node_name = node->name();
if (node_names.count(node_name) > 0) {
std::get<0>(ci).emplace(node_name);
remaining_nodes.erase(node_name);
} else {
// Edge of subgraph. Do nothing.
continue;
}
for (const Node* in : node->in_nodes()) {
if (visited.insert(in).second) {
queue.push_back(in);
}
}
for (const Node* out : node->out_nodes()) {
if (visited.insert(out).second) {
queue.push_back(out);
}
}
}
// Determine one cluster border
std::vector<string>& border_inputs = std::get<1>(ci);
std::vector<string>& border_outputs = std::get<2>(ci);
for (const string& node_name : node_names) {
Node* node = FindMutableNodeByName(node_name, &graph);
CHECK_NOTNULL(node);
int input_count = 0;
for (const Edge* in_edge : node->in_edges()) {
const Node* src_node = in_edge->src();
const bool src_is_outside =
node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
if (src_is_outside) {
const string src_name =
strings::StrCat(src_node->name(), ":", in_edge->src_output());
CHECK_EQ(1, src_node->num_outputs())
<< "output count of input border node must be one."
<< src_node->name();
if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
border_inputs.end()) {
border_inputs.emplace_back(src_name);
}
} else {
++input_count;
}
}
CHECK(input_count == 0 || input_count == node->in_edges().size())
<< "Invalid input_count(" << input_count << ", "
<< node->in_edges().size() << ") " << node_name;
for (const Edge* out_edge : node->out_edges()) {
const Node* dst_node = out_edge->dst();
CHECK_NOTNULL(dst_node);
const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
const string dst_name =
strings::StrCat(node->name(), ":", out_edge->src_output());
if (dst_is_outside) {
if (dst_node->IsSink()) {
CHECK_EQ(1, node->num_outputs())
<< "If you want to specify output node as subgraph output node "
<< "the output count of the node must be 1 "
<< "because that node is replaced by identity node.";
const string identity_dst_name =
strings::StrCat(node->name(), ":", 0);
if (std::find(border_outputs.begin(), border_outputs.end(),
identity_dst_name) == border_outputs.end()) {
border_outputs.emplace_back(identity_dst_name);
}
} else {
if (std::find(border_outputs.begin(), border_outputs.end(),
dst_name) == border_outputs.end()) {
border_outputs.emplace_back(dst_name);
}
}
}
}
}
cluster_infos->emplace_back(ci);
VLOG(1) << DumpCluster(ci);
}
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
const ClusterInfo& cluster, const GraphDef& graph_def,
GraphDef* subgraph_def) {
const std::unordered_set<string>& node_names = std::get<0>(cluster);
const std::unordered_set<string>& border_input_names =
BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
for (Node* node : graph.nodes()) {
if (node != nullptr && node_names.count(node->name()) <= 0 &&
border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
!node->IsSink()) {
graph.RemoveNode(node);
}
}
graph.ToGraphDef(subgraph_def);
for (const string& subgraph_input : std::get<1>(cluster)) {
const TensorId tid = ParseTensorName(subgraph_input);
const string subgraph_input_name(tid.first);
const int subgraph_input_port = tid.second;
const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
CHECK_NOTNULL(node_def);
std::vector<DataType> dt_vec;
std::vector<TensorShape> shape_vec;
GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
const DataType& dt =
dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
const TensorShape& shape =
shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
shape, subgraph_def));
}
// sort subgraph_def to align order in graph_def
std::unordered_map<string, int> name_to_id_map;
for (int i = 0; i < graph_def.node_size(); ++i) {
name_to_id_map.emplace(graph_def.node(i).name(), i);
}
std::sort(subgraph_def->mutable_node()->begin(),
subgraph_def->mutable_node()->end(),
[&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
CHECK(name_to_id_map.count(node0.name()) > 0);
CHECK(name_to_id_map.count(node1.name()) > 0);
const int id0 = name_to_id_map.at(node0.name());
const int id1 = name_to_id_map.at(node1.name());
return id0 < id1;
});
VLOG(1) << DumpGraphDef(*subgraph_def);
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
const std::vector<string>& border_inputs,
const std::vector<string>& border_outputs, const GraphDef& graph_def,
ClusterInfo* cluster) {
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
std::unordered_set<const Node*> visited;
std::deque<const Node*> queue;
for (const string& output : border_outputs) {
const TensorId tid = ParseTensorName(output);
const string output_node_name(tid.first);
for (const Node* node : graph.nodes()) {
if (output_node_name == node->name()) {
queue.push_back(node);
visited.insert(node);
}
}
}
std::unordered_set<const Node*> border_input_nodes;
// propagate visit to parent nodes until input nodes
while (!queue.empty()) {
const Node* node = queue.front();
queue.pop_front();
for (const Edge* edge : node->in_edges()) {
const Node* src_node = edge->src();
CHECK_NOTNULL(src_node);
const int src_port = edge->src_output();
bool input_found = false;
for (const string& input : border_inputs) {
const TensorId tid = ParseTensorName(input);
if (tid.first == src_node->name() && tid.second == src_port) {
input_found = true;
border_input_nodes.insert(src_node);
}
}
if (visited.insert(src_node).second) {
if (!input_found) {
queue.push_back(src_node);
}
}
}
}
for (const Node* node : visited) {
if (node != nullptr && !node->IsSource() && !node->IsSink() &&
border_input_nodes.count(node) <= 0) {
std::get<0>(*cluster).insert(node->name());
}
}
std::get<1>(*cluster) = border_inputs;
std::get<2>(*cluster) = border_outputs;
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs,
const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
const string& remote_graph_executor_name, const bool require_shape_type,
GraphDef* output_graph_def) {
LOG(INFO) << "Transforming quantized stripped model to a remote fused "
"graph execute op by fusing a specified subgraph...";
CHECK(!remote_graph_executor_name.empty());
const std::vector<string>& border_inputs = std::get<1>(cluster);
const std::vector<string>& border_outputs = std::get<2>(cluster);
GraphDef subgraph_def;
TF_RETURN_IF_ERROR(
BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
TF_RETURN_IF_ERROR(
ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
Node* fused_node;
TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
for (const Node* node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
const Edge* edge = nullptr;
TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
for (int j = 0; j < border_outputs.size(); ++j) {
const string& output = border_outputs.at(j);
const TensorId tid = ParseTensorName(output);
const string output_name(tid.first);
Node* src_node = edge->src();
if (src_node != nullptr && src_node->name() == output_name &&
edge->src_output() == tid.second) {
// Source node is replaced by new fused node.
Node* dst_node = edge->dst();
const int dst_input = edge->dst_input();
LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
<< " from " << edge->src()->name();
graph.RemoveEdge(edge);
graph.AddEdge(fused_node, j, dst_node, dst_input);
}
}
}
}
// Replace output nodes by identity nodes which forward outputs from
// RemoteFusedGraphExecuteOpNode
for (const string& output : outputs) {
const TensorId output_tid = ParseTensorName(output);
const string output_name(output_tid.first);
for (size_t i = 0; i < border_outputs.size(); ++i) {
const TensorId subgraph_output_tid =
ParseTensorName(border_outputs.at(i));
const string subgraph_output_name(subgraph_output_tid.first);
if (output_name == subgraph_output_name) {
LOG(INFO) << "As graph output and subgraph output are same, "
<< "the graph output node is replaced by identity node";
Node* original_output_node = FindMutableNodeByName(output_name, &graph);
CHECK_NOTNULL(original_output_node);
CHECK_EQ(1, original_output_node->num_outputs())
<< "Num outputs should be 1 for " << output << ".";
graph.RemoveNode(original_output_node);
Node* new_node;
TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
remote_fused_graph_node_name, i,
DT_FLOAT, &graph, &new_node));
CHECK_NOTNULL(new_node);
}
}
}
GraphDef result_graph_def;
graph.ToGraphDef(&result_graph_def);
ClusterInfo graph_cluster;
TF_RETURN_IF_ERROR(
BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
// Remove unvisited nodes
TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
output_graph_def));
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs,
const string& remote_fused_graph_node_name_prefix,
const std::unordered_set<string>& subgraph_nodes,
const string& remote_fused_graph_executor_name,
const bool require_shape_type, GraphDef* output_graph_def) {
std::vector<ClusterInfo> ci_vec;
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
subgraph_nodes, input_graph_def, &ci_vec));
for (size_t i = 0; i < ci_vec.size(); ++i) {
const string remote_fused_graph_node_name =
strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
remote_fused_graph_node_name, ci_vec.at(i),
remote_fused_graph_executor_name,
require_shape_type, output_graph_def));
}
return Status::OK();
}
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs,
const string& remote_fused_graph_node_name,
const std::vector<string>& border_inputs,
const std::vector<string>& border_outputs,
const string& remote_graph_executor_name, const bool require_shape_type,
GraphDef* output_graph_def) {
ClusterInfo cluster;
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
border_inputs, border_outputs, input_graph_def, &cluster));
return FuseCluster(
input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
remote_graph_executor_name, require_shape_type, output_graph_def);
}
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs,
const string& remote_fused_graph_node_name_prefix,
const std::unordered_set<string>& fused_op_types,
const string& remote_fused_graph_executor_name,
const bool require_shape_type, GraphDef* output_graph_def) {
const std::unordered_set<string> fused_nodes_filtered_by_op_types =
BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
return FuseRemoteGraphByNodeNames(
input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
require_shape_type, output_graph_def);
}
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, const string& executor_name,
GraphDef* output_graph_def) {
const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
if (build_func == nullptr) {
return errors::InvalidArgument("Unknown executor name: " + executor_name);
}
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
TF_RETURN_IF_ERROR((*build_func)(&executor));
CHECK_NOTNULL(executor.get());
if (!executor->IsEnabled()) {
// As this executor is not enabled, just return original graph as is.
*output_graph_def = input_graph_def;
return Status::OK();
}
return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
output_graph_def);
}
/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
const std::vector<string>& inputs, const std::vector<string>& outputs,
const std::unordered_set<string>& fused_node_names,
const std::vector<string>& border_inputs,
const std::vector<string>& border_outputs,
const std::unordered_set<string>& fused_op_types,
const string& remote_fused_graph_node_name,
const string& remote_graph_executor_name, GraphDef* graph_def) {
CHECK_NOTNULL(graph_def);
const std::unordered_set<string> fused_nodes_filtered_by_op_types =
BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
for (NodeDef& node_def : *graph_def->mutable_node()) {
string attr_str;
TensorId tid;
for (size_t i = 0; i < inputs.size(); ++i) {
if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i,
remote_graph_executor_name,
remote_fused_graph_node_name);
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i);
}
}
for (const string& fused_node_name : fused_node_names) {
if (fused_node_name == node_def.name()) {
AppendDeliminator(&attr_str);
attr_str += BuildNodeTypeAttr(FUSED_NODE);
}
}
for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
if (fused_node_name == node_def.name()) {
AppendDeliminator(&attr_str);
attr_str += BuildNodeTypeAttr(FUSED_NODE);
}
}
for (size_t i = 0; i < border_inputs.size(); ++i) {
if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i);
}
}
for (size_t i = 0; i < border_outputs.size(); ++i) {
if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i);
}
}
if (attr_str.empty()) {
attr_str += BuildNodeTypeAttr(UNUSED);
}
AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
}
return Status::OK();
}
/* static */ Status
RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
const GraphDef& input_graph_def,
const std::vector<std::pair<string, Tensor>>& input_tensors,
GraphDef* output_graph_def) {
std::unordered_map<int, string> input_map;
std::unordered_map<int, string> output_map;
std::unordered_set<string> fused_node_names;
std::unordered_map<int, string> border_input_map;
std::unordered_map<int, string> border_output_map;
string remote_graph_executor_name;
string remote_fused_graph_node_name;
for (const NodeDef& node_def : input_graph_def.node()) {
string attr_str;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
std::vector<std::vector<string>> attr_strs;
for (const string& str : str_util::Split(attr_str, ":")) {
attr_strs.emplace_back(str_util::Split(str, ","));
}
if (attr_strs.empty()) {
return errors::InvalidArgument("Remote graph node type not found.");
}
for (const std::vector<string>& attr : attr_strs) {
if (attr.empty()) {
return errors::InvalidArgument("Empty remote graph node type attr.");
}
int node_type_int;
CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
const RemoteFusedGraphNodeType node_type =
static_cast<RemoteFusedGraphNodeType>(node_type_int);
const string& name = node_def.name();
int port;
int index;
switch (node_type) {
case GRAPH_INPUT:
VLOG(2) << "Graph input: " << name;
CHECK_EQ(5, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
CHECK(!attr.at(3).empty());
remote_graph_executor_name = attr.at(3);
CHECK(!attr.at(4).empty());
remote_fused_graph_node_name = attr.at(4);
input_map.emplace(index, strings::StrCat(name, ":", port));
if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
LOG(INFO) << "Executor for " << remote_graph_executor_name
<< " not registered. Do not fuse.";
*output_graph_def = input_graph_def;
return Status::OK();
}
break;
case GRAPH_OUTPUT:
VLOG(2) << "Graph output: " << name;
CHECK_EQ(3, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
output_map.emplace(index, strings::StrCat(name, ":", port));
break;
case FUSED_NODE:
VLOG(2) << "Fused node: " << name;
CHECK_EQ(1, attr.size());
fused_node_names.emplace(name);
break;
case BORDER_INPUT:
VLOG(2) << "Border input: " << name;
CHECK_EQ(3, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
border_input_map.emplace(index, strings::StrCat(name, ":", port));
break;
case BORDER_OUTPUT:
VLOG(2) << "Border output: " << name;
CHECK_EQ(3, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
border_output_map.emplace(index, strings::StrCat(name, ":", port));
break;
case UNUSED:
// do nothing
break;
default:
// unsupported value
LOG(FATAL);
}
}
}
bool require_shape_type = false;
std::vector<string> inputs;
std::vector<string> outputs;
std::vector<string> border_inputs;
std::vector<string> border_outputs;
ConvertMapToVector(input_map, &inputs);
ConvertMapToVector(output_map, &outputs);
ConvertMapToVector(border_input_map, &border_inputs);
ConvertMapToVector(border_output_map, &border_outputs);
if (!input_tensors.empty()) {
bool input_match = false;
if (inputs.size() == input_tensors.size()) {
for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
if (!ContainsSameTensorId(input_tensor.first, inputs)) {
break;
}
DataType data_type;
TensorShape shape;
if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
&data_type, &shape)) {
if (data_type == input_tensor.second.dtype() &&
shape == input_tensor.second.shape()) {
VLOG(2) << "Input matched!";
// Shape type matched.
input_match = true;
require_shape_type = true;
}
} else {
// Shape type not required.
input_match = true;
}
}
}
if (!input_match) {
// Input mismatch. Just copy original graph
*output_graph_def = input_graph_def;
return Status::OK();
}
}
if (!fused_node_names.empty()) {
TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
input_graph_def, inputs, outputs, remote_fused_graph_node_name,
fused_node_names, remote_graph_executor_name, require_shape_type,
output_graph_def));
} else if (!border_inputs.empty() || !border_outputs.empty()) {
TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
input_graph_def, inputs, outputs, remote_fused_graph_node_name,
border_inputs, border_outputs, remote_graph_executor_name,
require_shape_type, output_graph_def));
} else {
*output_graph_def = input_graph_def;
}
return Status::OK();
}
/* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_tensors) {
for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
if (node_def == nullptr) {
return false;
}
string attr;
const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
if (!status.ok() || attr.empty()) {
return false;
}
}
return true;
}
/* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
const void* src_ptr, const int src_size, Tensor* tensor) {
CHECK(tensor->TotalBytes() >= src_size)
<< tensor->TotalBytes() << ", " << src_size;
void* dst_ptr;
switch (tensor->dtype()) {
case DT_FLOAT:
dst_ptr = tensor->flat<float>().data();
break;
case DT_DOUBLE:
dst_ptr = tensor->flat<double>().data();
break;
case DT_INT32:
dst_ptr = tensor->flat<int32>().data();
break;
case DT_UINT8:
dst_ptr = tensor->flat<uint8>().data();
break;
case DT_INT16:
dst_ptr = tensor->flat<int16>().data();
break;
case DT_INT8:
dst_ptr = tensor->flat<int8>().data();
break;
case DT_STRING:
dst_ptr = tensor->flat<tstring>().data();
break;
case DT_INT64:
dst_ptr = tensor->flat<int64>().data();
break;
case DT_BOOL:
dst_ptr = tensor->flat<bool>().data();
break;
case DT_QINT8:
dst_ptr = tensor->flat<qint8>().data();
break;
case DT_QUINT8:
dst_ptr = tensor->flat<quint8>().data();
break;
case DT_QINT32:
dst_ptr = tensor->flat<qint32>().data();
break;
case DT_BFLOAT16:
dst_ptr = tensor->flat<bfloat16>().data();
break;
case DT_QINT16:
dst_ptr = tensor->flat<qint16>().data();
break;
case DT_QUINT16:
dst_ptr = tensor->flat<quint16>().data();
break;
case DT_UINT16:
dst_ptr = tensor->flat<uint16>().data();
break;
default:
LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
break;
}
CHECK_NOTNULL(dst_ptr);
std::memcpy(dst_ptr, src_ptr, src_size);
return Status::OK();
}
/* static */ std::unordered_set<string>
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
std::unordered_set<string> retval;
for (const NodeDef& node_def : graph_def.node()) {
if (op_types.count(node_def.op()) > 0) {
retval.emplace(node_def.name());
}
}
return retval;
}
/* static */ std::unordered_set<string>
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
const GraphDef& graph_def,
const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
std::unordered_set<string> retval;
for (const NodeDef& node_def : graph_def.node()) {
std::vector<DataType> dt_vec;
std::vector<TensorShape> shape_vec;
const Status status =
GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
if (!status.ok()) {
shape_vec.clear();
}
if (ops_definitions.GetOpIdFor(
node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
retval.emplace(node_def.name());
}
}
return retval;
}
/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
const string& input, const DataType type, const TensorShape& shape,
GraphDef* graph_def) {
const TensorId tid = ParseTensorName(input);
CHECK_EQ(0, tid.second);
const string node_name(tid.first);
for (NodeDef& node : *graph_def->mutable_node()) {
if (node.name() != node_name) {
continue;
}
if (node.op() == "Placeholder") {
return Status::OK();
} else {
NodeDef placeholder_node;
placeholder_node.set_op("Placeholder");
placeholder_node.set_name(node_name);
AddNodeAttr("dtype", type, &placeholder_node);
AddNodeAttr("shape", shape, &placeholder_node);
// TODO(satok): Remove once we merge attributes
AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
node.Clear();
node = placeholder_node;
return Status::OK();
}
}
return errors::InvalidArgument(
strings::StrCat(node_name, " not found for replacement."));
}
/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
const RemoteFusedGraphNodeType node_type, const int port, const int index,
const string& executor_name, const string& node_name) {
return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
",", executor_name, ",", node_name);
}
/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
const RemoteFusedGraphNodeType node_type, const int port, const int index) {
return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
}
/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
const RemoteFusedGraphNodeType node_type) {
return strings::StrCat(static_cast<int>(node_type));
}
} // namespace tensorflow