| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" |
| |
| #include <iterator> |
| #include <string> |
| #include <tuple> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/escaping.h" |
| #include "absl/strings/numbers.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/strings/strip.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Analysis/Verifier.h" // TF:llvm-project |
| #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project |
| #include "mlir/IR/Attributes.h" // TF:llvm-project |
| #include "mlir/IR/Builders.h" // TF:llvm-project |
| #include "mlir/IR/Function.h" // TF:llvm-project |
| #include "mlir/IR/Identifier.h" // TF:llvm-project |
| #include "mlir/IR/Location.h" // TF:llvm-project |
| #include "mlir/IR/MLIRContext.h" // TF:llvm-project |
| #include "mlir/IR/Module.h" // TF:llvm-project |
| #include "mlir/IR/OpDefinition.h" // TF:llvm-project |
| #include "mlir/IR/StandardTypes.h" // TF:llvm-project |
| #include "mlir/IR/Types.h" // TF:llvm-project |
| #include "tensorflow/compiler/jit/shape_inference_helpers.h" |
| #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" |
| #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/shape_refiner.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/function.pb.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/resource_var.h" |
| #include "tensorflow/core/framework/shape_inference.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/graph/tensor_id.h" |
| #include "tensorflow/core/grappler/utils/transitive_fanin.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/protobuf/graph_debug_info.pb.h" |
| #include "tensorflow/core/protobuf/meta_graph.pb.h" |
| #include "tensorflow/core/protobuf/saved_object_graph.pb.h" |
| #include "tensorflow/core/protobuf/struct.pb.h" |
| #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" |
| #include "tensorflow/stream_executor/lib/statusor.h" |
| |
| static inline absl::string_view StringRefToView(llvm::StringRef ref) { |
| return {ref.data(), ref.size()}; |
| } |
| |
| namespace tensorflow { |
| using stream_executor::port::StatusOr; |
| |
| namespace { |
| |
| const char* disable_call_shape_inference_attribute_name = |
| "_disable_call_shape_inference"; |
| |
| // This class is used to generate new MLIR function name strings that are both |
| // unique in the TF function library `flib_` and unique among the name strings |
| // generated by the class object during its lifetime. |
| // |
| // In theory, this class is not necessary because we should simply take |
| // the TF function name and use it as MLIR function name. However, for some |
| // unknown reasons (callout for investigation in b/142268695), keeping the |
| // function names unchanged in an MLIR roundtrip causes test failures. |
| // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using |
| // and TF function name as MLIR function name after b/142268695 is root caused. |
| class NameUniquifier : public OpOrArgNameMapper { |
| public: |
| explicit NameUniquifier(const FunctionLibraryDefinition& flib) |
| : flib_(flib) {} |
| |
| private: |
| bool IsUnique(llvm::StringRef name) override { |
| return !flib_.Contains(std::string(name)); |
| } |
| |
| std::string GetName(OpOrVal op_or_val) override { |
| DCHECK(false) << "Unimplemented"; |
| return ""; |
| } |
| |
| const FunctionLibraryDefinition& flib_; |
| }; |
| |
| // Populates the tf.versions attribute on a module, given a corresponding |
| // graph VersionDef proto. |
| void PopulateTfVersions(mlir::ModuleOp module, |
| const VersionDef& graph_versions) { |
| mlir::Builder b(module.getContext()); |
| auto producer = b.getNamedAttr( |
| "producer", b.getI32IntegerAttr(graph_versions.producer())); |
| auto min_consumer = b.getNamedAttr( |
| "min_consumer", b.getI32IntegerAttr(graph_versions.min_consumer())); |
| auto bad_consumers = b.getNamedAttr( |
| "bad_consumers", b.getI32ArrayAttr(llvm::ArrayRef<int32_t>( |
| graph_versions.bad_consumers().begin(), |
| graph_versions.bad_consumers().end()))); |
| module.setAttr("tf.versions", |
| b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>( |
| {producer, min_consumer, bad_consumers}))); |
| } |
| |
| // Stateful helper class to import a TensorFlow model into an MLIR Module. |
| // |
| // This is the base class that contains common utilities shared between the |
| // GraphDef importer and SavedModel importer. |
| // |
| // A subclass is expected to call `PrepareConvert` first to perform necessary |
| // preparation over the graph and also certain internal bookkeeping data. |
| // Afterwards the other protected methods can be called. |
| class ImporterBase { |
| protected: |
| explicit ImporterBase( |
| const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::ModuleOp module, |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name, |
| NameUniquifier* function_name_uniquifier, |
| llvm::StringRef function_name_for_debug_info = "") |
| : builder_(module.getContext()), |
| module_(module), |
| context_(module.getContext()), |
| tf_name_to_mlir_name_(tf_name_to_mlir_name), |
| graph_flib_(flib), |
| specs_(specs), |
| debug_info_(debug_info), |
| function_name_for_debug_info_(function_name_for_debug_info), |
| function_name_uniquifier_(function_name_uniquifier) {} |
| |
| // Returns the inferred function signature of the given function body. Input |
| // types are unranked tensor of the respective datatype in the function and |
| // result types are inferred by the shape_refiner_. Result types need not be |
| // unranked tensors and could be ranked tensors in cases where result type |
| // depends on an op with static output shape like tf.Const. |
| StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody); |
| |
| // Extracts arg and ret nodes from FunctionBody. |
| // `resource_arg_unique_ids` will be filled with the unique IDs of resource |
| // variables, as a list of {index, ID} pairs. |
| void GetArgsAndRetsFromFunctionBody( |
| const FunctionBody& fbody, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes, |
| absl::InlinedVector<std::pair<int64_t, int64_t>, 4>* |
| resource_arg_unique_ids); |
| |
| // Prepares converting the graph to an MLIR module. This step removes the |
| // backedges of the graph, orders the nodes and infers the shapes. |
| Status PrepareConvert(const Graph& graph); |
| |
| // Converts the prepared graph to a Function and adds it to the module. A set |
| // of nodes from the graph are given to converted to the arguments and returns |
| // of the function. |
| Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes, |
| llvm::ArrayRef<mlir::NamedAttribute> attrs, |
| const absl::InlinedVector<std::pair<int64_t, int64_t>, 4>& |
| resource_arg_unique_ids); |
| |
| // Finds out the function definition for the given function name from the |
| // graph and converts it to a function of the module. This method is called |
| // on demand because the graph flib_def does not provide an iterator |
| // interface. |
| Status ConvertLibFunction(llvm::StringRef func_name); |
| |
| // Returns the list of nodes in the graph. Nodes are presented in the reverse |
| // order of a post-order depth-first visit starting from the graph's source |
| // nodes. |
| llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; } |
| |
| // Returns the inferred input type at index `idx` of the `node` in the |
| // context. |
| StatusOr<mlir::TensorType> InferInputType(const Node& node, int idx, |
| mlir::Builder builder); |
| |
| // Returns the inferred output type at index `idx` of the `node` in the |
| // context. |
| StatusOr<mlir::TensorType> InferOutputType(const Node& node, int idx, |
| mlir::Builder builder); |
| |
| private: |
| // Most types with subtypes have only one subtype. |
| using ElementSubtypes = llvm::SmallVector<mlir::TensorType, 1>; |
| |
| // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all |
| // data type and shape information is maintained by the shape_refiner_. |
| Status AddNodesToShapeRefiner(); |
| |
| // Converts the inferred shape referred to by 'handle' in 'context', with |
| // given element type, and returns an MLIR tensor type. |
| StatusOr<mlir::TensorType> ConvertDataTypeAndShape( |
| DataType dtype, const shape_inference::ShapeHandle& handle, |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder); |
| |
| // Converts the inferred shape referred to by 'handle' in 'context', with |
| // given element type, and returns an MLIR tensor type. |
| StatusOr<mlir::TensorType> ConvertElementTypeAndShape( |
| mlir::Type element_type, const shape_inference::ShapeHandle& handle, |
| shape_inference::InferenceContext* context, mlir::Builder builder); |
| |
| // Converts the inferred subtypes for an element type to corresponding MLIR |
| // types in 'context'. |
| StatusOr<ElementSubtypes> ConvertSubtypes( |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder); |
| |
| // Converts the tensor proto into an MLIR elements attribute. |
| StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) { |
| return ::tensorflow::ConvertTensorProto(value, &builder_); |
| } |
| |
| // Converts func name in graphdef to mlir::SymbolRefAttribute. |
| StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName( |
| const std::string& func_name); |
| |
| // Converts the given non-function-call AttrValue to an MLIR Attribute. |
| StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value); |
| |
| // Converts the given function-call AttrValue to MLIR Attributes and pushes |
| // them to the given attributes list. For example, if there is a kFunc |
| // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to |
| // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, |
| // {base_name.k2 : rfc}}. |
| Status ConvertFunctionCallAttribute( |
| const std::string& base_name, const AttrValue& value, |
| llvm::SmallVector<mlir::NamedAttribute, 4>* attributes); |
| |
| // Helper to create either a tf_executor operation or a TF operation wrapped |
| // in an island. When convert_to_legacy_call is true, converts the operation |
| // representing a call to a library function with a name represented in |
| // node_type_name to LegacyCallOp. |
| mlir::Operation* createOperation( |
| const Node& node, llvm::StringRef node_type_name, |
| const mlir::OperationState& result, |
| const llvm::SmallVectorImpl<mlir::Value>& control_operands, |
| bool convert_to_legacy_call = false); |
| |
| // Converts one NodeDef from the input GraphDef into an Operation and |
| // inserts it into the MLIR module using builder_. |
| Status ConvertNode(const Node& node); |
| |
| // If the input graph represents a while-loop, the edges pointing from a |
| // "NextIteration" node to a "Merge" node add cyclic dependencies and make the |
| // topological sorting impossible. We need to remove these edges from the |
| // input graph to infer shapes and construct a Function. For each |
| // "NextIteration" node, there are two operations, "NextIteration.source" |
| // and "NextIteration.sink" are added to the MLIR module. |
| using BackEdge = BackEdgeHelper::BackEdge; |
| |
| // Removes backedges from the input graph. The removed edges are added back to |
| // to OpBuilder after the remaining graph is converted to the Function. |
| Status RemoveBackedges(const Graph& graph); |
| |
| // Restores backedges removed during shape inference to the final Function. |
| Status AddBackedges(); |
| |
| // Restores a single backedge in the Function by adding a replicated |
| // operation before the dst operation. |
| Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst, |
| int dst_input); |
| |
| // Adds the input arguments and return operation to the function. The |
| // arguments are added as basic block argument. Also the argument types and |
| // the id of the nodes from the input graph needs to be specified. |
| Status ConvertFunctionArgAndRets( |
| mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op, |
| llvm::ArrayRef<mlir::Type> arg_types, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes); |
| |
| // Gets the location information of the given node. It uses the |
| // "original_node_name" in the NodeDef to get the corresponding file location |
| // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If |
| // there are multiple "original_node_names", a FusedLoc is returned. If the |
| // node name couldn't be found in the input DebugInfo, a NameLoc is used as |
| // the location. |
| mlir::Location GetLocation(const NodeDef& node); |
| |
| // Gets the location information string for the given node. |
| std::string GetLocationStr(const Node& node, bool includeNodeName = false); |
| |
| // Inserts a placeholder node in the graph to replace a feed output tensor, |
| // and returns the new placeholder node and a boolean indicating if the |
| // original input node was removed from the graph. Uses of the feed output |
| // tensor are replaced with this placeholder node. If the feed output tensor |
| // is of a single output node, the control dependencies are forwarded to the |
| // the placeholder node, and the original node will be removed. |
| // Note: This modifies the graph, and so any list of ordered nodes needs to be |
| // reconstructed. |
| StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed( |
| const TensorShapeProto& shape, DataType dtype, Node* node, int index, |
| const std::unordered_map<string, Node*>& node_name_map); |
| |
| // Gets the input and output nodes corresponding to the specified input and |
| // output nodes in specs_. If there are no input or output nodes specified, |
| // nodes will be empty. |
| Status GetInputOutputNodes( |
| const std::unordered_map<string, Node*>& node_name_map, |
| std::unordered_set<const Node*>* nodes); |
| |
| // The input graph with backedges removed. The removed backedges are stored |
| // in the back_edge_helper. |
| BackEdgeHelper back_edge_helper_; |
| // A map between node and output index, for each backedge. |
| absl::flat_hash_map<const Node*, int> back_edge_node_output_; |
| absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_; |
| // A map between sink and source operation of NextIteration |
| absl::flat_hash_map<mlir::Operation*, mlir::Operation*> |
| next_iteration_sink_source_; |
| |
| // All nodes and version information about the (copied) imported graph. |
| std::unique_ptr<Graph> graph_; |
| std::vector<Node*> ordered_nodes_; |
| |
| // Maps from a Node ID to a MLIR value. |
| using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>; |
| |
| mlir::OpBuilder builder_; |
| mlir::ModuleOp module_; |
| mlir::MLIRContext* context_; |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_; |
| const FunctionLibraryDefinition& graph_flib_; |
| const GraphImportConfig& specs_; |
| const GraphDebugInfo& debug_info_; |
| llvm::StringRef function_name_for_debug_info_; |
| NodeValueMap node_values_; |
| std::unique_ptr<ShapeRefiner> shape_refiner_; |
| NameUniquifier* function_name_uniquifier_; |
| |
| protected: |
| // Maps feed as TensorId to new Placeholder node name. |
| absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_; |
| }; |
| |
| // Returns true if the node with given name has a non primary output that is |
| // used by some other node as an input. Returns false if no outputs are in use |
| // or only the first output is in use. |
| bool HasNonPrimaryOutputInUse(const GraphDef& graph_def, |
| const std::string& node) { |
| for (const auto& node_def : graph_def.node()) { |
| for (const auto& input : node_def.input()) { |
| if (absl::StartsWith(input, node + ":") && input != node + ":0") { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| // Updates the given LegacyFedInput node with Placeholder node if it is one of |
| // the inputs. Returns an error if non primary output of the LegacyFedInput node |
| // is in use and therefore can not be replaced by the Placeholder node that only |
| // has a single output. |
| Status UpdateLegacyFedInputNode(const GraphDef& graph_def, |
| const GraphImportConfig::InputArrays& inputs, |
| NodeDef* node) { |
| const std::string& node_name = node->name(); |
| auto it = inputs.find(node_name); |
| |
| // Node is not an input. |
| if (it == inputs.end()) return Status::OK(); |
| |
| if (HasNonPrimaryOutputInUse(graph_def, node_name)) { |
| return errors::InvalidArgument( |
| "LegacyFedInput node ", node->name(), |
| " has non primary output in use and can not be replaced with " |
| "Placeholder node"); |
| } |
| |
| DataType dtype = it->second.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->attr().at("output_types").list().type(0); |
| } |
| // Update op name, drop inputs and set attributes required by the Placeholder |
| // op. |
| *node->mutable_op() = "Placeholder"; |
| node->clear_attr(); |
| node->clear_input(); |
| AddNodeAttr("dtype", dtype, node); |
| AddNodeAttr("shape", it->second.shape, node); |
| return Status::OK(); |
| } |
| |
| // Preprocesses GraphDef before it can be converted to Graph by, |
| // - Adding the default attributes to each node def if they are missing from |
| // the GraphDef. |
| // - Replacing LegacyFedInput nodes with Placeholder nodes if |
| // convert_legacy_fed_inputs option is enabled. |
| Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { |
| for (auto& node_def : *graph_def->mutable_node()) { |
| // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One |
| // solution could be have a tool to let users upgrade old serialized graphs. |
| if (specs && specs->convert_legacy_fed_inputs && |
| node_def.op() == "LegacyFedInput") { |
| TF_RETURN_IF_ERROR( |
| UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def)); |
| } |
| |
| const tensorflow::OpRegistrationData* op_reg_data = |
| tensorflow::OpRegistry::Global()->LookUp(node_def.op()); |
| if (!op_reg_data) { |
| // This is likely a function call node, so we should continue. |
| continue; |
| } |
| ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def); |
| } |
| return Status::OK(); |
| } |
| |
| // Mapping from node name to feed (index and ArrayInfo). Node name must outlive |
| // this map. |
| using FeedsByNode = absl::flat_hash_map< |
| absl::string_view, |
| absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>; |
| |
| // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output |
| // tensor name to index and ArrayInfo. Keys and values are backed by |
| // `GraphImportConfig::InputArrays`. |
| StatusOr<FeedsByNode> GetFeedsByNode( |
| const GraphImportConfig::InputArrays& inputs) { |
| FeedsByNode feeds_by_node; |
| feeds_by_node.reserve(inputs.size()); |
| |
| for (const auto& input : inputs) { |
| TensorId tensor = ParseTensorName(input.first); |
| if (tensor.index() < 0) |
| return errors::FailedPrecondition( |
| "Feed output tensor must be a data output '", tensor.ToString(), "'"); |
| |
| auto& node = feeds_by_node[tensor.node()]; |
| if (!node.insert({tensor.index(), &input}).second) |
| return errors::FailedPrecondition( |
| "Multiple feeds for the same output tensor '", tensor.ToString(), |
| "'"); |
| } |
| |
| return feeds_by_node; |
| } |
| |
| // Creates a unique name for a node that will be replacing a feed output tensor. |
| std::string GetUniqueNodeName( |
| absl::string_view node_name, int index, |
| const std::unordered_map<string, Node*>& node_name_map) { |
| std::string new_node_name_base = absl::StrCat(node_name, "_", index); |
| int count = 0; |
| std::string new_node_name = new_node_name_base; |
| while (node_name_map.find(new_node_name) != node_name_map.end()) { |
| new_node_name = absl::StrCat(new_node_name_base, "_", count++); |
| } |
| return new_node_name; |
| } |
| |
| Status ImporterBase::RemoveBackedges(const Graph& graph) { |
| // TODO(fengliuai): Converting to GraphDef and back is the easiest way to |
| // clone a graph. |
| // TODO(fengliuai): clone the graph without going to graph_def first. |
| GraphDef graph_def; |
| graph.ToGraphDef(&graph_def); |
| graph_ = absl::make_unique<Graph>(graph.flib_def()); |
| GraphConstructorOptions opts; |
| opts.allow_internal_ops = true; |
| opts.add_default_attributes = false; |
| TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph( |
| opts, std::move(graph_def), graph_.get())); |
| |
| // Remove all the backedges. So the nodes can be added to the shape refiner. |
| TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get())); |
| VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size()) |
| << " backedges."; |
| |
| // Creates a map for quickly identifying whether a node output is a backedge. |
| for (const auto& edge : back_edge_helper_.RemovedEdges()) { |
| if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() && |
| back_edge_node_output_[edge.src] != edge.src_output) { |
| return errors::FailedPrecondition( |
| "More than one of the src node outputs are backedges!"); |
| } |
| back_edge_node_output_[edge.src] = edge.src_output; |
| // We expect a merge to receive a single backedge (multiple NextIteration |
| // nodes feeding into the same merge is unexpected here). |
| DCHECK(!back_edge_dst_inputs_.contains(edge.dst)); |
| back_edge_dst_inputs_[edge.dst] = edge; |
| } |
| |
| // Obtains a RPO ordering, using node names as a tiebreak for stable sorting. |
| GetReversePostOrder( |
| *graph_, &ordered_nodes_, |
| [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed( |
| const TensorShapeProto& shape, DataType dtype, Node* node, int index, |
| const std::unordered_map<string, Node*>& node_name_map) { |
| DCHECK_LT(index, node->num_outputs()); |
| const bool update_inplace = node->num_outputs() == 1 && index == 0; |
| std::string new_node_name = |
| update_inplace ? node->name() |
| : GetUniqueNodeName(node->name(), index, node_name_map); |
| |
| Node* placeholder_node; |
| NodeBuilder builder(new_node_name, "Placeholder"); |
| builder.Attr("shape", shape); |
| builder.Attr("dtype", dtype); |
| TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node)); |
| |
| // Update edges from original feed with Placeholder node. |
| std::vector<const Edge*> data_edges; |
| std::vector<const Edge*> control_edges; |
| for (const tensorflow::Edge* edge : node->out_edges()) { |
| if (edge->src_output() == index) { |
| data_edges.push_back(edge); |
| } else if (update_inplace && edge->IsControlEdge()) { |
| control_edges.push_back(edge); |
| } |
| } |
| |
| for (const auto* edge : data_edges) { |
| TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(), |
| edge->dst_input())); |
| } |
| |
| for (const auto* edge : control_edges) { |
| graph_->AddControlEdge(placeholder_node, edge->dst()); |
| graph_->RemoveControlEdge(edge); |
| } |
| |
| if (update_inplace) { |
| graph_->RemoveNode(node); |
| } |
| |
| return std::pair<Node*, bool>(placeholder_node, update_inplace); |
| } |
| |
| Status ImporterBase::GetInputOutputNodes( |
| const std::unordered_map<string, Node*>& node_name_map, |
| std::unordered_set<const Node*>* nodes) { |
| auto add_node = [&](absl::string_view name) { |
| auto it = node_name_map.find(std::string(name)); |
| if (it == node_name_map.end()) { |
| return errors::FailedPrecondition( |
| absl::StrCat("Graph does not contain node: ", name)); |
| } |
| nodes->insert(it->second); |
| return Status::OK(); |
| }; |
| |
| // Remap feeds and fetches to newly created Placeholder nodes. |
| for (const auto& input : specs_.inputs) { |
| TensorId tensor = ParseTensorName(input.first); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| TF_RETURN_IF_ERROR(add_node(remapped_it->second)); |
| } else { |
| TF_RETURN_IF_ERROR(add_node(tensor.node())); |
| } |
| } |
| |
| for (const auto& output : specs_.outputs) { |
| TensorId tensor = ParseTensorName(output); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| TF_RETURN_IF_ERROR(add_node(remapped_it->second)); |
| } else { |
| TF_RETURN_IF_ERROR(add_node(tensor.node())); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| // TODO(fengliuai): Replace the iterative algorithm by an one pass propagation |
| Status ImporterBase::AddNodesToShapeRefiner() { |
| shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(), |
| graph_->op_registry()); |
| // Some operations (for example "TPUExecute") don't have shape inference |
| // function defined, so we should set this to false for adding nodes with |
| // these types of operations. |
| shape_refiner_->set_require_shape_inference_fns(false); |
| shape_refiner_->set_function_library_for_shape_inference(&graph_flib_); |
| |
| TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); |
| auto node_name_map = graph_->BuildNodeNameIndex(); |
| |
| // First add all nodes to the refiner. |
| for (Node* node : ordered_nodes_) { |
| // We need to use a TensorFlow node to teach the shape refiner that user |
| // specifies certain data type and shape for the inputs in the `specs_`. |
| // This node shouldn't have any inputs, only have one output and its |
| // output type/shape is only determined by its "named" attributes. (The |
| // attributes should have fixed names so we can use the info from `specs_` |
| // to set the value of them.) `Placeholder` satisfies these constraints. |
| // |
| // Therefore, if the input node isn't a `Placeholder`, we create one and use |
| // it to replace the original input node, so the shape refiner can |
| // successfully propagate the user's input type and shape to the rest of the |
| // graph. |
| bool node_added_to_shape_refiner = false; |
| auto it = feeds_by_node.find(node->name()); |
| if (it != feeds_by_node.end()) { |
| auto op_name = node->op_def().name(); |
| if (op_name != "Placeholder" && op_name != "LegacyFedInput" && |
| op_name != FunctionLibraryDefinition::kArgOp) { |
| for (const auto& output_tensor : it->second) { |
| const int index = output_tensor.first; |
| const ArrayInfo& array_info = output_tensor.second->second; |
| |
| DataType dtype = array_info.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->output_type(0); |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| auto placeholder_node_and_removed, |
| CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, |
| node_name_map)); |
| |
| Node* placeholder_node = placeholder_node_and_removed.first; |
| if (placeholder_node_and_removed.second) { |
| // Original node has been removed from the graph. |
| node = placeholder_node; |
| node_added_to_shape_refiner = true; |
| } |
| remapped_feeds_[{it->first, index}] = placeholder_node->name(); |
| node_name_map[placeholder_node->name()] = placeholder_node; |
| // Add the new placeholder node to the shape refiner. |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| shape_refiner_->AddNode(placeholder_node), |
| GetLocationStr(*placeholder_node)); |
| } |
| } else { |
| auto index_it = it->second.find(0); |
| if (index_it == it->second.end()) { |
| return errors::FailedPrecondition( |
| "Missing feed output tensor at index 0 for node '", node->name(), |
| "'"); |
| } |
| node->AddAttr("shape", index_it->second->second.shape); |
| DataType dtype = index_it->second->second.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->output_type(0); |
| } |
| node->AddAttr("dtype", dtype); |
| } |
| } |
| if (!node_added_to_shape_refiner) { |
| // Add the node to the shape refiner if the node hasn't been removed. |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node), |
| GetLocationStr(*node)); |
| } |
| |
| auto set_shape_from_list_attr = [&](const AttrValue* attr) { |
| auto& list = attr->list(); |
| for (auto shape : llvm::enumerate(list.shape())) { |
| auto* node_context = shape_refiner_->GetContext(node); |
| shape_inference::ShapeHandle handle; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| node_context->MakeShapeFromShapeProto(shape.value(), &handle), |
| GetLocationStr(*node)); |
| node_context->set_output(shape.index(), handle); |
| } |
| return Status::OK(); |
| }; |
| |
| // We currently have no other way to get shapes from ReadVariableOp's. |
| // Some graphs seem to have _output_shapes attributes on them, so use that |
| // if possible. |
| // TODO(silvasean): Ideally, we would do this in a separate shape inference |
| // pass to avoid adding complexity to the importer. But right now, we don't |
| // have an MLIR-native shape inference pass, so we need to do this while we |
| // still have the Graph around, i.e. here, in the importer. |
| if (node->op_def().name() == "ReadVariableOp") { |
| // TODO(silvasean): In some graphs, this seems to be annotated on every |
| // node. Why and by whom? |
| // TODO(b/140588338): We should ideally incorporate that information for |
| // all nodes, but right now, this can result in e.g. an Identity node with |
| // signature such as |
| // `(tensor<?x?xf32>) -> tensor<?x9216xf32>` which fails the verifier |
| // (which checks for exact type equality; _output_shapes results in |
| // us shoehorning in the more-precise type on the output). |
| if (const AttrValue* attr = node->attrs().Find("_output_shapes")) |
| TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); |
| } |
| |
| // If it is the argument node, the shape handle is set explicitly, so it |
| // can be propagated to the body nodes of the function. |
| if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) { |
| auto* node_context = shape_refiner_->GetContext(node); |
| DCHECK(node_context != nullptr); |
| if (const AttrValue* attr = node->attrs().Find("shape")) { |
| shape_inference::ShapeHandle handle; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| node_context->MakeShapeFromShapeProto(attr->shape(), &handle), |
| GetLocationStr(*node)); |
| node_context->set_output(0, handle); |
| } else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) { |
| TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); |
| } else { |
| node_context->set_output(0, node_context->UnknownShape()); |
| } |
| } |
| } |
| |
| // Since we might have inserted and removed nodes from the graph, fix |
| // source/sink edges and reconstruct the RPO ordering of nodes |
| FixupSourceAndSinkEdges(graph_.get()); |
| |
| // Prune nodes in the graph that are not reachable from the output. |
| if (specs_.prune_unused_nodes) { |
| std::unordered_set<const Node*> prune_start; |
| TF_RETURN_IF_ERROR(GetInputOutputNodes(node_name_map, &prune_start)); |
| if (!prune_start.empty()) { |
| if (PruneForReverseReachability(graph_.get(), prune_start)) { |
| VLOG(1) << "Pruned unused nodes in graphdef"; |
| } else { |
| VLOG(1) << "No unused nodes in graphdef to prune"; |
| } |
| } else { |
| VLOG(1) << "No output nodes specified, skipping pruning"; |
| } |
| } else { |
| VLOG(1) << "Pruning unused nodes in graphdef is disabled"; |
| } |
| |
| // Re-initialize ordered_nodes_ since we might have modified the graph. |
| GetReversePostOrder( |
| *graph_, &ordered_nodes_, |
| [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); |
| |
| VLOG(1) << "Inferring graph shapes to fixpoint"; |
| |
| // The "changed" information from UpdateNode can give false positives, so we |
| // create a dedicated method to verify the shapes are not changed before and |
| // after the shape refine. |
| auto same_inferred_shape = [](shape_inference::InferenceContext* c, |
| shape_inference::ShapeHandle s0, |
| shape_inference::ShapeHandle s1) -> bool { |
| if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) { |
| return true; |
| } |
| if (c->Rank(s0) != c->Rank(s1)) { |
| return false; |
| } |
| for (int i = 0; i < c->Rank(s0); ++i) { |
| if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) { |
| int64 val0 = c->Value(c->Dim(s0, i)); |
| int64 val1 = c->Value(c->Dim(s1, i)); |
| // Negative value is treated as unknown so all negative values indicate |
| // the same dimension. |
| if (val0 >= 0 && val1 >= 0 && val0 != val1) return false; |
| } |
| } |
| return true; |
| }; |
| |
| bool changed = true; |
| int i = 0; |
| const int kMaxIterationCount = 2; |
| while (changed && i != kMaxIterationCount) { |
| changed = false; |
| for (const Node* node : ordered_nodes_) { |
| auto* shape_context = shape_refiner_->GetContext(node); |
| DCHECK(shape_context != nullptr); |
| absl::InlinedVector<shape_inference::ShapeHandle, 4> existing; |
| existing.reserve(shape_context->num_outputs()); |
| for (int o = 0; o < shape_context->num_outputs(); ++o) { |
| existing.push_back(shape_context->output(o)); |
| } |
| bool inferred = false; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred), |
| GetLocationStr(*node)); |
| for (int o = 0; o < shape_context->num_outputs(); ++o) { |
| if (!same_inferred_shape(shape_context, shape_context->output(o), |
| existing[o])) { |
| changed = true; |
| break; |
| } |
| } |
| } |
| ++i; |
| } |
| if (i >= kMaxIterationCount) { |
| LOG(WARNING) << "Graph shapes did not converge to a fixpoint within " |
| << kMaxIterationCount |
| << " iterations. Graph shapes may be conservative."; |
| } |
| VLOG(1) << "Graph shapes were inferred with " << (i - 1) |
| << " extra rounds of analysis to reach a fixpoint."; |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::TensorType> ImporterBase::InferInputType(const Node& node, |
| int idx, |
| mlir::Builder builder) { |
| ExtendedInferenceContext* shape_context = |
| shape_refiner_->GetExtendedContext(&node); |
| DataType dtype = shape_context->input_type(idx); |
| auto* context = shape_context->get_context(); |
| return ConvertDataTypeAndShape(dtype, context->input(idx), |
| context->input_handle_shapes_and_types(idx), |
| context, builder); |
| } |
| |
| StatusOr<mlir::TensorType> ImporterBase::InferOutputType( |
| const Node& node, int idx, mlir::Builder builder) { |
| ExtendedInferenceContext* shape_context = |
| shape_refiner_->GetExtendedContext(&node); |
| DataType dtype = shape_context->output_type(idx); |
| auto* context = shape_context->get_context(); |
| return ConvertDataTypeAndShape(dtype, context->output(idx), |
| context->output_handle_shapes_and_types(idx), |
| context, builder); |
| } |
| |
| StatusOr<mlir::TensorType> ImporterBase::ConvertDataTypeAndShape( |
| DataType dtype, const shape_inference::ShapeHandle& handle, |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder) { |
| TF_ASSIGN_OR_RETURN(auto subtypes, |
| ConvertSubtypes(handle_subtypes, context, builder)); |
| |
| mlir::Type element_type; |
| if (dtype == DT_VARIANT) |
| element_type = mlir::TF::VariantType::get(subtypes, context_); |
| else if (dtype == DT_RESOURCE) |
| element_type = mlir::TF::ResourceType::get(subtypes, context_); |
| else |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(dtype, builder, &element_type)); |
| |
| return ConvertElementTypeAndShape(element_type, handle, context, builder); |
| } |
| |
| StatusOr<mlir::TensorType> ImporterBase::ConvertElementTypeAndShape( |
| mlir::Type element_type, const shape_inference::ShapeHandle& handle, |
| shape_inference::InferenceContext* context, mlir::Builder builder) { |
| if (!context->RankKnown(handle)) { |
| return mlir::UnrankedTensorType::get(element_type); |
| } |
| |
| // Sentinel for an unknown dimension size. getTensorType interprets any |
| // negative value as an unknown dimension. |
| // TODO(jmolloy): Ideally this shouldn't be a local sentinel. |
| const int64_t kUnknownDim = -1; |
| |
| absl::InlinedVector<int64_t, 4> dimensions; |
| int32 rank = context->Rank(handle); |
| dimensions.reserve(rank); |
| for (int i = 0; i < rank; ++i) { |
| auto dim_handle = context->Dim(handle, i); |
| if (!context->ValueKnown(dim_handle)) |
| dimensions.push_back(kUnknownDim); |
| else |
| dimensions.push_back(context->Value(dim_handle)); |
| } |
| |
| return mlir::RankedTensorType::get( |
| llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type); |
| } |
| |
| StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes( |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder) { |
| ElementSubtypes subtypes; |
| if (!handle_subtypes) return subtypes; |
| |
| subtypes.reserve(handle_subtypes->size()); |
| for (const auto& subtype : *handle_subtypes) { |
| mlir::Type element_type; |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type)); |
| TF_ASSIGN_OR_RETURN(mlir::TensorType type, |
| ConvertElementTypeAndShape(element_type, subtype.shape, |
| context, builder)); |
| subtypes.push_back(type); |
| } |
| return subtypes; |
| } |
| |
| Status ImporterBase::ConvertFunctionCallAttribute( |
| const std::string& base_name, const AttrValue& value, |
| llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) { |
| TF_ASSIGN_OR_RETURN(auto func_attr, |
| ConvertFunctionCallName(value.func().name())); |
| attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); |
| |
| for (const auto& it : value.func().attr()) { |
| auto name = absl::StrCat(base_name, ".", it.first); |
| TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second)); |
| attributes->push_back(builder_.getNamedAttr(name, value)); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName( |
| const std::string& func_name) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(func_name)); |
| auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name]; |
| auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name); |
| return builder_.getSymbolRefAttr(func); |
| } |
| |
| StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue( |
| const AttrValue& value) { |
| switch (value.value_case()) { |
| case AttrValue::kI: |
| return builder_.getI64IntegerAttr(value.i()); |
| case AttrValue::kS: |
| return builder_.getStringAttr(value.s()); |
| case AttrValue::kF: |
| return builder_.getFloatAttr(builder_.getF32Type(), value.f()); |
| case AttrValue::kB: |
| return builder_.getBoolAttr(value.b()); |
| case AttrValue::kType: { |
| mlir::Type type; |
| TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type)); |
| return mlir::TypeAttr::get(type); |
| } |
| case AttrValue::kShape: |
| return builder_.getStringAttr(mangling_util::MangleShape(value.shape())); |
| case AttrValue::kTensor: |
| return ConvertTensorProto(value.tensor()); |
| case AttrValue::kList: { |
| absl::InlinedVector<mlir::Attribute, 8> attrs; |
| for (const auto& item : value.list().i()) |
| attrs.push_back(builder_.getI64IntegerAttr(item)); |
| for (const auto& item : value.list().s()) |
| attrs.push_back(builder_.getStringAttr(item)); |
| for (const auto& item : value.list().f()) |
| attrs.push_back(builder_.getFloatAttr(builder_.getF32Type(), item)); |
| for (const auto& item : value.list().b()) |
| attrs.push_back(builder_.getBoolAttr(item)); |
| for (const auto& item : value.list().type()) { |
| attrs.push_back(builder_.getStringAttr( |
| mangling_util::MangleDataType(static_cast<DataType>(item)))); |
| } |
| for (const auto& item : value.list().shape()) { |
| attrs.push_back( |
| builder_.getStringAttr(mangling_util::MangleShape(item))); |
| } |
| for (const auto& item : value.list().tensor()) { |
| TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item)); |
| attrs.push_back(attr); |
| } |
| for (const auto& item : value.list().func()) { |
| TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); |
| if (item.attr_size() != 0) |
| return errors::Unimplemented( |
| "func attributes with non-zero attr.size()"); |
| attrs.push_back(attr); |
| } |
| return builder_.getArrayAttr( |
| llvm::makeArrayRef(attrs.begin(), attrs.end())); |
| } |
| case AttrValue::kFunc: |
| return errors::Unknown("kFunc type should be handled separately!"); |
| case AttrValue::VALUE_NOT_SET: |
| return builder_.getUnitAttr(); |
| // kPlaceholder is not implemented. |
| default: |
| return errors::Unimplemented( |
| absl::StrCat("Attribute ", value.DebugString())); |
| } |
| } |
| |
| void ImporterBase::GetArgsAndRetsFromFunctionBody( |
| const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes, |
| absl::InlinedVector<std::pair<int64_t, int64_t>, 4>* |
| resource_arg_unique_ids) { |
| arg_nodes->reserve(fbody.arg_nodes.size()); |
| ret_nodes->reserve(fbody.ret_nodes.size()); |
| for (auto arg : fbody.arg_nodes) { |
| arg_nodes->emplace_back(arg, 0); |
| } |
| for (auto ret : fbody.ret_nodes) { |
| ret_nodes->emplace_back(ret, 0); |
| } |
| for (const auto& entry : fbody.fdef.resource_arg_unique_id()) { |
| resource_arg_unique_ids->push_back(entry); |
| } |
| *control_ret_nodes = fbody.control_ret_nodes; |
| } |
| |
| Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { |
| // If the library function has been converted already, nothing needs to be |
| // done. |
| if (tf_name_to_mlir_name_->find(std::string(func_name)) != |
| tf_name_to_mlir_name_->end()) |
| return Status::OK(); |
| |
| std::string mlir_func_name( |
| function_name_uniquifier_->GetUniqueName(func_name)); |
| (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name; |
| |
| const auto& func_lib = graph_flib_; |
| const auto* func_def = func_lib.Find(std::string(func_name)); |
| if (func_def == nullptr) { |
| return errors::FailedPrecondition( |
| absl::StrCat("Failed to find function '", StringRefToView(func_name), |
| "'. The imported TensorFlow GraphDef is ill-formed.")); |
| } |
| |
| // Converts the function definition to a graph. |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR( |
| FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody)); |
| |
| // Converts the argument and return types to mlir types. |
| absl::InlinedVector<mlir::NamedAttribute, 8> attributes; |
| attributes.reserve(func_def->attr_size()); |
| for (const auto& name_and_value : func_def->attr()) { |
| // This is a function definition attribute, so it shouldn't contain |
| // kFunc attribute and it is treated as normal one. |
| TF_ASSIGN_OR_RETURN(auto attr, |
| ConvertAttributeValue(name_and_value.second)); |
| std::string attr_name = |
| mangling_util::MangleAttributeName(name_and_value.first); |
| attributes.push_back(builder_.getNamedAttr(attr_name, attr)); |
| } |
| |
| // Checks opdef stateful attribute and import that as Function Attribute |
| if (func_def->signature().is_stateful()) { |
| auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName(); |
| attributes.push_back( |
| builder_.getNamedAttr(stateful_str, builder_.getUnitAttr())); |
| } |
| |
| // Checks for an associated custom gradient function. Adds it to the attribute |
| // list of this function. |
| auto grad_func_name = func_lib.FindGradient(std::string(func_name)); |
| if (!grad_func_name.empty()) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name)); |
| auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name]; |
| auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name); |
| auto gradient_attr = builder_.getSymbolRefAttr(grad_func); |
| auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); |
| attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); |
| } |
| |
| // Converts the graph to an MLIR function and adds it to the module. |
| // We populate the NodeSpec so that all the _Arg ops get their shape |
| // added correctly. |
| GraphImportConfig specs; |
| for (const auto& name_and_value : func_def->attr()) { |
| if (name_and_value.first == "_input_shapes") { |
| auto& list = name_and_value.second.list(); |
| auto& signature = func_def->signature(); |
| if (list.shape_size() != signature.input_arg_size()) { |
| return errors::FailedPrecondition( |
| "Number of input arguments must be equal to the length of " |
| "_input_shapes attribute in function '", |
| StringRefToView(func_name), "'."); |
| } |
| for (int i = 0; i < list.shape_size(); i++) { |
| auto& input_arg = signature.input_arg(i); |
| auto& array_info = specs.inputs[input_arg.name()]; |
| array_info.imported_dtype = input_arg.type(); |
| array_info.shape = list.shape(i); |
| } |
| } |
| } |
| |
| ImporterBase child_importer(graph_flib_, debug_info_, specs, module_, |
| tf_name_to_mlir_name_, function_name_uniquifier_, |
| func_name); |
| TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph)); |
| |
| TF_ASSIGN_OR_RETURN(auto func_type, |
| child_importer.InferLibFunctionType(*fbody)); |
| |
| absl::InlinedVector<OutputTensor, 4> arg_nodes; |
| absl::InlinedVector<OutputTensor, 4> ret_nodes; |
| absl::InlinedVector<Node*, 4> control_ret_nodes; |
| absl::InlinedVector<std::pair<int64_t, int64_t>, 4> resource_arg_unique_ids; |
| GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, |
| &control_ret_nodes, &resource_arg_unique_ids); |
| |
| TF_RETURN_IF_ERROR(child_importer.Convert( |
| mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, |
| llvm::makeArrayRef(attributes.begin(), attributes.end()), |
| resource_arg_unique_ids)); |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::PrepareConvert(const Graph& graph) { |
| TF_RETURN_IF_ERROR(RemoveBackedges(graph)); |
| TF_RETURN_IF_ERROR(AddNodesToShapeRefiner()); |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::Convert( |
| llvm::StringRef func_name, mlir::FunctionType func_type, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes, |
| llvm::ArrayRef<mlir::NamedAttribute> attrs, |
| const absl::InlinedVector<std::pair<int64_t, int64_t>, 4>& |
| resource_arg_unique_ids) { |
| // TODO(b/122040776): Uses debug info for FunctionDef. |
| auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), |
| func_name, func_type, attrs); |
| |
| module_.push_back(function); |
| // Seeds the builder with an initial block. |
| function.addEntryBlock(); |
| builder_ = mlir::OpBuilder(function.getBody()); |
| |
| // Create the graph operation in which we will convert the individual nodes. |
| auto graph = builder_.create<mlir::tf_executor::GraphOp>( |
| function.getLoc(), func_type.getResults()); |
| builder_.createBlock(&graph.body()); |
| |
| for (const Node* node : ordered_nodes_) { |
| TF_RETURN_IF_ERROR(ConvertNode(*node)); |
| } |
| |
| // Adds the backedges back to the function by creating the source and sink |
| // pairs. |
| TF_RETURN_IF_ERROR(AddBackedges()); |
| |
| TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph, |
| func_type.getInputs(), arg_nodes, |
| ret_nodes, control_ret_nodes)); |
| for (const auto& entry : resource_arg_unique_ids) { |
| function.setArgAttr(entry.first, "tf.resource_arg_unique_id", |
| builder_.getI64IntegerAttr(entry.second)); |
| } |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::ConvertFunctionArgAndRets( |
| mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op, |
| llvm::ArrayRef<mlir::Type> arg_types, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes) { |
| auto* bb = &func.front(); |
| llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4> |
| arg_nodes_to_values; |
| for (int i = 0, e = arg_types.size(); i < e; ++i) { |
| auto& arg_node = arg_nodes[i]; |
| // The lookup can't fail here: otherwise some nodes in the function haven't |
| // be converted to mlir operations and don't have a mapping. |
| mlir::Operation* island = node_values_.find(arg_node.node->id())->second; |
| |
| auto bb_arg = bb->getArgument(i); |
| mlir::Value arg_def = bb_arg; |
| |
| if (island->getNumResults() != 2) |
| return errors::InvalidArgument( |
| "Only feed output tensors of single output nodes are supported"); |
| |
| // Collect mapping of OutputTensor to associated block arg. |
| arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); |
| island->getResult(0).replaceAllUsesWith(arg_def); |
| // Erase control outputs from feed. |
| auto control_uses = island->getResult(1).getUses(); |
| for (auto& control_use : llvm::make_early_inc_range(control_uses)) |
| control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); |
| |
| if (!arg_node.node->requested_device().empty()) |
| func.setArgAttr( |
| i, "tf.device", |
| builder_.getStringAttr(arg_node.node->requested_device())); |
| |
| island->dropAllReferences(); |
| island->erase(); |
| } |
| |
| llvm::SmallVector<mlir::Value, 8> inst_to_return; |
| for (const auto& ret : ret_nodes) { |
| auto* inst = node_values_[ret.node->id()]; |
| auto op = absl::string_view(ret.node->type_string()); |
| if (op == FunctionLibraryDefinition::kRetOp || |
| op == FunctionLibraryDefinition::kDeviceRetOp) { |
| // Lookup the instruction inside the island |
| auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst); |
| mlir::Operation* inner_op = &island_op.GetBody().front(); |
| // Remove kRetOp or kDeviceRetOp operation and return its operand. |
| // kRetOp and kDeviceRetOp should have just one operand unless they have |
| // control dependencies. |
| if (inner_op->getNumOperands() != 1) |
| return errors::Unimplemented("Return node with multiple inputs."); |
| inst_to_return.push_back(inner_op->getOperand(0)); |
| inst->dropAllReferences(); |
| inst->erase(); |
| } else { |
| // Lookup and use block arg if fetch is a feed. |
| auto it = arg_nodes_to_values.find({ret.node, ret.index}); |
| if (it != arg_nodes_to_values.end()) |
| inst_to_return.push_back(it->second); |
| else |
| inst_to_return.push_back(inst->getResult(ret.index)); |
| } |
| } |
| |
| for (Node* control_ret : control_ret_nodes) { |
| auto* inst = node_values_[control_ret->id()]; |
| inst_to_return.push_back(*std::prev(inst->result_end())); |
| } |
| |
| // Terminate the function by adding a Fetch operation to terminate the graph |
| // and a return operation to return the Graph results. |
| builder_.setInsertionPointToEnd(&graph_op.body().front()); |
| builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(), |
| inst_to_return); |
| builder_.setInsertionPointToEnd(bb); |
| builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_), |
| graph_op.getResults()); |
| return Status::OK(); |
| } |
| |
| mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { |
| // TODO(b/142400497): What is the semantic contract for locations? |
| const auto& debug_info = debug_info_.traces(); |
| |
| // Create a location for node `name` in function `function_name`. |
| auto create_location = [&](llvm::StringRef name, |
| llvm::StringRef function_name) -> mlir::Location { |
| // Use the catenation of function and node names as the lookup key into the |
| // debug info. This matches the way that the key is formed on the python |
| // side. |
| // |
| // We also use this as the name for the NameLoc for ops in function, since |
| // otherwise our names could collide across functions. |
| // For ops in the main graph, we omit the "@function_name" (which, would be |
| // just "@" since function_name would be empty) because some code seems to |
| // depend on the name being this way for correctness. |
| std::string debug_info_key = (name + "@" + function_name).str(); |
| std::string name_for_name_loc = |
| function_name.empty() ? name.str() : (name + "@" + function_name).str(); |
| auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_); |
| const auto& location_it = debug_info.find(debug_info_key); |
| if (location_it == debug_info.end()) { |
| return mlir::NameLoc::get(name_loc_id, context_); |
| } |
| |
| // Convert the stack trace to a chain of mlir::CallSiteLocs. |
| const auto& trace = location_it->second; |
| llvm::SmallVector<mlir::Location, 4> locations; |
| locations.reserve(trace.file_line_cols_size()); |
| for (const auto& location : trace.file_line_cols()) { |
| const auto& file = debug_info_.files(location.file_index()); |
| auto file_name = mlir::Identifier::get(file, context_); |
| auto file_line_loc = mlir::FileLineColLoc::get(file_name, location.line(), |
| location.col(), context_); |
| locations.push_back(file_line_loc); |
| } |
| |
| // If there are no locations in the stack trace, fall back to just a |
| // NameLoc with no child. |
| if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_); |
| |
| // Use the front FileLineColLoc to generate a NameLoc. |
| mlir::Location node_name_loc = |
| mlir::NameLoc::get(name_loc_id, locations.front()); |
| |
| // If there are more locations then generate a stack trace, otherwise just |
| // return the name loc. |
| auto callsite_locs = llvm::makeArrayRef(locations).drop_front(); |
| return callsite_locs.empty() |
| ? node_name_loc |
| : mlir::CallSiteLoc::get(node_name_loc, callsite_locs); |
| }; |
| |
| // For NextIteration nodes, location is used to pair source and sink nodes. |
| // Hence, we use node name as location to keep it unique. |
| // TODO(prakalps): In future the plan is to use tokens to pair source/sink |
| // nodes. Then NextIteration nodes would not need to be handled separately. |
| if (node_def.op() == "NextIteration") |
| return create_location(node_def.name(), function_name_for_debug_info_); |
| |
| auto original_nodes = |
| node_def.experimental_debug_info().original_node_names(); |
| auto original_funcs = |
| node_def.experimental_debug_info().original_func_names(); |
| |
| if (original_nodes.empty()) { |
| return create_location(node_def.name(), function_name_for_debug_info_); |
| } else { |
| // If the original nodes are defined, then we use them to get a list of |
| // call sites, and then fuse them to a single fused location, with the name |
| // of the node_def. |
| llvm::SmallVector<mlir::Location, 4> node_locations; |
| node_locations.reserve(original_nodes.size() + 1); |
| |
| // store the names in the experimental_debug_info |
| for (int i = 0, e = original_nodes.size(); i != e; ++i) { |
| auto node_name = original_nodes[i]; |
| auto func_name = (i < original_funcs.size()) ? original_funcs[i] : ""; |
| node_locations.push_back(create_location(node_name, func_name)); |
| } |
| // store the name of the node_def |
| node_locations.push_back( |
| create_location(node_def.name(), function_name_for_debug_info_)); |
| return mlir::FusedLoc::get(node_locations, context_); |
| } |
| } |
| |
| std::string ImporterBase::GetLocationStr(const Node& node, |
| bool includeNodeName) { |
| const auto location = GetLocation(node.def()); |
| std::string s; |
| llvm::raw_string_ostream ss(s); |
| location.print(ss); |
| ss.flush(); |
| // Removes the node name prefix if it exists. |
| if (!s.empty() && s[0] == '\"' && s.find_first_of(node.name()) == 1) { |
| return s.replace(0, node.name().size() + 3, ""); |
| } |
| return s; |
| } |
| |
| mlir::Operation* ImporterBase::createOperation( |
| const Node& node, llvm::StringRef node_type_name, |
| const mlir::OperationState& result, |
| const llvm::SmallVectorImpl<mlir::Value>& control_operands, |
| bool convert_to_legacy_call) { |
| // For the tf.executor specific operations (not wrapped in an island), we |
| // have an extra returned value for the control result, and we concatenate |
| // control and non-control operands. |
| mlir::SmallVector<mlir::Type, 4> types(result.types); |
| types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext())); |
| mlir::SmallVector<mlir::Value, 4> operands(result.operands); |
| operands.append(control_operands.begin(), control_operands.end()); |
| |
| auto loc = result.location; |
| // Dispatch based on the name and create the appropriate operation. |
| if (node.IsSwitch()) { |
| // Switch and _SwitchN both are in switch class, differentiate based on |
| // op name. |
| if (node.op_def().name() == "_SwitchN") { |
| return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands, |
| result.attributes); |
| } |
| return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsMerge()) { |
| return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsNextIteration()) { |
| // NextIteration is a bit special, we create a pair of operations that are |
| // linked together through a token returned by the source. |
| // We make use of a separate builder to insert the source at the top of |
| // the block. |
| mlir::OpBuilder builder_at_begin(builder_.getBlock(), |
| builder_.getBlock()->begin()); |
| auto source_op = |
| builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>( |
| loc, operands[0].getType(), result.attributes); |
| return builder_.create<mlir::tf_executor::NextIterationSinkOp>( |
| loc, source_op.token(), operands, result.attributes); |
| } |
| if (node.IsLoopCond()) { |
| return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsEnter()) { |
| return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsExit()) { |
| return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsControlTrigger()) { |
| return builder_.create<mlir::tf_executor::ControlTriggerOp>( |
| loc, operands, result.attributes); |
| } |
| // Regular TensorFlow operation are wrapped in a tf_executor.island. |
| auto island = builder_.create<mlir::tf_executor::IslandOp>( |
| result.location, types, control_operands, |
| mlir::ArrayRef<mlir::NamedAttribute>{}); |
| island.body().push_back(new mlir::Block); |
| mlir::OpBuilder island_builder(&island.GetBody()); |
| |
| // Create the operation inside the island now. |
| mlir::Operation* inner_op; |
| if (convert_to_legacy_call) { |
| bool disable_call_shape_inference = false; |
| for (const auto& name_and_value : node.attrs()) { |
| const auto& attr_name = name_and_value.first; |
| const AttrValue& attr_value = name_and_value.second; |
| if (strcmp(attr_name.c_str(), |
| disable_call_shape_inference_attribute_name) == 0 && |
| attr_value.value_case() == AttrValue::kB) { |
| disable_call_shape_inference = attr_value.b(); |
| } |
| } |
| |
| mlir::BoolAttr attribute = |
| builder_.getBoolAttr(disable_call_shape_inference); |
| inner_op = island_builder.create<mlir::TF::LegacyCallOp>( |
| result.location, result.types, result.operands, |
| island_builder.getSymbolRefAttr(node_type_name), attribute); |
| } else { |
| inner_op = island_builder.createOperation(result); |
| } |
| |
| if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) { |
| // The op has multiple variadic outputs. |
| // Calculate result segment sizes using the OpDef. |
| NameRangeMap output_ranges; |
| // This will fail only if the OpDef is syntactically invalid. |
| // TODO(jpienaar): Convert this CHECK into a properly propagated error. |
| TF_CHECK_OK( |
| NameRangesForNode(node, node.op_def(), nullptr, &output_ranges)); |
| std::vector<mlir::Attribute> values; |
| values.reserve(node.op_def().output_arg_size()); |
| for (const auto& output_arg : node.op_def().output_arg()) { |
| auto range = output_ranges[output_arg.name()]; |
| values.push_back( |
| island_builder.getI32IntegerAttr(range.second - range.first)); |
| } |
| |
| // Add derived "result_segment_sizes" attr to the created operation. |
| // TODO(b/146937733): Don't use <void> here. |
| llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments< |
| void>::getResultSegmentSizeAttr(); |
| auto attr_type = mlir::VectorType::get(node.op_def().output_arg_size(), |
| builder_.getIntegerType(32)); |
| auto attr_value = mlir::DenseElementsAttr::get(attr_type, values); |
| inner_op->setAttr(attr_name, attr_value); |
| } |
| |
| // Add the terminator for the island |
| island_builder.create<mlir::tf_executor::YieldOp>(result.location, |
| inner_op->getResults()); |
| return island.getOperation(); |
| } |
| |
| Status ImporterBase::ConvertNode(const Node& node) { |
| if (!node.IsOp()) { |
| // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by |
| // Graph and don't exist in GraphDef. |
| return Status::OK(); |
| } |
| |
| // If it is a custom OP, its definition should be found in the library. We |
| // create the MLIR function and insert it to the module if it doesn't exist. |
| std::string node_type_name = node.type_string(); |
| const auto* func_def = graph_flib_.Find(node_type_name); |
| bool convert_to_legacy_call = false; |
| if (func_def) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name)); |
| node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; |
| convert_to_legacy_call = true; |
| } |
| |
| auto get_full_op_name = [&](const std::string& op_name) { |
| const char* kTfPrefix = "tf."; |
| return kTfPrefix + op_name; |
| }; |
| |
| std::string op_name = get_full_op_name(node_type_name); |
| if (back_edge_node_output_.contains(&node)) { |
| op_name = op_name + ".sink"; |
| } |
| |
| const auto& node_def = node.def(); |
| mlir::OperationState result(GetLocation(node_def), op_name); |
| for (int i = 0; i < node.num_outputs(); ++i) { |
| // The backedge has been removed, so we shouldn't count the corresponding |
| // output from the src node when converting to an operation. |
| if (back_edge_node_output_.contains(&node) && |
| back_edge_node_output_[&node] == i) { |
| continue; |
| } |
| TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_)); |
| result.types.push_back(type); |
| } |
| |
| // Surprisingly input edges can be nondeterministically ordered. This |
| // particularly seems to be the case for the control edges between _SOURCE |
| // and _SINK that the Graph constructor inserts. Copy the input edges and |
| // sort the edges, but only the control edges, not data edges! |
| // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes. |
| // They'll break roundtripping anyway unless we strip them when converting |
| // back to graphdef. |
| absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size()); |
| absl::c_copy(node.in_edges(), in_edges.begin()); |
| absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) { |
| if (e1->IsControlEdge() && !e2->IsControlEdge()) return false; |
| if (!e1->IsControlEdge() && e2->IsControlEdge()) return true; |
| return e1->dst_input() < e2->dst_input(); |
| }); |
| |
| result.operands.reserve(in_edges.size()); |
| |
| // Collect the control operands separately, they will be held by the island. |
| mlir::SmallVector<mlir::Value, 8> control_operands; |
| |
| for (const auto* input_edge : in_edges) { |
| const Node& input_node = *input_edge->src(); |
| if (input_node.IsSource()) { |
| if (in_edges.size() != 1) { |
| return errors::FailedPrecondition( |
| "The node has other inputs besides the _Source node"); |
| } |
| // We don't import the _SOURCE node. |
| continue; |
| } |
| if (input_node.IsArg() && input_edge->IsControlEdge()) { |
| // Currently we have not reached consensus as to what TF function |
| // semantics are (b/133509504). Here we assume that all arguments to a |
| // function should be available before we start execution of any internal |
| // node. This makes the control dependencies between function arguments |
| // and internal nodes redundant, and so we do not import them. The TF |
| // inliner however assumes no such dependency between function args and |
| // internal nodes exists, unless explicitly stated. Since we drop control |
| // dependencies here, it leads to loss of information. If the function is |
| // inlined later, the inliner would not know of these explicit control |
| // dependencies present in the original graph. |
| continue; |
| } |
| if (node_values_.find(input_node.id()) == node_values_.end()) |
| return errors::FailedPrecondition( |
| "Graph not traversed in reverse post order; use seen before def!"); |
| mlir::Operation* inst = node_values_[input_node.id()]; |
| if (input_edge->IsControlEdge()) |
| control_operands.push_back(inst->getResult(inst->getNumResults() - 1)); |
| else |
| result.operands.push_back(inst->getResult(input_edge->src_output())); |
| } |
| |
| using FuncPairType = std::pair<const std::string*, const AttrValue*>; |
| std::vector<FuncPairType> funcs; |
| result.attributes.reserve(node.attrs().size() + 2); |
| for (const auto& name_and_value : node.attrs()) { |
| const auto& attr_name = name_and_value.first; |
| const AttrValue& attr_value = name_and_value.second; |
| // LegacyCall can only represent _diable_call_shape_inference attribute. |
| // If a call has other attributes, can't convert it to LegacyCall. |
| if (convert_to_legacy_call && |
| (strcmp(attr_name.c_str(), |
| disable_call_shape_inference_attribute_name) || |
| attr_value.value_case() != AttrValue::kB)) { |
| convert_to_legacy_call = false; |
| } |
| if (attr_value.value_case() == AttrValue::kFunc) { |
| // Attribute iteration order is not defined for protocol buffer Map. |
| // Process function attributes separately in the lexicographical order to |
| // have deterministic order of functions in the constructed IR. |
| funcs.emplace_back(&attr_name, &attr_value); |
| } else { |
| TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value)); |
| result.attributes.push_back(builder_.getNamedAttr(attr_name, attr)); |
| } |
| } |
| |
| auto comparator = [](const FuncPairType& a, const FuncPairType& b) { |
| return *a.first < *b.first; |
| }; |
| std::sort(funcs.begin(), funcs.end(), comparator); |
| for (const auto& func : funcs) { |
| TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second, |
| &result.attributes)); |
| } |
| |
| result.attributes.push_back(builder_.getNamedAttr( |
| "device", builder_.getStringAttr(std::string(node_def.device())))); |
| |
| // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add |
| // the differentiating attribute. |
| if (node.IsIfNode()) { |
| result.name = mlir::OperationName(get_full_op_name("If"), context_); |
| mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf"); |
| result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); |
| } |
| |
| // Map While and StatelessWhile op in TensorFlow to the common While op in |
| // MLIR and add the differentiating attribute. |
| if (node.IsWhileNode()) { |
| result.name = mlir::OperationName(get_full_op_name("While"), context_); |
| mlir::BoolAttr val = |
| builder_.getBoolAttr(node_type_name == "StatelessWhile"); |
| result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); |
| } |
| |
| // Register the mapping between the TF node and the newly created operation. |
| node_values_[node.id()] = createOperation( |
| node, node_type_name, result, control_operands, convert_to_legacy_call); |
| return Status::OK(); |
| } |
| |
| // Add the backedges to the CFG. Given a backedge, we replace the original |
| // source and destination operations by two new operations. Most of the |
| // fields of the replacements are copied from the original operations. |
| // However, |
| // - for the src operation, one output is inserted to the front of the output |
| // list. The type of the output is set to the type of the non-control result |
| // of the dst operation, and |
| // - for the dst operation, one operand is inserted to the front of the |
| // operand list. This operand is using the first result of the src |
| // operation. |
| // TODO(fengliuai): Preserve the order of the results and operands if |
| // necessary. |
| Status ImporterBase::AddBackedges() { |
| for (auto it : back_edge_dst_inputs_) { |
| BackEdge& edge = it.second; |
| if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) { |
| return errors::FailedPrecondition( |
| "Invalid backedge; should be from NextIteration to Merge!"); |
| } |
| auto* sink = node_values_[edge.src->id()]; |
| auto* dst = node_values_[edge.dst->id()]; |
| TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input)); |
| } |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, |
| int dst_input) { |
| // Get the NextIteration.Source operation from the token operand of the sink. |
| mlir::Operation* source = sink->getOperand(0).getDefiningOp(); |
| |
| // Adds the "source" to the operands of the dst by creating a new dst |
| // operation. |
| mlir::OperationState state(dst->getLoc(), dst->getName()); |
| auto num_operands = dst->getNumOperands(); |
| state.operands.reserve(num_operands + 1); |
| for (int input = 0, e = num_operands + 1; input != e; ++input) { |
| if (input < dst_input) { |
| state.operands.push_back(dst->getOperand(input)); |
| } else if (input == dst_input) { |
| state.operands.push_back(source->getResult(0)); |
| } else { |
| state.operands.push_back(dst->getOperand(input - 1)); |
| } |
| } |
| state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end()); |
| state.types.assign(dst->getResultTypes().begin(), |
| dst->getResultTypes().end()); |
| builder_.setInsertionPoint(dst); |
| auto* new_dst = builder_.createOperation(state); |
| |
| // Replaces the output uses of the old operation by the corresponding |
| // result of the new operation, and deletes the old operation. |
| for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) { |
| auto new_output = new_dst->getResult(i); |
| dst->getResult(i).replaceAllUsesWith(new_output); |
| } |
| dst->dropAllReferences(); |
| dst->erase(); |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType( |
| const FunctionBody& fbody) { |
| mlir::Builder builder(context_); |
| |
| // The FunctionBody contains a graph with a single-output _Arg node for each |
| // function argument and a single-input _Retval node for each function return |
| // value. |
| // |
| // We already populated the ShapeRefiner with all the information about the |
| // shapes of these graph edges, so we just query it to build the corresponding |
| // MLIR function type signature. |
| |
| llvm::SmallVector<mlir::Type, 4> arg_types; |
| arg_types.reserve(fbody.arg_types.size()); |
| for (auto arg : fbody.arg_nodes) { |
| // Find node in the graph using the node id instead of using `arg` directly |
| // because the graph has been cloned. |
| auto* node = graph_->FindNodeId(arg->id()); |
| TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder)); |
| arg_types.push_back(type); |
| } |
| |
| llvm::SmallVector<mlir::Type, 4> ret_types; |
| ret_types.reserve(fbody.ret_types.size()); |
| for (auto ret : fbody.ret_nodes) { |
| // Find node in the graph using the node id instead of using `ret` directly |
| // because the graph has been cloned. |
| auto* node = graph_->FindNodeId(ret->id()); |
| TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); |
| ret_types.push_back(type); |
| } |
| |
| return builder.getFunctionType(arg_types, ret_types); |
| } |
| |
| // Stateful helper class to import a TensorFlow model expressed in GraphDef into |
| // an MLIR Module. |
| // |
| // The nodes defined in the graph are converted to a function called |
| // 'func_name'. All library function definitions are converted to MLIR functions |
| // in the module. |
| class GraphDefImporter : public ImporterBase { |
| public: |
| // Main entry point: converts the given graph to an MLIR Module. |
| static StatusOr<mlir::OwningModuleRef> Convert( |
| mlir::MLIRContext* context, const Graph& graph, |
| const GraphDebugInfo& debug_info, |
| const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, |
| llvm::StringRef func_name); |
| |
| private: |
| explicit GraphDefImporter( |
| const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::ModuleOp module, |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name, |
| NameUniquifier* function_name_uniquifier) |
| : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, |
| function_name_uniquifier) {} |
| |
| // Returns the function signature of the main function of converted MLIR |
| // module, the input nodes and output nodes. The type and shape information |
| // for the function arguments are read from `specs`, but the type and shape |
| // information for the function returns are inferred by the shape refiner in |
| // ImporterBase. |
| StatusOr<mlir::FunctionType> InferMainFunctionType( |
| const GraphImportConfig& specs, mlir::MLIRContext* context, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes); |
| |
| // Returns the function signature of the main function, alongside input and |
| // output nodes, for function graphs. Arguments and return values are |
| // determined by node op type. Type and shape information of the function are |
| // inferred by the shape refiner in ImporterBase. |
| // `resource_arg_unique_ids` will be filled with the unique IDs of resource |
| // variables, as a list of {index, ID} pairs. |
| StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph( |
| mlir::MLIRContext* context, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes, |
| absl::InlinedVector<std::pair<int64_t, int64_t>, 4>* |
| resource_arg_unique_ids); |
| |
| // Finds the function's control ret nodes based on supplied node names in |
| // `control_outputs`. If `control_outputs` are not unique or a control ret |
| // node is missing, an error will be returned. |
| Status GetControlRetsFromFunctionGraph( |
| llvm::ArrayRef<std::string> control_outputs, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes); |
| }; |
| |
| StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert( |
| mlir::MLIRContext* context, const Graph& graph, |
| const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, |
| const GraphImportConfig& specs, llvm::StringRef func_name) { |
| mlir::OwningModuleRef module = |
| mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); |
| std::unordered_map<std::string, std::string> tf_name_to_mlir_name; |
| NameUniquifier function_name_uniquifier(flib_def); |
| |
| GraphDefImporter importer(flib_def, debug_info, specs, module.get(), |
| &tf_name_to_mlir_name, &function_name_uniquifier); |
| |
| TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); |
| |
| mlir::FunctionType func_type; |
| absl::InlinedVector<OutputTensor, 4> arg_nodes; |
| absl::InlinedVector<OutputTensor, 4> ret_nodes; |
| absl::InlinedVector<Node*, 4> control_ret_nodes; |
| absl::InlinedVector<std::pair<int64_t, int64_t>, 4> resource_arg_unique_ids; |
| llvm::SmallVector<mlir::NamedAttribute, 1> attrs; |
| if (specs.graph_as_function) { |
| if (specs.prune_unused_nodes || !specs.inputs.empty() || |
| !specs.outputs.empty()) |
| return errors::InvalidArgument( |
| "Pruning of graph is currently unsupported when the main graph is " |
| "converted to a function."); |
| |
| TF_ASSIGN_OR_RETURN( |
| func_type, |
| importer.GetArgsRetsAndTypesFromFunctionGraph( |
| context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids)); |
| |
| TF_RETURN_IF_ERROR(importer.GetControlRetsFromFunctionGraph( |
| specs.control_outputs, &control_ret_nodes)); |
| |
| if (!arg_nodes.empty() || !ret_nodes.empty() || |
| !control_ret_nodes.empty()) { |
| mlir::Builder b(context); |
| std::string s; |
| llvm::raw_string_ostream ss(s); |
| auto node_name = [&](const OutputTensor& tensor) { |
| ss << tensor.node->name(); |
| }; |
| mlir::interleave(arg_nodes, ss, node_name, ","); |
| auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| mlir::interleave(ret_nodes, ss, node_name, ","); |
| auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| mlir::interleave(specs.control_outputs, ss, ","); |
| auto control_outputs = |
| b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); |
| |
| attrs.push_back(b.getNamedAttr( |
| "tf.entry_function", |
| b.getDictionaryAttr({inputs, outputs, control_outputs}))); |
| } |
| } else { |
| // Collects the argument and return nodes by looking up the node names |
| // specified by the user. |
| TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType( |
| specs, context, &arg_nodes, &ret_nodes)); |
| |
| // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function, |
| // tf.versions) shared by importer and exporter in a centralized place. |
| // Record the input and output mapping. |
| if (!specs.inputs.empty() || !specs.outputs.empty()) { |
| mlir::Builder b(context); |
| std::string s; |
| llvm::raw_string_ostream ss(s); |
| mlir::interleave( |
| specs.inputs, ss, |
| [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; }, |
| ","); |
| auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| mlir::interleave(specs.outputs, ss, ","); |
| auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); |
| |
| attrs.push_back(b.getNamedAttr("tf.entry_function", |
| b.getDictionaryAttr({inputs, outputs}))); |
| } |
| } |
| |
| // Record version info. |
| PopulateTfVersions(module.get(), graph.versions()); |
| |
| TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( |
| func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, |
| resource_arg_unique_ids)); |
| return module; |
| } |
| |
| StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType( |
| const GraphImportConfig& specs, mlir::MLIRContext* context, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes) { |
| // Find all the input nodes and output nodes. |
| // Feeds have been remapped to single output nodes (Placeholder), so an exact |
| // name match is sufficient. |
| absl::flat_hash_map<absl::string_view, int> inputs; |
| for (auto input_and_idx : llvm::enumerate(specs.inputs)) { |
| TensorId tensor = ParseTensorName(input_and_idx.value().first); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| inputs.insert({remapped_it->second, input_and_idx.index()}); |
| } else { |
| inputs.insert({tensor.node(), input_and_idx.index()}); |
| } |
| } |
| |
| absl::flat_hash_set<absl::string_view> output_node_names; |
| std::vector<TensorId> outputs; |
| output_node_names.reserve(specs.outputs.size()); |
| for (const auto& output : specs.outputs) { |
| TensorId tensor = ParseTensorName(output); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| output_node_names.insert(remapped_it->second); |
| outputs.push_back({remapped_it->second, 0}); |
| } else { |
| output_node_names.insert(tensor.node()); |
| outputs.push_back(tensor); |
| } |
| } |
| |
| if (!inputs.empty() || !outputs.empty()) { |
| arg_nodes->resize(inputs.size()); |
| ret_nodes->resize(outputs.size()); |
| |
| for (Node* n : GetOrderedNodes()) { |
| // Handle inputs/arguments. |
| auto input_it = inputs.find(n->name()); |
| if (input_it != inputs.end()) { |
| (*arg_nodes)[input_it->second] = {n, 0}; |
| } |
| |
| // Handle outputs/returns. |
| if (output_node_names.contains(n->name())) { |
| for (int i = 0, e = outputs.size(); i != e; ++i) { |
| TensorId tensor = outputs[i]; |
| if (n->name() != tensor.node()) continue; |
| (*ret_nodes)[i] = {n, tensor.index()}; |
| } |
| } |
| } |
| } |
| |
| // Starts to construct the function type. |
| mlir::Builder builder(context); |
| llvm::SmallVector<mlir::Type, 4> arg_types; |
| arg_types.reserve(specs.inputs.size()); |
| int i = 0; |
| for (auto it : specs.inputs) { |
| Node* arg_node = arg_nodes->at(i).node; |
| if (arg_node == nullptr) { |
| return errors::InvalidArgument("Input ", it.first, |
| " was not found in graph"); |
| } |
| mlir::Type element_type; |
| const auto& node_info = it.second; |
| DataType imported_dtype = node_info.imported_dtype; |
| // Uses the existing output type of the arg node if the data type of the |
| // the node isn't specified through the import configuration. |
| if (imported_dtype == DT_INVALID) { |
| imported_dtype = arg_node->output_type(0); |
| if (imported_dtype == DT_INVALID) { |
| return errors::InvalidArgument("Input ", i, "has invalid data type"); |
| } |
| } |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type)); |
| llvm::SmallVector<int64_t, 4> shape; |
| TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); |
| arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); |
| i++; |
| } |
| |
| llvm::SmallVector<mlir::Type, 4> ret_types; |
| ret_types.reserve(specs.outputs.size()); |
| for (int i = 0, e = specs.outputs.size(); i != e; ++i) { |
| if (ret_nodes->at(i).node == nullptr) { |
| return errors::InvalidArgument("Output ", specs.outputs[i], |
| " was not found in graph"); |
| } |
| } |
| for (const auto& ret : *ret_nodes) { |
| if (ret.node->num_outputs() <= ret.index) { |
| return errors::InvalidArgument("Invalid output index ", ret.index, |
| " specified for node: ", ret.node->name()); |
| } |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferOutputType(*ret.node, ret.index, builder)); |
| ret_types.push_back(type); |
| } |
| |
| return builder.getFunctionType(arg_types, ret_types); |
| } |
| |
| StatusOr<mlir::FunctionType> |
| GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( |
| mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes, |
| absl::InlinedVector<std::pair<int64_t, int64_t>, 4>* |
| resource_arg_unique_ids) { |
| auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) { |
| auto* attr = node->attrs().Find("index"); |
| if (!attr) |
| return errors::InvalidArgument(node->type_string(), " node '", |
| node->name(), |
| "' is missing attribute 'index'"); |
| |
| auto index = attr->i(); |
| if (nodes->size() < index + 1) nodes->resize(index + 1); |
| |
| if ((*nodes)[index].node != nullptr) |
| return errors::InvalidArgument(node->type_string(), " node '", |
| node->name(), "' has attribute 'index' ", |
| index, " that conflicts with node '", |
| (*nodes)[index].node->name(), "'"); |
| (*nodes)[index] = {node, 0}; |
| |
| return Status::OK(); |
| }; |
| |
| // Collect arg and ret nodes from graph. |
| for (auto* node : GetOrderedNodes()) |
| if (node->IsArg()) |
| TF_RETURN_IF_ERROR(add_node(node, arg_nodes)); |
| else if (node->IsRetval()) |
| TF_RETURN_IF_ERROR(add_node(node, ret_nodes)); |
| |
| // Collect arg and ret types and create function type. |
| mlir::Builder builder(context); |
| llvm::SmallVector<mlir::Type, 4> arg_types; |
| arg_types.reserve(arg_nodes->size()); |
| for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) { |
| auto& arg_node = arg_node_and_idx.value(); |
| if (arg_node.node == nullptr) |
| return errors::InvalidArgument("Graph missing _Arg at index ", |
| arg_node_and_idx.index()); |
| |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferOutputType(*arg_node.node, /*idx=*/0, builder)); |
| arg_types.push_back(type); |
| tensorflow::int64 resource_arg_unique_id; |
| if (TryGetNodeAttr(arg_node.node->attrs(), "_resource_arg_unique_id", |
| &resource_arg_unique_id)) { |
| resource_arg_unique_ids->emplace_back(arg_node_and_idx.index(), |
| resource_arg_unique_id); |
| } |
| } |
| |
| llvm::SmallVector<mlir::Type, 4> ret_types; |
| ret_types.reserve(ret_nodes->size()); |
| for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) { |
| auto& ret_node = ret_node_and_idx.value(); |
| if (ret_node.node == nullptr) |
| return errors::InvalidArgument("Graph missing _Retval at index ", |
| ret_node_and_idx.index()); |
| |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferInputType(*ret_node.node, /*idx=*/0, builder)); |
| ret_types.push_back(type); |
| } |
| |
| return builder.getFunctionType(arg_types, ret_types); |
| } |
| |
| Status GraphDefImporter::GetControlRetsFromFunctionGraph( |
| llvm::ArrayRef<std::string> control_outputs, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes) { |
| if (control_outputs.empty()) return Status::OK(); |
| |
| llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx; |
| for (auto control_and_idx : llvm::enumerate(control_outputs)) |
| controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()}); |
| |
| if (controls_to_idx.size() != control_outputs.size()) |
| return errors::InvalidArgument("Control outputs must be unique"); |
| |
| control_ret_nodes->resize(controls_to_idx.size()); |
| |
| for (auto* node : GetOrderedNodes()) { |
| auto it = controls_to_idx.find(node->name()); |
| if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node; |
| } |
| |
| for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs)) |
| if (std::get<0>(node_and_name) == nullptr) |
| return errors::InvalidArgument( |
| "Control output '", std::get<1>(node_and_name), "' is missing"); |
| |
| return Status::OK(); |
| } |
| |
| // Stateful helper class to import a TensorFlow model expressed in SavedModel |
| // into an MLIR Module. |
| class SavedModelImporter : public ImporterBase { |
| public: |
| // Main entry point: converts all functions in the given meta graph to an MLIR |
| // Module. |
| static StatusOr<mlir::OwningModuleRef> Convert( |
| SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, |
| absl::Span<std::string> exported_names, bool add_default_attributes); |
| |
| private: |
| explicit SavedModelImporter( |
| const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::ModuleOp module, |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name, |
| NameUniquifier* function_name_uniquifier) |
| : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, |
| function_name_uniquifier) {} |
| }; |
| |
| // Determines the names used to reference objects in the SavedObjectGraph. |
| class ObjectNames { |
| public: |
| explicit ObjectNames(const SavedObjectGraph& object_graph, |
| absl::Span<std::string> exported_names); |
| |
| // Gets the names that external users of the SavedModel can use to refer to |
| // this node. |
| llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const; |
| |
| // Gets the name in the module symbol table for this node. |
| // This name is only used for internal IR references. |
| llvm::StringRef GetSymbolTableName(int node_id) const; |
| |
| private: |
| // In the absence of any other information, use this name as the symbol table |
| // name for this node. |
| std::string GetDefaultSymbolTableName(int node_id) const; |
| // Determines if a name is exported. |
| bool IsExported(const std::string& name); |
| // Main object graph traversal function. |
| void RecursivelyVisitObjectGraph(int node_id); |
| // Gets a stable StringRef from a std::string. |
| llvm::StringRef SaveString(const std::string& s) const; |
| |
| // The object graph we are traversing. |
| const SavedObjectGraph& object_graph_; |
| // The set of names to export. Empty means "export all". |
| std::unordered_set<std::string> names_to_export_; |
| |
| // When we recursively follow the object graph tree structure from the root, |
| // we track its path in the object graph by pushing and popping from here |
| // during traversal. |
| llvm::SmallVector<std::string, 8> path_segments_; |
| // The set of node_id's that are on the current DFS stack. |
| // For cyclic object graphs, this prevents infinite recursion. |
| std::unordered_set<int> on_stack_nodes_; |
| |
| // Key: node_id. |
| // Value: all object names that node_id appears as. |
| // Each object name corresponds to a unique path from the root of the object |
| // graph. |
| // The common intuitive case is when there is only one name for a given |
| // object, which corresponds to the object graph being a tree. |
| // |
| // But, there cases where the object graph is a general graph. For |
| // example, this happens commonly in Keras models, where `foo.bar` is |
| // also reachable via the name `keras_api.foo.bar`. |
| // Cycles are possible too. |
| absl::flat_hash_map<int, std::vector<std::string>> object_names_; |
| |
| // Key: node_id |
| // Value: all names that this object is exported as |
| absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>> |
| exported_names_; |
| // Key: node_id |
| // Value: pretty symbol table name to use for internal references to this |
| // object. |
| absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_; |
| |
| // Stable strings we can take StringRef's into. Used only by the SaveString |
| // method. |
| mutable std::unordered_set<std::string> saved_strings_; |
| }; |
| |
| ObjectNames::ObjectNames(const SavedObjectGraph& object_graph, |
| absl::Span<std::string> exported_names) |
| : object_graph_(object_graph), |
| names_to_export_(exported_names.begin(), exported_names.end()) { |
| // Visit all reachable nodes from the root of the object graph. |
| // This builds up object_names_ to contain all names like `foo.bar` that a |
| // particular node in the graph can be reached from. |
| RecursivelyVisitObjectGraph(/*node_id=*/0); |
| |
| // Populate the exported_names_ map. |
| // TODO(silvasean): Diagnose typos in exported names? |
| for (auto& kv : object_names_) { |
| // Make object names map independent of our particular choice of object |
| // graph traversal. |
| std::sort(kv.second.begin(), kv.second.end(), |
| [](absl::string_view a, absl::string_view b) { |
| // The sort order here influences the "pretty name" we assign |
| // below. We want the most debuggable name to be first. |
| // |
| // Debuggability heuristics: |
| // 1. Names that end in digits are likely to be internal aliases |
| // to the "real" names. |
| // 2. Longer names are more likely to be internal aliases. |
| // |
| // Example set of object names created by Keras for the weight |
| // matrix of a fully connected layer on a trivial FC mnist |
| // model: |
| // - `model.layer-1.kernel` (this is the "best" name) |
| // - `model.keras_api.layers.1.kernel` |
| // - `model.variables.0` |
| // - `model.keras_api.layers.1.keras_api.trainable_variables.0` |
| // - ... 10 more long aliases ending in digits ... |
| return std::make_tuple(isdigit(a.back()), a.size(), a) < |
| std::make_tuple(isdigit(b.back()), b.size(), b); |
| }); |
| for (const std::string& name : kv.second) { |
| if (IsExported(name)) { |
| exported_names_[kv.first].push_back(SaveString(name)); |
| } |
| } |
| } |
| // Create "pretty" symbol table names for nodes where that is applicable. |
| // We could make all symbol table names use the default, which is basically |
| // just the node id. But for debugging purposes, it's nicer if we can mix in |
| // a recognizable object name if we have the information to do so. |
| for (auto& kv : object_names_) { |
| int node_id = kv.first; |
| std::string internal_name = |
| absl::StrCat(GetDefaultSymbolTableName(node_id), "__"); |
| // If the object has an exported name, we prefer that since it is probably |
| // the most recognizable. Otherwise, we grab some non-exported name of the |
| // object. |
| if (exported_names_.find(node_id) != exported_names_.end()) { |
| internal_name += exported_names_[node_id][0].str(); |
| } else { |
| internal_name += object_names_[node_id][0]; |
| } |
| pretty_symbol_table_name_[node_id] = SaveString(internal_name); |
| } |
| } |
| |
| llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames( |
| int node_id) const { |
| auto it = exported_names_.find(node_id); |
| if (it != exported_names_.end()) { |
| return it->second; |
| } |
| return {}; |
| } |
| |
| llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const { |
| auto it = pretty_symbol_table_name_.find(node_id); |
| if (it != pretty_symbol_table_name_.end()) { |
| return it->second; |
| } |
| return SaveString(GetDefaultSymbolTableName(node_id)); |
| } |
| |
| std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const { |
| return absl::StrCat("__sm_node", node_id); |
| } |
| |
| bool ObjectNames::IsExported(const std::string& name) { |
| if (names_to_export_.empty()) { |
| return true; |
| } |
| return names_to_export_.find(name) != names_to_export_.end(); |
| } |
| |
| void ObjectNames::RecursivelyVisitObjectGraph(int node_id) { |
| const SavedObject& object = object_graph_.nodes(node_id); |
| |
| switch (object.kind_case()) { |
| case SavedObject::kConstant: |
| case SavedObject::kFunction: |
| case SavedObject::kVariable: { |
| object_names_[node_id].push_back(absl::StrJoin(path_segments_, ".")); |
| break; |
| } |
| default: |
| break; |
| } |
| |
| for (const auto& child_ref : object.children()) { |
| bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second; |
| if (on_stack) { |
| // This is a backedge. Don't traverse it. |
| continue; |
| } |
| |
| path_segments_.push_back(child_ref.local_name()); |
| RecursivelyVisitObjectGraph(child_ref.node_id()); |
| path_segments_.pop_back(); |
| |
| on_stack_nodes_.erase(child_ref.node_id()); |
| } |
| } |
| |
| llvm::StringRef ObjectNames::SaveString(const std::string& s) const { |
| return llvm::StringRef(*saved_strings_.insert(s).first); |
| } |
| |
| // Extracts a TensorProto for a Const op from a GraphDef, given an op_name. |
| // Returns nullptr on not found or other mismatch. |
| // This returns a pointer to the actual node within the graph_def so as to |
| // avoid expensive copies. |
| const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def, |
| const std::string& op_name) { |
| const NodeDef* match_node = nullptr; |
| for (const auto& node : graph_def.node()) { |
| if (node.name() == op_name) { |
| match_node = &node; |
| } |
| } |
| |
| if (!match_node) { |
| return nullptr; |
| } |
| |
| auto value_it = match_node->attr().find("value"); |
| if (value_it == match_node->attr().end()) { |
| return nullptr; |
| } |
| |
| if (!value_it->second.has_tensor()) { |
| return nullptr; |
| } |
| |
| return &value_it->second.tensor(); |
| } |
| |
| const TrackableObjectGraph::TrackableObject::SerializedTensor* |
| FindSerializedTensorInTrackable( |
| const TrackableObjectGraph::TrackableObject& trackable_object, |
| StringPiece name) { |
| for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { |
| if (maybe_serialized_tensor.name() == name) { |
| return &maybe_serialized_tensor; |
| } |
| } |
| return nullptr; |
| } |
| |
| Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph, |
| const ObjectNames& object_names) { |
| for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { |
| const SavedObject& object = object_graph.nodes(node_id); |
| if (object_names.GetExportedNames(node_id).empty()) { |
| continue; |
| } |
| if (object.kind_case() == SavedObject::kFunction) { |
| // We only allow a single input signature to each SavedFunction. |
| // This assumption means we have a 1:1 correspondence between |
| // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef |
| // This makes defining the ABI easier (or even well-defined at all). |
| // TODO(silvasean): How to detect a function that doesn't have an |
| // explicitly user-provided input signature, but happens to have been |
| // traced exactly once? |
| if (object.function().concrete_functions_size() != 1) { |
| llvm::SmallVector<std::string, 4> names; |
| for (llvm::StringRef s : object_names.GetExportedNames(node_id)) { |
| names.push_back("'" + s.str() + "'"); |
| } |
| return errors::InvalidArgument( |
| "Exported function with exported name(s) ", |
| absl::StrJoin(names, ", "), |
| " with multiple concrete functions. Add " |
| "@tf.function(input_signature=[...]) on this function, or use a " |
| "narrower list of exported names that excludes this function."); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Recursively traverses a StructuredValue, linearizing all the leaves. |
| // |
| // This currently only handles the subset of StructuredValue that is needed for |
| // signatures. |
| // |
| // Given a StructuredValue with structure [{"x": leaf0}], the "index path" |
| // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on |
| // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a |
| // linearized function argument or return on a FunctionDef, and hence to an |
| // mlir::FuncOp argument / return. |
| // |
| // This must match the linearization that happens in `tf.nest.flatten`. |
| // In particular, dict values should be linearized in sorted key order. |
| // |
| // The linearized index paths can be returned back to a structured |
| // representation (e.g. to emit C structs matching a signature) with a simple |
| // algorithm that recurses on each run of index paths with identical first |
| // elements. |
| class StructuredValueLinearizer { |
| public: |
| StructuredValueLinearizer(const StructuredValue& value, |
| mlir::MLIRContext* context); |
| |
| // Returns the list of index paths to each leaf of the StructuredValue, |
| // in a linearized order matching `tf.nest.flatten`. |
| // |
| // If an error occurred during the linearization process, an error message |
| // with `error_context` prepended will be included in the returned status. |
| StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths( |
| llvm::StringRef error_context) const; |
| |
| private: |
| // Main function that recursively traverses the StructuredValue. |
| void RecursivelyFindLeaves(const StructuredValue& value); |
| |
| mlir::Builder builder_; |
| // The current index path. We push/pop this during recursive traversal of the |
| // StructuredValue. |
| llvm::SmallVector<mlir::Attribute, 4> current_index_path_; |
| // The list of leaf index paths we have discovered so far. |
| llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_; |
| // If non-empty, an error message to report. |
| std::string error_message_; |
| }; |
| |
| StructuredValueLinearizer::StructuredValueLinearizer( |
| const StructuredValue& value, mlir::MLIRContext* context) |
| : builder_(context) { |
| RecursivelyFindLeaves(value); |
| } |
| |
| StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> |
| StructuredValueLinearizer::GetLeafIndexPaths( |
| llvm::StringRef error_context) const { |
| if (error_message_.empty()) { |
| return llvm::makeArrayRef(leaf_index_paths_); |
| } |
| return errors::InvalidArgument( |
| error_context.str(), error_message_, |
| "This likely means that you have @tf.function " |
| "on an exported function instead of " |
| "@tf.function(input_signature=[...]). Consider annotating an " |
| "input_signature or narrowing your set of " |
| "exported names to not include this function."); |
| } |
| |
| void StructuredValueLinearizer::RecursivelyFindLeaves( |
| const StructuredValue& value) { |
| switch (value.kind_case()) { |
| case StructuredValue::kDictValue: { |
| // Dict values must be linearized in sorted order of keys. |
| const DictValue& dict = value.dict_value(); |
| using FieldTy = protobuf::MapPair<std::string, StructuredValue>; |
| llvm::SmallVector<const FieldTy*, 4> fields; |
| for (auto& field : dict.fields()) { |
| fields.push_back(&field); |
| } |
| llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) { |
| return a->first < b->first; |
| }); |
| for (auto& field : fields) { |
| current_index_path_.push_back(builder_.getStringAttr(field->first)); |
| RecursivelyFindLeaves(field->second); |
| current_index_path_.pop_back(); |
| } |
| return; |
| } |
| case StructuredValue::kTupleValue: { |
| const TupleValue& tuple = value.tuple_value(); |
| for (int i = 0, e = tuple.values_size(); i < e; i++) { |
| current_index_path_.push_back(builder_.getI64IntegerAttr(i)); |
| RecursivelyFindLeaves(tuple.values(i)); |
| current_index_path_.pop_back(); |
| } |
| return; |
| } |
| // We don't differentiate between tuples and lists. |
| case StructuredValue::kListValue: { |
| const ListValue& list = value.list_value(); |
| for (int i = 0, e = list.values_size(); i < e; i++) { |
| current_index_path_.push_back(builder_.getI64IntegerAttr(i)); |
| RecursivelyFindLeaves(list.values(i)); |
| current_index_path_.pop_back(); |
| } |
| return; |
| } |
| case StructuredValue::kTensorSpecValue: { |
| // Base case: record the current path stack as the index path needed to |
| // get to this leaf. |
| leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_)); |
| return; |
| } |
| case StructuredValue::kNoneValue: { |
| // Base case: do nothing. |
| // This arises, for example, as the top-level object of an output |
| // signature when there are no return values. |
| return; |
| } |
| default: { |
| llvm::raw_string_ostream os(error_message_); |
| // TODO(silvasean): Use an enumerant name string instead of a number. |
| os << "Unhandled structured value kind " << value.kind_case() |
| << " at index path: <value>"; |
| for (auto path_element : current_index_path_) { |
| os << "."; |
| if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) { |
| os << integer.getValue(); |
| } else { |
| auto str = path_element.cast<mlir::StringAttr>(); |
| os << str.getValue(); |
| } |
| } |
| os << "\n"; |
| } |
| } |
| } |
| |
| // For exported functions with mutable bound inputs, rewrite the function |
| // signature to annotate resource subtypes on the types. |
| // |
| // The raw imported functions have `tensor<*x!tf.resource>` as the type for |
| // mutable bound inputs. Here we turn that into |
| // `tensor<!tf.resource<tensor<...>>>`. |
| void SetResourceSubtypes(mlir::ModuleOp module) { |
| mlir::SymbolTable symbol_table(module); |
| for (auto func : module.getOps<mlir::FuncOp>()) { |
| if (!mlir::tf_saved_model::IsExported(func)) continue; |
| mlir::OpBuilder builder(func.getBody()); |
| llvm::SmallVector<mlir::Type, 4> new_input_types; |
| for (int i = 0, e = func.getNumArguments(); i < e; i++) { |
| auto arg = func.front().getArgument(i); |
| auto global_tensor = |
| mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table); |
| if (global_tensor && global_tensor.is_mutable()) { |
| auto old_type = arg.getType(); |
| auto new_type = mlir::RankedTensorType::get( |
| {}, mlir::TF::ResourceType::get( |
| {global_tensor.type().cast<mlir::TensorType>()}, |
| module.getContext())); |
| arg.setType(new_type); |
| auto arg_with_original_type = builder.create<mlir::TF::CastOp>( |
| global_tensor.getLoc(), old_type, arg, |
| /*Truncate=*/builder.getBoolAttr(false)); |
| arg.replaceAllUsesWith(arg_with_original_type); |
| // The RAUW replaces the arg with itself, so we need to set it back. |
| arg_with_original_type.setOperand(arg); |
| } |
| new_input_types.push_back(arg.getType()); |
| } |
| func.setType(mlir::FunctionType::get( |
| new_input_types, func.getType().getResults(), module.getContext())); |
| } |
| } |
| |
| // Reorder the ops in the module to make testing easier and less dependent |
| // on implementation details such as the order of functions in the |
| // FunctionDefLibrary. |
| // |
| // The order this ensures is: |
| // 1. GlobalTensorOp's |
| // 2. FuncOps's. |
| // |
| // Within each of 1. and 2., ops are sorted by exported name (if |
| // available, and only the first exported name is considered), followed by |
| // non-exported ops. |
| void SortSavedModelModule(mlir::ModuleOp module) { |
| struct NamedGlobalTensor { |
| llvm::StringRef name; |
| mlir::tf_saved_model::GlobalTensorOp global_tensor; |
| }; |
| llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors; |
| for (auto global_tensor : |
| module.getOps<mlir::tf_saved_model::GlobalTensorOp>()) { |
| auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor); |
| // We use stable_sort, so duplicate empty names are fine here. |
| named_global_tensors.push_back( |
| {exported_names.empty() ? "" : exported_names.front(), global_tensor}); |
| } |
| llvm::stable_sort(named_global_tensors, |
| [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) { |
| return std::make_tuple(a.name.empty(), a.name) < |
| std::make_tuple(b.name.empty(), b.name); |
| }); |
| |
| struct NamedFunc { |
| llvm::StringRef name; |
| mlir::FuncOp func; |
| }; |
| llvm::SmallVector<NamedFunc, 8> named_funcs; |
| for (auto func : module.getOps<mlir::FuncOp>()) { |
| auto exported_names = mlir::tf_saved_model::GetExportedNames(func); |
| named_funcs.push_back( |
| {exported_names.empty() ? "" : exported_names.front(), func}); |
| } |
| llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) { |
| return std::make_tuple(a.name.empty(), a.name) < |
| std::make_tuple(b.name.empty(), b.name); |
| }); |
| |
| // Move onto the front of the module in reverse of the final desired order. |
| for (auto named_func : llvm::reverse(named_funcs)) { |
| named_func.func.getOperation()->moveBefore(&module.getBody()->front()); |
| } |
| for (auto named_global_tensor : llvm::reverse(named_global_tensors)) { |
| named_global_tensor.global_tensor.getOperation()->moveBefore( |
| &module.getBody()->front()); |
| } |
| } |
| |
| Status CreateSavedModelIR( |
| const ObjectNames& object_names, mlir::ModuleOp module, |
| const SavedObjectGraph& object_graph, |
| const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name, |
| SavedModelV2Bundle* saved_model) { |
| mlir::OpBuilder builder(module.getBodyRegion()); |
| mlir::SymbolTable symbol_table(module); |
| |
| // Create a side data-structure, indexed by the object_graph node_id to |
| // a TrackableObject that is restorable. |
| absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*> |
| restored_objects; |
| TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore( |
| [&](int saved_node_id, |
| const TrackableObjectGraph::TrackableObject& trackable_object) { |
| restored_objects.insert( |
| std::make_pair(saved_node_id, &trackable_object)); |
| return Status::OK(); |
| })); |
| |
| for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { |
| const SavedObject& object = object_graph.nodes(node_id); |
| // For correctness, we cannot import functions that don't have exported |
| // names, since they don't necessarily have a well-defined ABI (diagnosed |
| // earlier). |
| // |
| // For variables/constants, pruning them is purely an optimization, |
| // and more complicated since it requires use-def analysis of which |
| // functions use which variables/constants, so we don't do anything |
| // special for them here as part of our initial IR construction. |
| if (object.kind_case() == SavedObject::kFunction) { |
| if (object_names.GetExportedNames(node_id).empty()) { |
| continue; |
| } |
| std::string error_context = |
| "While importing SavedModel function '" + |
| object_names.GetExportedNames(node_id)[0].str() + "': "; |
| const SavedFunction& function = object.function(); |
| auto orig_func = symbol_table.lookup<mlir::FuncOp>( |
| tf_name_to_mlir_name.find(function.concrete_functions(0))->second); |
| mlir::FuncOp func = orig_func; |
| // If there are potentially references to this func from within the |
| // module, create a wrapper around it and decorate the wrapper with the |
| // tf_saved_model attributes instead. |
| if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getName(), |
| &module.getBodyRegion())) { |
| func = orig_func.cloneWithoutRegions(); |
| module.insert(module.getBody()->begin(), func); |
| func.addEntryBlock(); |
| func.setName("__sm_exported_" + orig_func.getName().str()); |
| llvm::SmallVector<mlir::Value, 4> args_as_values; |
| for (auto block_argument : func.getArguments()) { |
| args_as_values.push_back(block_argument); |
| } |
| mlir::OpBuilder body_builder(&func.getBody()); |
| auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>( |
| func.getLoc(), orig_func.getType().getResults(), args_as_values, |
| builder.getSymbolRefAttr(orig_func.getName()), |
| /*config=*/builder.getStringAttr(""), |
| /*config_proto=*/builder.getStringAttr(""), |
| /*executor_type=*/builder.getStringAttr("")); |
| body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults()); |
| } |
| func.setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); |
| const SavedConcreteFunction& concrete_function = |
| object_graph.concrete_functions().at(function.concrete_functions(0)); |
| |
| // We do not handle the other element of this tuple, which corresponds to |
| // Python kwonlyargs, since currently TensorFlow prohibits this in |
| // combination with input_signature: |
| // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030 |
| // Our SavedModel import requires input_signature on the tf.function, so |
| // we never need to handle the kwonlyargs. |
| auto positional_arg_structure = |
| concrete_function.canonicalized_input_signature() |
| .tuple_value() |
| .values(0); |
| StructuredValueLinearizer input_linearizer(positional_arg_structure, |
| builder.getContext()); |
| |
| int bound_input_base = |
| func.getNumArguments() - concrete_function.bound_inputs_size(); |
| TF_ASSIGN_OR_RETURN(auto input_index_paths, |
| input_linearizer.GetLeafIndexPaths( |
| error_context + "in input signature: ")); |
| if (bound_input_base != input_index_paths.size()) { |
| return errors::InvalidArgument( |
| error_context, |
| "Argument mismatch between concrete function input signature " |
| "vs underlying FunctionDef for concrete function '", |
| function.concrete_functions(0), "' (", input_index_paths.size(), |
| " vs ", bound_input_base, ")"); |
| } |
| for (auto index_path : llvm::enumerate(input_index_paths)) { |
| func.setArgAttr(index_path.index(), "tf_saved_model.index_path", |
| index_path.value()); |
| } |
| |
| for (auto& bound_input : |
| llvm::enumerate(concrete_function.bound_inputs())) { |
| int arg_index = bound_input_base + bound_input.index(); |
| auto symbol_ref = builder.getSymbolRefAttr( |
| object_names.GetSymbolTableName(bound_input.value())); |
| func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref); |
| } |
| |
| StructuredValueLinearizer output_linearizer( |
| concrete_function.output_signature(), builder.getContext()); |
| TF_ASSIGN_OR_RETURN(auto output_index_paths, |
| output_linearizer.GetLeafIndexPaths( |
| error_context + "in output signature: ")); |
| if (func.getNumResults() != output_index_paths.size()) { |
| return errors::InvalidArgument( |
| error_context, |
| "Result mismatch between concrete function output signature " |
| "vs underlying FunctionDef for concrete function '", |
| function.concrete_functions(0), "' (", output_index_paths.size(), |
| " vs ", func.getNumResults(), ")"); |
| } |
| for (auto index_path : llvm::enumerate(output_index_paths)) { |
| func.setResultAttr(index_path.index(), "tf_saved_model.index_path", |
| index_path.value()); |
| } |
| } else if (object.kind_case() == SavedObject::kVariable) { |
| const SavedVariable& variable = object.variable(); |
| // Find the trackable in the side data structure. |
| auto variable_trackable_it = restored_objects.find(node_id); |
| if (variable_trackable_it == restored_objects.end()) { |
| return errors::FailedPrecondition("Could not restore saved variable: ", |
| variable.name()); |
| } |
| const auto* serialized_tensor_attr = FindSerializedTensorInTrackable( |
| *variable_trackable_it->second, "VARIABLE_VALUE"); |
| if (!serialized_tensor_attr) { |
| return errors::FailedPrecondition( |
| "Could not find serialized tensor for saved variable: ", |
| variable.name()); |
| } |
| const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key(); |
| |
| // Load it from the reader. |
| Tensor value; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| saved_model->variable_reader()->Lookup(checkpoint_key, &value), |
| "Could not read checkpoint key from variables bundle: ", |
| checkpoint_key); |
| TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); |
| // A variable can have a partially known type, such as tensor<?x27x?xf32>, |
| // even if the initializer is a specific static shape. |
| TF_ASSIGN_OR_RETURN( |
| auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(), |
| &builder)); |
| auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>( |
| builder.getUnknownLoc(), |
| builder.getStringAttr(object_names.GetSymbolTableName(node_id)), |
| value_attr, |
| /*type=*/mlir::TypeAttr::get(type), |
| /*is_mutable=*/builder.getUnitAttr()); |
| op.setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); |
| } else if (object.kind_case() == SavedObject::kConstant) { |
| const SavedConstant& constant = object.constant(); |
| const TensorProto* value = ExtractConstTensorFromGraph( |
| saved_model->meta_graph_def().graph_def(), constant.operation()); |
| if (!value) { |
| return errors::FailedPrecondition( |
| "Unable to find const node referenced in object graph: ", |
| constant.operation()); |
| } |
| TF_ASSIGN_OR_RETURN(auto value_attr, |
| ConvertTensorProto(*value, &builder)); |
| auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>( |
| builder.getUnknownLoc(), |
| builder.getStringAttr(object_names.GetSymbolTableName(node_id)), |
| value_attr, |
| /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()), |
| /*is_mutable=*/nullptr); |
| op.setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); |
| } |
| } |
| SetResourceSubtypes(module); |
| module.setAttr("tf_saved_model.semantics", builder.getUnitAttr()); |
| SortSavedModelModule(module); |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert( |
| SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, |
| absl::Span<std::string> exported_names, bool add_default_attributes) { |
| GraphDebugInfo dummy_debug_info; |
| const GraphDebugInfo& debug_info = |
| saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; |
| |
| GraphImportConfig specs; |
| mlir::OwningModuleRef module = |
| mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); |
| std::unordered_map<std::string, std::string> tf_name_to_mlir_name; |
| |
| const auto& graphdef = saved_model->meta_graph_def().graph_def(); |
| PopulateTfVersions(module.get(), graphdef.versions()); |
| |
| GraphConstructorOptions options; |
| options.allow_internal_ops = true; |
| options.add_default_attributes = add_default_attributes; |
| Graph graph(OpRegistry::Global()); |
| |
| GraphDef preprocessed_graphdef(graphdef); |
| if (add_default_attributes) { |
| TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef)); |
| } |
| |
| TF_RETURN_IF_ERROR( |
| ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); |
| |
| NameUniquifier function_name_uniquifier(graph.flib_def()); |
| SavedModelImporter importer(graph.flib_def(), debug_info, specs, module.get(), |
| &tf_name_to_mlir_name, &function_name_uniquifier); |
| |
| auto fn_names = graph.flib_def().ListFunctionNames(); |
| for (const auto& fn_name : fn_names) { |
| TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); |
| } |
| |
| if (!saved_model->meta_graph_def().has_object_graph_def()) { |
| return errors::InvalidArgument( |
| "SavedModel does not have an object graph. Please use TF2."); |
| } |
| auto& object_graph = saved_model->meta_graph_def().object_graph_def(); |
| ObjectNames object_names(object_graph, exported_names); |
| |
| // Clean up a couple func's that always seem to be present when importing a |
| // SavedModel. This is not strictly needed, as there is a separate pass that |
| // will clean them up, but this makes staring at the raw IR of minimal |
| // examples quite a bit nicer. |
| for (auto func : llvm::make_early_inc_range(module->getOps<mlir::FuncOp>())) { |
| if (func.getName().startswith("__inference__traced_save_") || |
| func.getName().startswith("__inference__traced_restore_") || |
| func.getName().startswith("__inference_signature_wrapper_")) { |
| func.erase(); |
| } |
| } |
| |
| // Diagnose SavedFunction's with multiple input signatures. |
| TF_RETURN_IF_ERROR( |
| DiagnoseMultipleConcreteFunctions(object_graph, object_names)); |
| |
| // Construct the SavedModel IR. |
| TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(), |
| object_graph, tf_name_to_mlir_name, |
| saved_model)); |
| assert(mlir::succeeded(mlir::verify(module.get()))); |
| |
| return module; |
| } |
| |
| // A helper class to import a TensorFlow model expressed in SavedModel V1 into |
| // an MLIR Module in SavedModel dialect. |
| class SavedModelV1Importer { |
| public: |
| // Main entry point: converts all functions (specified by SignatureDefs) in |
| // the given meta graph to an MLIR Module. |
| static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle, |
| mlir::MLIRContext* context) { |
| SavedModelV1Importer importer(bundle, context); |
| |
| return importer.ConvertSignatures(); |
| } |
| |
| private: |
| SavedModelV1Importer(const SavedModelBundle& bundle, |
| mlir::MLIRContext* context) |
| : bundle_(bundle), |
| module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} |
| |
| // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function |
| // for each signature. |
| StatusOr<mlir::OwningModuleRef> ConvertSignatures(); |
| Status ConvertSignature( |
| const GraphDef& graphdef, const std::string& sig_def_key, |
| const std::map<std::string, TensorInfo>& inputs_sorted, |
| const std::map<std::string, TensorInfo>& outputs_sorted, |
| const GraphDebugInfo& debug_info, |
| const FunctionLibraryDefinition& flib_def); |
| |
| // Creates GlobalTensorOp for each variable and moves each VarHandle op to |
| // the enclosing function's arguments. |
| Status LiftVariables(); |
| // Moves the result of the VarHandleOp to the enclosing function's argument |
| // list and erases this VarHandleOp. |
| void LiftVariable(mlir::TF::VarHandleOp op); |
| |
| // Reads all variables from the SavedModel through session and creates |
| // GlobalTensorOp for these variables. |
| Status ReadVariablesFromSession( |
| const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops); |
| |
| GraphImportConfig::InputArrays ParseInputArrays( |
| const std::map<std::string, TensorInfo>& inputs); |
| |
| std::vector<std::string> ParseOutputArrays( |
| const std::map<std::string, TensorInfo>& outputs); |
| |
| const SavedModelBundle& bundle_; |
| mlir::OwningModuleRef module_; |
| }; |
| |
| StatusOr<mlir::OwningModuleRef> SavedModelV1Importer::ConvertSignatures() { |
| const auto& signatures = bundle_.GetSignatures(); |
| const auto& graphdef = bundle_.meta_graph_def.graph_def(); |
| PopulateTfVersions(module_.get(), graphdef.versions()); |
| |
| FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library()); |
| |
| // debug_info might not be loaded with loader_lite. |
| GraphDebugInfo debug_info; |
| if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; |
| |
| for (const auto& key_and_signature_def : signatures) { |
| const std::string& sig_def_key = key_and_signature_def.first; |
| const SignatureDef& signature_def = key_and_signature_def.second; |
| |
| // It is safe to skip "__saved_model_init_op" since it is an internal |
| // signature that is not user-accessible. |
| if (sig_def_key == "__saved_model_init_op") { |
| continue; |
| } |
| |
| // protobuf::Map doesn't provide stable iteration order so use std::map |
| std::map<std::string, TensorInfo> inputs_sorted( |
| signature_def.inputs().begin(), signature_def.inputs().end()); |
| std::map<std::string, TensorInfo> outputs_sorted( |
| signature_def.outputs().begin(), signature_def.outputs().end()); |
| |
| TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, inputs_sorted, |
| outputs_sorted, debug_info, flib_def)); |
| } |
| TF_RETURN_IF_ERROR(LiftVariables()); |
| |
| mlir::OpBuilder builder(module_->getBodyRegion()); |
| module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); |
| SortSavedModelModule(*module_); |
| |
| return std::move(module_); |
| } |
| |
| Status SavedModelV1Importer::ConvertSignature( |
| const GraphDef& graphdef, const std::string& sig_def_key, |
| const std::map<std::string, TensorInfo>& inputs_sorted, |
| const std::map<std::string, TensorInfo>& outputs_sorted, |
| const GraphDebugInfo& debug_info, |
| const FunctionLibraryDefinition& flib_def) { |
| GraphImportConfig specs; |
| specs.inputs = ParseInputArrays(inputs_sorted); |
| specs.outputs = ParseOutputArrays(outputs_sorted); |
| |
| // Remove unused nodes and create sub-graphdef. |
| GraphDef sub_graph_def; |
| TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph( |
| graphdef, &sub_graph_def, |
| /*terminal_nodes=*/{specs.outputs.begin(), specs.outputs.end()})); |
| |
| // Convert sub-graphdef to sub-graph. |
| GraphConstructorOptions options; |
| options.allow_internal_ops = true; |
| options.add_default_attributes = true; |
| Graph sub_graph(OpRegistry::Global()); |
| |
| TF_RETURN_IF_ERROR( |
| ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph)); |
| |
| // Convert sub-graph to MLIR module. |
| TF_ASSIGN_OR_RETURN( |
| auto sub_module, |
| GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info, |
| flib_def, specs, sig_def_key)); |
| mlir::OpBuilder builder(sub_module->getBodyRegion()); |
| |
| // Find the FuncOp which corresponds to current SignatureDef. |
| mlir::SymbolTable symbol_table(*sub_module); |
| auto func_op = symbol_table.lookup<mlir::FuncOp>(sig_def_key); |
| TF_RET_CHECK(func_op) |
| << "Graphdef importer should have created a function named " |
| << sig_def_key << "."; |
| |
| // Use unique SignatureDef key as exported name. |
| func_op.setAttr("tf_saved_model.exported_names", |
| builder.getStrArrayAttr({sig_def_key})); |
| |
| // Transfer input and output parameter names to index_path attributes. |
| for (auto input_and_idx : llvm::enumerate(inputs_sorted)) { |
| func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path", |
| builder.getStrArrayAttr({input_and_idx.value().first})); |
| } |
| for (auto output_and_idx : llvm::enumerate(outputs_sorted)) { |
| func_op.setResultAttr( |
| output_and_idx.index(), "tf_saved_model.index_path", |
| builder.getStrArrayAttr({output_and_idx.value().first})); |
| } |
| |
| // Move the converted functions to top level MLIR module. |
| auto* block = module_->getBody(); |
| auto* sub_block = sub_module->getBody(); |
| block->getOperations().splice( |
| mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(), |
| sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator())); |
| |
| return Status::OK(); |
| } |
| |
| Status SavedModelV1Importer::LiftVariables() { |
| llvm::SmallVector<mlir::TF::VarHandleOp, 4> ops; |
| |
| bool contains_ref_variable = false; |
| |
| module_->walk([&ops, &contains_ref_variable](mlir::Operation* op) { |
| if (auto var_handle_op = llvm::dyn_cast<mlir::TF::VarHandleOp>(op)) |
| ops.push_back(var_handle_op); |
| else if (op->getName().getStringRef() == "tf.VariableV2") |
| contains_ref_variable = true; |
| }); |
| |
| if (contains_ref_variable) |
| return errors::InvalidArgument( |
| "Ref variable created by VariableV2 is not supported."); |
| |
| if (ops.empty()) return Status::OK(); |
| |
| TF_RETURN_IF_ERROR(ReadVariablesFromSession(ops)); |
| |
| for (auto op : ops) LiftVariable(op); |
| |
| return Status::OK(); |
| } |
| |
| void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) { |
| mlir::OpBuilder builder(&module_->getBodyRegion()); |
| |
| auto func_op = op.getParentOfType<mlir::FuncOp>(); |
| builder.setInsertionPoint(func_op); |
| |
| auto func_type = func_op.getType(); |
| |
| // Create the new function type by adding variable type to the arguments. |
| llvm::SmallVector<mlir::Type, 4> new_input_types( |
| func_type.getInputs().begin(), func_type.getInputs().end()); |
| new_input_types.push_back(op.resource().getType()); |
| auto new_func_type = |
| builder.getFunctionType(new_input_types, func_type.getResults()); |
| |
| func_op.setType(new_func_type); |
| |
| // Bind the argument to the corresponding global tensor op. |
| func_op.setArgAttr(func_op.getNumArguments() - 1, |
| "tf_saved_model.bound_input", |
| builder.getSymbolRefAttr(op.shared_name())); |
| |
| // Add the newly added function param to entry block's arguments. |
| auto new_value = func_op.front().addArgument(op.resource().getType()); |
| |
| // Remove the VarHandleOp. |
| op.getOperation()->replaceAllUsesWith(llvm::ArrayRef<mlir::Value>(new_value)); |
| op.getOperation()->erase(); |
| } |
| |
| Status SavedModelV1Importer::ReadVariablesFromSession( |
| const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops) { |
| mlir::OpBuilder builder(&module_->getBodyRegion()); |
| |
| // Find all variables and their corresponding read ops. |
| llvm::MapVector<llvm::StringRef, mlir::TF::VarHandleOp> |
| variable_names_and_ops; |
| for (auto op : ops) { |
| variable_names_and_ops[op.shared_name()] = op; |
| } |
| |
| // Read all resource variables from the session. |
| std::vector<std::string> variable_names; |
| variable_names.reserve(variable_names_and_ops.size()); |
| for (const auto& name_and_location : variable_names_and_ops) |
| variable_names.push_back(std::string(name_and_location.first)); |
| |
| std::vector<Tensor> resource_tensors; |
| TF_RETURN_IF_ERROR(bundle_.GetSession()->Run( |
| /*inputs=*/{}, variable_names, |
| /*target_node_names=*/{}, &resource_tensors)); |
| |
| const DeviceMgr* device_manager; |
| TF_RETURN_IF_ERROR(bundle_.GetSession()->LocalDeviceManager(&device_manager)); |
| |
| // Read all underlying tensors of the variables from the session. |
| std::vector<Tensor> tensors; |
| tensors.reserve(resource_tensors.size()); |
| for (const auto& resource_tensor : resource_tensors) { |
| const auto& resource_handle = resource_tensor.scalar<ResourceHandle>()(); |
| |
| Device* device; |
| TF_RETURN_IF_ERROR( |
| device_manager->LookupDevice(resource_handle.device(), &device)); |
| |
| Var* var_ptr; |
| TF_RETURN_IF_ERROR(device->resource_manager()->Lookup( |
| resource_handle.container(), resource_handle.name(), &var_ptr)); |
| core::RefCountPtr<Var> var(var_ptr); |
| |
| // The variable tensor is already loaded into corresponding device's |
| // resource manager when we load the saved model using LoadSavedModel(). |
| // Here we just read its value. |
| mutex_lock ml(*var->mu()); |
| tensors.push_back(*var->tensor()); |
| } |
| |
| for (const auto& iter : llvm::zip(variable_names_and_ops, tensors)) { |
| const auto& name = std::get<0>(iter).first; |
| auto location = std::get<0>(iter).second.getLoc(); |
| const auto& tensor = std::get<1>(iter); |
| |
| // Create tensor attribute for this variable. |
| TF_ASSIGN_OR_RETURN(auto tensor_attr, ConvertTensor(tensor, &builder)); |
| |
| builder.create<mlir::tf_saved_model::GlobalTensorOp>( |
| location, builder.getStringAttr(name), tensor_attr, |
| mlir::TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); |
| } |
| |
| return Status::OK(); |
| } |
| |
| GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays( |
| const std::map<std::string, TensorInfo>& inputs) { |
| GraphImportConfig::InputArrays results; |
| for (const auto& iter : inputs) { |
| const auto& tensor_info = iter.second; |
| |
| // Only dense tensor is supported. |
| DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName); |
| |
| ArrayInfo array_info; |
| array_info.imported_dtype = tensor_info.dtype(); |
| array_info.shape = tensor_info.tensor_shape(); |
| |
| std::vector<std::string> node_names = |
| absl::StrSplit(tensor_info.name(), ':'); |
| |
| results.insert(std::pair<std::string, ArrayInfo>(node_names.at(0), |
| std::move(array_info))); |
| } |
| return results; |
| } |
| |
| std::vector<std::string> SavedModelV1Importer::ParseOutputArrays( |
| const std::map<std::string, TensorInfo>& outputs) { |
| std::vector<std::string> results; |
| for (const auto& iter : outputs) { |
| const auto& tensor_info = iter.second; |
| |
| std::vector<std::string> node_names = |
| absl::StrSplit(tensor_info.name(), ':'); |
| results.push_back(node_names.at(0)); |
| } |
| return results; |
| } |
| |
| } // namespace |
| |
| Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { |
| return FunctionalizeControlFlow(graph, flib_def); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir( |
| const GraphDef& graphdef, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::MLIRContext* context, |
| bool add_default_attributes) { |
| GraphConstructorOptions options; |
| options.allow_internal_ops = true; |
| options.add_default_attributes = add_default_attributes; |
| Graph graph(OpRegistry::Global()); |
| |
| GraphDef preprocessed_graphdef(graphdef); |
| if (add_default_attributes) { |
| TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef)); |
| } |
| TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
| options, std::move(preprocessed_graphdef), &graph)); |
| return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, |
| context); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir( |
| const Graph& graph, const GraphDebugInfo& debug_info, |
| const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, |
| mlir::MLIRContext* context) { |
| // TODO(jpienaar): Remove need to const_cast. |
| if (specs.upgrade_legacy) { |
| TF_RETURN_IF_ERROR( |
| UpgradeLegacyGraph(const_cast<Graph*>(&graph), |
| const_cast<FunctionLibraryDefinition*>(&flib_def))); |
| } |
| return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs, |
| /*func_name=*/"main"); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( |
| SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, |
| absl::Span<std::string> exported_names, bool add_default_attributes) { |
| return SavedModelImporter::Convert(saved_model, context, exported_names, |
| add_default_attributes); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir( |
| const SavedModelBundle& saved_model, mlir::MLIRContext* context) { |
| return SavedModelV1Importer::Convert(saved_model, context); |
| } |
| |
| std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { |
| std::string txt_module; |
| { |
| mlir::OpPrintingFlags flags; |
| if (show_debug_info) flags.enableDebugInfo(); |
| llvm::raw_string_ostream os{txt_module}; |
| module.print(os, flags); |
| } |
| return txt_module; |
| } |
| |
| } // namespace tensorflow |