| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/grappler/optimizers/data/function_utils.h" |
| |
| #include "tensorflow/core/framework/device_base.h" |
| #include "tensorflow/core/framework/op_def.pb.h" |
| #include "tensorflow/core/util/ptr_util.h" |
| |
| namespace tensorflow { |
| namespace grappler { |
| namespace function_utils { |
| namespace { |
| |
| template <typename Predicate, typename Collection> |
| std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate, |
| const Collection& collection) { |
| std::vector<int> indices = {}; |
| unsigned idx = 0; |
| for (auto&& element : collection) { |
| if (predicate(element)) { |
| indices.push_back(idx); |
| } |
| idx++; |
| } |
| return indices; |
| } |
| |
| } // namespace |
| |
| FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name, |
| const string& output, int position) |
| : node_name(node_name), node_output(output), position(position) { |
| full_str = strings::StrCat(node_name, ":", node_output, ":", position); |
| } |
| |
| FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) { |
| // Parses node_name:node_output:position string into its components. |
| full_str = input; |
| StringPiece capture; |
| StringPiece remaining; |
| |
| // Parse "node_name" |
| if (strings::Scanner(input) |
| .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) |
| .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) |
| .GetResult(&remaining, &capture)) { |
| node_name = string(capture.data(), capture.size()); |
| } |
| |
| // Parse "node_output" if it exists |
| if (strings::Scanner(remaining) |
| .OneLiteral(":") |
| .RestartCapture() |
| .One(strings::Scanner::LETTER) |
| .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE) |
| .GetResult(&remaining, &capture)) { |
| node_output = string(capture.data(), capture.size()); |
| } |
| |
| // Parse "position" if it exists |
| if (strings::Scanner(remaining) |
| .OneLiteral(":") |
| .RestartCapture() |
| .Many(strings::Scanner::DIGIT) |
| .GetResult(nullptr, &capture)) { |
| CHECK(strings::safe_strto32(capture, &position)); |
| } |
| } |
| |
| // TODO(rachelim): Create a utility class similar to MutableGraphView for |
| // FunctionDefs, and use that to manipulate functions. It'll be more |
| // performant if we kept mappings of nodes->inputs/outputs, so that we don't |
| // have to search over all nodes each time. |
| // Note that we're not using GrapplerFunctionItem because it doesn't cover |
| // some of our desired uses (eg changing the outputs of a function), and the |
| // FunctionDef -> GraphDef conversion isn't really necessary in this case. |
| void ReplaceReferences(const string& from, const string& to, |
| FunctionDef* func) { |
| for (NodeDef& n : *func->mutable_node_def()) { |
| std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from, |
| to); |
| } |
| |
| for (auto& p : *func->mutable_ret()) { |
| if (p.second == from) { |
| p.second = to; |
| } |
| } |
| } |
| |
| void AddFunctionOutputWithUniqueName(StringPiece prefix, |
| StringPiece output_tensor_name, |
| FunctionDef* function, DataType dt) { |
| string name = string(prefix); |
| int id = function->signature().output_arg_size(); |
| while (ContainsFunctionOutputWithName(name, *function)) { |
| name = strings::StrCat(prefix, "/_", id); |
| ++id; |
| } |
| auto* output = function->mutable_signature()->mutable_output_arg()->Add(); |
| output->set_name(name); |
| output->set_type(dt); |
| |
| (*function->mutable_ret())[name] = string(output_tensor_name); |
| } |
| |
| NodeDef* AddNode(StringPiece name, StringPiece op, |
| const std::vector<string>& inputs, |
| const std::vector<std::pair<string, AttrValue>>& attributes, |
| FunctionDef* fd) { |
| NodeDef* node = fd->add_node_def(); |
| if (!name.empty()) { |
| node->set_name(string(name)); |
| } else { |
| SetUniqueFunctionNodeName(op, fd, node); |
| } |
| node->set_op(string(op)); |
| for (const string& input : inputs) { |
| node->add_input(input); |
| } |
| for (auto attr : attributes) { |
| (*node->mutable_attr())[attr.first] = attr.second; |
| } |
| return node; |
| } |
| |
| bool ContainsFunctionNodeWithName(StringPiece name, |
| const FunctionDef& function) { |
| return FindFunctionNodeWithName(name, function) != -1; |
| } |
| |
| bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { |
| return FindFunctionNodeWithOp(op, function) != -1; |
| } |
| |
| bool ContainsFunctionOutputWithName(StringPiece name, |
| const FunctionDef& function) { |
| return FindFunctionOutputWithName(name, function) != -1; |
| } |
| |
| int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) { |
| std::vector<int> indices = GetElementIndicesWithPredicate( |
| [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, |
| function.signature().input_arg()); |
| return indices.empty() ? -1 : indices.front(); |
| } |
| |
| int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) { |
| std::vector<int> indices = GetElementIndicesWithPredicate( |
| [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, |
| function.signature().output_arg()); |
| return indices.empty() ? -1 : indices.front(); |
| } |
| |
| int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { |
| std::vector<int> indices = GetElementIndicesWithPredicate( |
| [&name](const NodeDef& node) { return node.name() == name; }, |
| function.node_def()); |
| return indices.empty() ? -1 : indices.front(); |
| } |
| |
| int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { |
| std::vector<int> indices = GetElementIndicesWithPredicate( |
| [&op](const NodeDef& node) { return node.op() == op; }, |
| function.node_def()); |
| |
| return indices.empty() ? -1 : indices.front(); |
| } |
| |
| void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, |
| NodeDef* node) { |
| string name = string(prefix); |
| int id = function->node_def_size(); |
| while (ContainsFunctionNodeWithName(name, *function)) { |
| name = strings::StrCat(prefix, "/_", id); |
| ++id; |
| } |
| node->set_name(std::move(name)); |
| } |
| |
| } // end namespace function_utils |
| } // end namespace grappler |
| } // end namespace tensorflow |