| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" |
| |
| #include <utility> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/types/optional.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/Casting.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/Function.h" // from @llvm-project |
| #include "mlir/IR/Identifier.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/Module.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassManager.h" // from @llvm-project |
| #include "mlir/Support/DebugStringHelper.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/graph_to_functiondef.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/tensor_id.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| |
| namespace tensorflow { |
| using llvm::dyn_cast; |
| using llvm::isa; |
| using mlir::BlockArgument; |
| using mlir::Dialect; |
| using mlir::Operation; |
| using mlir::Value; |
| using stream_executor::port::StatusOr; |
| |
| namespace { |
| |
| constexpr char kInvalidExecutorGraphMsg[] = |
| "Functions must be of a single Graph with single op Islands: "; |
| |
| constexpr char kDeviceAttr[] = "tf.device"; |
| constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; |
| |
| bool IsLegalChar(char c, bool first_char) { |
| if (isalpha(c)) return true; |
| if (isdigit(c)) return true; |
| if (c == '.') return true; |
| if (c == '_') return true; |
| |
| // First character of a node name can only be a letter, digit, dot or |
| // underscore. |
| if (first_char) return false; |
| |
| if (c == '/') return true; |
| if (c == '-') return true; |
| |
| return false; |
| } |
| |
| // Convert characters in name that are considered illegal in TensorFlow Node |
| // name to '.'. |
| std::string LegalizeNodeName(llvm::StringRef name) { |
| assert(!name.empty() && "expected non-empty name"); |
| |
| std::string legalized_name; |
| bool first = true; |
| for (auto c : name) { |
| if (IsLegalChar(c, first)) { |
| legalized_name += c; |
| } else { |
| legalized_name += '.'; |
| } |
| first = false; |
| } |
| |
| return legalized_name; |
| } |
| |
| // OpOrArgLocNameMapper that legalizes the returned name. |
| class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper { |
| private: |
| std::string GetName(OpOrVal op_or_val) override { |
| return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val)); |
| } |
| }; |
| |
| // Checks functions in module are of single tf_executor.graph and each |
| // tf_executor.island in tf_executor.graph only has a single op. |
| Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) { |
| Status status = Status::OK(); |
| module.walk([&](mlir::FuncOp function) { |
| if (!llvm::hasSingleElement(function)) { |
| status = errors::FailedPrecondition( |
| kInvalidExecutorGraphMsg, |
| "only single block functions are supported."); |
| return mlir::WalkResult::interrupt(); |
| } |
| |
| auto block = function.front().without_terminator(); |
| auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin()); |
| if (!graph) { |
| status = errors::FailedPrecondition( |
| kInvalidExecutorGraphMsg, |
| "first op in function is not a tf_executor.graph."); |
| return mlir::WalkResult::interrupt(); |
| } |
| |
| if (!hasSingleElement(block)) { |
| status = errors::FailedPrecondition( |
| kInvalidExecutorGraphMsg, |
| "function does not only contain a single tf_executor.graph."); |
| return mlir::WalkResult::interrupt(); |
| } |
| |
| for (Operation& op : graph.GetBody()) { |
| auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op); |
| if (!island) continue; |
| |
| if (!island.WrapsSingleOp()) { |
| status = errors::FailedPrecondition( |
| kInvalidExecutorGraphMsg, |
| "tf_executor.island must perfectly wrap a single op."); |
| return mlir::WalkResult::interrupt(); |
| } |
| } |
| |
| return mlir::WalkResult::advance(); |
| }); |
| |
| return status; |
| } |
| |
| // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is |
| // returned. |
| Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) { |
| auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op); |
| if (island) return &island.GetBody().front(); |
| return op; |
| } |
| |
| // Stateful helper class to export a function into a Graph. |
| class Exporter { |
| public: |
| // Converts the given Module to a Graph. The given module should only contain |
| // one entry function, which is identified by name "main". This entry function |
| // is converted to the base of the graph graph. The rest of the functions are |
| // converted to the library functions in that graph. |
| static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs, |
| std::unique_ptr<Graph>* graph, |
| FunctionLibraryDefinition* flib_def, |
| absl::flat_hash_set<Node*>* control_ret_nodes); |
| |
| // Converts a given FuncOp to a FunctionDef and adds it to the function |
| // definition library |
| static Status ConvertLibFunction(const GraphExportConfig& configs, |
| const Dialect* tf_dialect, |
| mlir::FuncOp function, |
| FunctionDefLibrary* flib); |
| // Converts the given FuncOp to a Graph. The arguments and returns of |
| // function are added to the graph with special op names kArgOp and kRetOp. |
| // Later on, this graph can be converted a function definition and added to |
| // another graph. |
| static StatusOr<std::unique_ptr<Graph>> Convert( |
| const GraphExportConfig& configs, const Dialect* tf_dialect, |
| mlir::FuncOp function, FunctionDefLibrary* flib, |
| absl::flat_hash_set<Node*>* control_ret_nodes); |
| |
| private: |
| explicit Exporter(Graph* graph, const Dialect* tf_dialect) |
| : graph_(graph), tf_dialect_(tf_dialect) {} |
| |
| Status AddArgumentNode(BlockArgument arg, unsigned index, |
| llvm::StringRef name); |
| Status AddFetchNode(mlir::FuncOp function, mlir::tf_executor::FetchOp fetch, |
| llvm::ArrayRef<llvm::StringRef> names); |
| Status AddInstructionNode(Operation* inst); |
| Status AddEdge(Operation* inst); |
| |
| StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg, |
| unsigned index, |
| llvm::StringRef name); |
| StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(mlir::FuncOp function, |
| Value operand, |
| unsigned index, |
| llvm::StringRef name); |
| Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch, |
| absl::flat_hash_set<Node*>* control_ret_nodes); |
| // Adds one edge between src_node and dst_node. If it is not a control edge, |
| // an index is used to find out the right operand of the dst_node. |
| Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index); |
| |
| Graph* graph_; |
| LegalizedOpOrValLocNameMapper op_to_name_; |
| absl::flat_hash_map<Operation*, Node*> nodes_; |
| llvm::DenseMap<BlockArgument, Node*> args_; |
| // One single return operation can return multiple results, and each of them |
| // will be converted to one node in the graph. |
| typedef absl::InlinedVector<Node*, 4> NodeVector; |
| absl::flat_hash_map<Operation*, NodeVector> returns_; |
| const mlir::Dialect* tf_dialect_; |
| }; |
| |
| StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode( |
| BlockArgument arg, unsigned index, llvm::StringRef name) { |
| auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>(); |
| |
| auto node_def = absl::make_unique<NodeDef>(); |
| if (!name.empty()) |
| node_def->set_name(name.str()); |
| else |
| node_def->set_name( |
| std::string(op_to_name_.GetUniqueName(func.getName().str()))); |
| |
| node_def->set_op(FunctionLibraryDefinition::kArgOp); |
| |
| TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes", |
| arg.getType().cast<mlir::ShapedType>(), |
| node_def->mutable_attr())); |
| |
| DataType dtype; |
| TF_RETURN_IF_ERROR(ConvertToDataType( |
| arg.getType().cast<mlir::TensorType>().getElementType(), &dtype)); |
| AttrValue type_attr; |
| type_attr.set_type(dtype); |
| (*node_def->mutable_attr())["T"] = type_attr; |
| |
| AttrValue index_attr; |
| index_attr.set_i(index); |
| (*node_def->mutable_attr())["index"] = index_attr; |
| |
| if (auto device_attr = |
| func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr)) |
| *node_def->mutable_device() = device_attr.getValue().str(); |
| |
| llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs = |
| func.getArgAttrs(index); |
| absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr}; |
| TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, |
| node_def->mutable_attr())); |
| |
| return node_def; |
| } |
| |
| // TODO(b/160014479): Support exporting function result attributes as optional |
| // attributes. |
| StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode( |
| mlir::FuncOp function, Value operand, unsigned index, |
| llvm::StringRef name) { |
| auto node_def = absl::make_unique<NodeDef>(); |
| if (!name.empty()) |
| node_def->set_name(name.str()); |
| else |
| node_def->set_name( |
| std::string(op_to_name_.GetUniqueName(function.getName().str()))); |
| |
| node_def->set_op(FunctionLibraryDefinition::kRetOp); |
| DataType dtype; |
| TF_RETURN_IF_ERROR(ConvertToDataType( |
| operand.getType().cast<mlir::TensorType>().getElementType(), &dtype)); |
| AttrValue type_attr; |
| type_attr.set_type(dtype); |
| (*node_def->mutable_attr())["T"] = type_attr; |
| AttrValue index_attr; |
| index_attr.set_i(index); |
| (*node_def->mutable_attr())["index"] = index_attr; |
| return node_def; |
| } |
| |
| Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, |
| unsigned dst_index) { |
| if (auto input_result = src.dyn_cast<mlir::OpResult>()) { |
| auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner()); |
| // Replaces the input node with NextIteration sink if it is a NextIteration |
| // source. |
| if (auto next_iter_source = |
| llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>( |
| input_inst)) |
| input_inst = next_iter_source.GetSink(); |
| |
| auto node_it = nodes_.find(input_inst); |
| TF_RET_CHECK(node_it != nodes_.end()) |
| << "Use of OpResult encountered before def!"; |
| if (input_result.getType().isa<mlir::tf_executor::ControlType>()) { |
| graph_->AddControlEdge(node_it->second, dst_node); |
| } else { |
| graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node, |
| dst_index); |
| } |
| return Status::OK(); |
| } |
| |
| auto input_arg = src.cast<BlockArgument>(); |
| auto input_node_it = args_.find(input_arg); |
| TF_RET_CHECK(input_node_it != args_.end()) |
| << "Use of BlockArgument encounted before def!"; |
| // For argument, there is only one result output, so the index is always 0. |
| graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index); |
| return Status::OK(); |
| } |
| |
| Status Exporter::AddEdge(Operation* inst) { |
| // For tf_executor.fetch, add only its data edges. Control edges are captured |
| // later. |
| if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) { |
| for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) { |
| Value operand = operand_and_idx.value(); |
| if (operand.getType().isa<mlir::tf_executor::ControlType>()) break; |
| |
| auto* dst_node = returns_[fetch][operand_and_idx.index()]; |
| TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // For tf_executor.NextIteration.Sink, skip its token operand and add data and |
| // control edges with their index offset by 1. |
| if (auto next_iter_sink = |
| llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) { |
| auto* dst_node = nodes_[inst]; |
| TF_RETURN_IF_ERROR( |
| AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0)); |
| for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs())) |
| TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node, |
| control_and_idx.index() + 1)); |
| |
| return Status::OK(); |
| } |
| |
| // For tf_executor.NextIteration.Source, op can be skipped as it is assumed |
| // there are no operands. |
| if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) { |
| assert(inst->getNumOperands() == 0); |
| return Status::OK(); |
| } |
| |
| Operation* op = GetIslandInnerOpOrSelf(inst); |
| auto* dst_node = nodes_[op]; |
| int operand_offset = 0; |
| // For tf_executor.island, add data edges from its wrapped op before control |
| // edges. |
| if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) { |
| for (auto operand_and_idx : llvm::enumerate(op->getOperands())) |
| TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node, |
| operand_and_idx.index())); |
| |
| operand_offset = op->getNumOperands(); |
| } |
| |
| // For all other ops (including tf_executor.island), add remaining edges. |
| for (auto operand_and_idx : llvm::enumerate(inst->getOperands())) |
| TF_RETURN_IF_ERROR( |
| AddEdgeBetweenNodes(operand_and_idx.value(), dst_node, |
| operand_and_idx.index() + operand_offset)); |
| |
| return Status::OK(); |
| } |
| |
| Status Exporter::AddInstructionNode(Operation* inst) { |
| std::unique_ptr<NodeDef> node_def; |
| auto name = op_to_name_.GetUniqueName(inst); |
| // Convert registered TF ops to NodeDef. Only registered ops are handled to |
| // ensure that PopulateDerivedAttrs adds the correct attributes. |
| TF_ASSIGN_OR_RETURN(node_def, |
| ConvertTFDialectOpToNodeDef( |
| inst, name, /*ignore_unregistered_attrs=*/false)); |
| |
| Status status; |
| Node* node = graph_->AddNode(*node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| DCHECK(node != nullptr); |
| nodes_[inst] = node; |
| return Status::OK(); |
| } |
| |
| bool IsEntryFunctionArg(BlockArgument arg) { |
| return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() == |
| "main"; |
| } |
| |
| // Creates argument nodes from Block argument. If a name is supplied, that |
| // name will be used instead of generating a unique name. |
| Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index, |
| llvm::StringRef name) { |
| TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); |
| Status status; |
| Node* node = graph_->AddNode(*node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| args_[arg] = node; |
| return Status::OK(); |
| } |
| |
| // Creates return nodes per operand of a FetchOp. If names is supplied, those |
| // names will be used per node in order instead of generating a unique name. |
| Status Exporter::AddFetchNode(mlir::FuncOp function, |
| mlir::tf_executor::FetchOp fetch, |
| llvm::ArrayRef<llvm::StringRef> names) { |
| Status status; |
| auto& return_nodes = returns_[fetch]; |
| for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) { |
| if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>()) |
| break; |
| |
| TF_ASSIGN_OR_RETURN( |
| auto node_def, |
| GetReturnNode(function, operand_and_idx.value(), |
| operand_and_idx.index(), |
| names.empty() ? "" : names[operand_and_idx.index()])); |
| Node* node = graph_->AddNode(*node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| return_nodes.push_back(node); |
| } |
| return Status::OK(); |
| } |
| |
| // Collects control ret Nodes based on tf_executor.graph's associated |
| // tf_executor.fetch control inputs. |
| Status Exporter::GetControlRetNodes( |
| mlir::tf_executor::FetchOp fetch, |
| absl::flat_hash_set<Node*>* control_ret_nodes) { |
| for (Value fetch_operand : fetch.getOperands()) { |
| if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) { |
| Operation* defining_op = |
| GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp()); |
| auto node_it = nodes_.find(defining_op); |
| TF_RET_CHECK(node_it != nodes_.end()); |
| control_ret_nodes->insert(node_it->second); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<std::unique_ptr<Graph>> Exporter::Convert( |
| const GraphExportConfig& configs, const Dialect* tf_dialect, |
| mlir::FuncOp function, FunctionDefLibrary* flib, |
| absl::flat_hash_set<Node*>* control_ret_nodes) { |
| mlir::Block& block = function.front(); |
| |
| // Extract input & output names if set. |
| llvm::SmallVector<llvm::StringRef, 2> input_names; |
| llvm::SmallVector<llvm::StringRef, 2> output_names; |
| auto dict_attr = |
| function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function"); |
| if (dict_attr) { |
| TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>()) |
| << "inputs missing in entry function attribute"; |
| TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>()) |
| << "outputs missing in entry function attribute"; |
| dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split( |
| input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); |
| dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split( |
| output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); |
| } |
| |
| auto graph = absl::make_unique<Graph>(OpRegistry::Global()); |
| |
| // Extract version info. |
| VersionDef versions; |
| auto module = function.getParentOfType<mlir::ModuleOp>(); |
| if (mlir::succeeded(ExtractTfVersions(module, &versions))) { |
| graph->set_versions(versions); |
| } |
| |
| // We have to add the function library here, so a custom operation, which is |
| // defined in the function library can be added to the graph. |
| TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); |
| Exporter exporter(graph.get(), tf_dialect); |
| |
| auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front()); |
| |
| // Set input and output names and increment the use counter for them to help |
| // generate unique names. |
| if (!output_names.empty()) { |
| const int num_data_results = graph_op.getNumResults(); |
| const int64 output_names_size = output_names.size(); |
| TF_RET_CHECK(output_names_size == num_data_results) |
| << "output names (" << output_names.size() |
| << ") != terminator operands (" << num_data_results << ")"; |
| llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name; |
| llvm::StringMap<Operation*> name_to_op; |
| for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) { |
| // Skip control rets. |
| const int64 index = it.index(); |
| if (index >= num_data_results) break; |
| // TODO(jpienaar): If there is a result index specified, ensure only one |
| // and that it matches the result index of the op. |
| std::string orig_name(output_names[index]); |
| auto tensor_id = ParseTensorName(orig_name); |
| auto name = LegalizeNodeName( |
| llvm::StringRef(tensor_id.node().data(), tensor_id.node().size())); |
| |
| // Ensure name does not get reused. |
| (void)exporter.op_to_name_.GetUniqueName(name); |
| } |
| } |
| |
| if (!input_names.empty()) { |
| TF_RET_CHECK(input_names.size() == block.getNumArguments()); |
| for (const auto& it : llvm::enumerate(function.getArguments())) { |
| // TODO(lyandy): Update when changing feed/fetch import. |
| std::string orig_name(input_names[it.index()]); |
| std::string name = LegalizeNodeName(orig_name); |
| auto tensor_id = ParseTensorName(name); |
| TF_RET_CHECK(tensor_id.index() == 0) |
| << "input port designation not supported"; |
| // Only assign user of argument the input name if the main graph did not |
| // have its _Arg nodes lifted into the functions arguments. |
| // Ensure name does not get reused. |
| (void)exporter.op_to_name_.GetUniqueName(name); |
| } |
| } |
| |
| // Adds nodes for basic block (function) arguments. |
| for (auto it : llvm::enumerate(block.getArguments())) { |
| int index = it.index(); |
| auto arg = it.value(); |
| mlir::Type type = arg.getType(); |
| if (!type.isa<mlir::TensorType>()) { |
| return errors::InvalidArgument( |
| "FuncOps arguments must have tensor types. Found ", |
| mlir::debugString(type), " in function ", function.getName().str()); |
| } |
| |
| TF_RETURN_IF_ERROR(exporter.AddArgumentNode( |
| arg, index, !input_names.empty() ? input_names[index] : "")); |
| } |
| |
| auto convert_called_function = [&](llvm::StringRef name) { |
| auto func = |
| function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>( |
| name); |
| if (func != nullptr) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib)); |
| TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); |
| } |
| return Status::OK(); |
| }; |
| |
| // Adds nodes for operations. |
| for (Operation& inst : graph_op.GetBody()) { |
| for (auto type : inst.getResultTypes()) |
| if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType, |
| mlir::tf_executor::TokenType>()) |
| return errors::InvalidArgument( |
| "Values must be of tensor type, TensorFlow control type, or " |
| "TensorFlow token type. Found ", |
| mlir::debugString(type)); |
| |
| if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) { |
| // Skip tf_executor.NextIteration.Source as associated |
| // tf_executor.NextIteration.Sink will be used instead. |
| continue; |
| } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) { |
| TF_RETURN_IF_ERROR(exporter.AddFetchNode(function, fetch, output_names)); |
| } else if (auto island = |
| llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) { |
| Operation& inner_op = island.GetBody().front(); |
| auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef()); |
| if (op_name.ok()) { |
| // If it is TF Control dialect specific op, look up custom operation |
| // in the module and first convert that, then add it to function |
| // definition library |
| // TODO(prakalps): If two functions have cyclic dependence, this will |
| // introduce an infinite loop. |
| TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str())); |
| } |
| |
| if (IsLegacyCallInstruction(&inner_op)) { |
| TF_RETURN_IF_ERROR(convert_called_function( |
| inner_op.getAttrOfType<mlir::SymbolRefAttr>("f") |
| .getLeafReference())); |
| } |
| |
| TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op)); |
| } else { |
| TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst)); |
| } |
| } |
| // Adds edges between the argument, operation and return nodes. |
| for (Operation& inst : graph_op.GetBody()) { |
| TF_RETURN_IF_ERROR(exporter.AddEdge(&inst)); |
| } |
| // Fixes the edges between the inserted nodes and special "_SOURCE" and |
| // "_SINK". |
| FixupSourceAndSinkEdges(graph.get()); |
| |
| TF_RETURN_IF_ERROR( |
| exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes)); |
| |
| return graph; |
| } |
| |
| Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, |
| const Dialect* tf_dialect, |
| mlir::FuncOp function, |
| FunctionDefLibrary* flib) { |
| // First look for the function in the current function library. If found, |
| // nothing needs to be done. |
| OpRegistry empty_registry; |
| FunctionLibraryDefinition flib_def(&empty_registry, *flib); |
| auto function_name = function.getName().str(); |
| if (flib_def.Find(function_name)) return Status::OK(); |
| |
| // TODO(fengliuai): use a small flib_def to reduce overhead |
| absl::flat_hash_set<Node*> control_ret_nodes; |
| TF_ASSIGN_OR_RETURN(auto sub_graph, |
| Exporter::Convert(configs, tf_dialect, function, flib, |
| &control_ret_nodes)); |
| const auto control_ret = [&](const Node* n) -> absl::optional<string> { |
| return control_ret_nodes.contains(n) |
| ? absl::make_optional<string>(n->name()) |
| : absl::nullopt; |
| }; |
| FunctionDef func_def; |
| TF_RETURN_IF_ERROR( |
| GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def)); |
| |
| // The node defs in FunctionDef might contain debug info which was added |
| // by the GraphToFunctionDef method. We should remove it if we don't want |
| // to export them to avoid failing the roundtrip test. |
| if (!configs.export_debug_info) { |
| for (auto& node_def : *func_def.mutable_node_def()) { |
| node_def.clear_experimental_debug_info(); |
| } |
| } |
| |
| // Checks for gradient attribute. If present converts the gradient function |
| // and populates the GradientDef. |
| auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); |
| if (auto attr = |
| function.getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) { |
| auto grad_func = |
| function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>( |
| attr.getValue()); |
| TF_RETURN_IF_ERROR( |
| ConvertLibFunction(configs, tf_dialect, grad_func, flib)); |
| GradientDef grad; |
| grad.set_function_name(function_name); |
| grad.set_gradient_func(grad_func.getName().str()); |
| *flib->add_gradient() = grad; |
| } |
| |
| auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName(); |
| if (auto attr = function.getAttrOfType<mlir::UnitAttr>(stateful_string)) { |
| func_def.mutable_signature()->set_is_stateful(true); |
| } |
| |
| // Ignore the gradient and is_stateful attribute on the function as they have |
| // been handled above. |
| absl::flat_hash_set<absl::string_view> attrs_to_ignore = { |
| grad_string.data(), stateful_string.data()}; |
| llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs( |
| function.getDialectAttrs()); |
| TF_RETURN_IF_ERROR( |
| ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr())); |
| |
| for (int i = 0, e = function.getNumArguments(); i < e; ++i) { |
| if (auto resource_arg_unique_id_attr = |
| function.getArgAttrOfType<mlir::IntegerAttr>( |
| i, kResourceArgUniqueIdAttr)) { |
| (*func_def.mutable_resource_arg_unique_id())[i] = |
| resource_arg_unique_id_attr.getInt(); |
| } |
| |
| llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs = |
| function.getArgAttrs(i); |
| if (func_arg_i_attrs.empty()) continue; |
| absl::flat_hash_set<absl::string_view> attrs_to_ignore = { |
| kDeviceAttr, kResourceArgUniqueIdAttr}; |
| FunctionDef::ArgAttrs func_def_arg_i_attrs; |
| TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, |
| func_def_arg_i_attrs.mutable_attr())); |
| if (func_def_arg_i_attrs.attr().empty()) continue; |
| (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs); |
| } |
| |
| (*flib->add_function()) = std::move(func_def); |
| return Status::OK(); |
| } |
| |
| Status Exporter::Convert(mlir::ModuleOp module, |
| const GraphExportConfig& configs, |
| std::unique_ptr<Graph>* graph, |
| FunctionLibraryDefinition* flib_def, |
| absl::flat_hash_set<Node*>* control_ret_nodes) { |
| mlir::Identifier entry_func_id = |
| mlir::Identifier::get("main", module.getContext()); |
| absl::optional<mlir::FuncOp> entry_func; |
| FunctionDefLibrary flib; |
| auto tf_dialect = module.getContext()->getRegisteredDialect("tf"); |
| for (auto function : module.getOps<mlir::FuncOp>()) { |
| if (function.isExternal()) |
| return errors::FailedPrecondition("External functions not supported"); |
| |
| if (function.getName() == entry_func_id) { |
| entry_func.emplace(function); |
| } else { |
| TF_RETURN_IF_ERROR( |
| ConvertLibFunction(configs, tf_dialect, function, &flib)); |
| } |
| } |
| |
| if (!entry_func.has_value()) |
| return errors::FailedPrecondition("entry function `main` must be present"); |
| |
| // Updates the graph and the function library definition. |
| TF_ASSIGN_OR_RETURN( |
| *graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib, |
| control_ret_nodes)); |
| for (auto& func_def : flib.function()) { |
| TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def)); |
| } |
| for (auto& grad_def : flib.gradient()) { |
| TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def)); |
| } |
| return Status::OK(); |
| } |
| } // namespace |
| |
| Status ConvertMlirToGraph(mlir::ModuleOp module, |
| const GraphExportConfig& configs, |
| std::unique_ptr<Graph>* graph, |
| FunctionLibraryDefinition* flib_def, |
| absl::flat_hash_set<Node*>* control_ret_nodes) { |
| TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module)); |
| return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes); |
| } |
| |
| Status ConvertMlirToGraph(mlir::ModuleOp module, |
| const GraphExportConfig& configs, |
| std::unique_ptr<Graph>* graph, |
| FunctionLibraryDefinition* flib_def) { |
| absl::flat_hash_set<Node*> control_ret_nodes; |
| return ConvertMlirToGraph(module, configs, graph, flib_def, |
| &control_ret_nodes); |
| } |
| |
| StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef( |
| mlir::ModuleOp module, const GraphExportConfig& configs) { |
| FunctionLibraryDefinition flib_def(OpRegistry::Global(), |
| FunctionDefLibrary()); |
| auto graph = absl::make_unique<Graph>(flib_def); |
| TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def)); |
| auto graphdef = absl::make_unique<GraphDef>(); |
| graph->ToGraphDef(graphdef.get()); |
| if (!configs.export_library) graphdef->clear_library(); |
| if (!configs.export_shapes) { |
| for (auto& node_def : *graphdef->mutable_node()) { |
| node_def.mutable_attr()->erase("shape"); |
| } |
| } |
| if (!configs.export_debug_info) { |
| for (auto& node_def : *graphdef->mutable_node()) { |
| node_def.clear_experimental_debug_info(); |
| } |
| } |
| return graphdef; |
| } |
| |
| stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( |
| mlir::FuncOp func, const GraphExportConfig& configs, |
| FunctionDef* function_def) { |
| Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf"); |
| FunctionDefLibrary flib; |
| TF_RETURN_IF_ERROR( |
| Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); |
| for (auto& func_def : flib.function()) { |
| if (func_def.signature().name() == func.getName()) { |
| *function_def = func_def; |
| return Status::OK(); |
| } |
| } |
| return errors::InvalidArgument( |
| "Function couldn't be found in the FunctionDefLibrary after converting " |
| "from MLIR"); |
| } |
| |
| } // namespace tensorflow |