| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" |
| |
| #include <iterator> |
| #include <memory> |
| #include <string> |
| #include <tuple> |
| #include <type_traits> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/escaping.h" |
| #include "absl/strings/numbers.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/strings/strip.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Diagnostics.h" // from @llvm-project |
| #include "mlir/IR/Identifier.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/OpDefinition.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/IR/Verifier.h" // from @llvm-project |
| #include "mlir/Pass/PassManager.h" // from @llvm-project |
| #include "tensorflow/cc/saved_model/constants.h" |
| #include "tensorflow/cc/saved_model/loader_util.h" |
| #include "tensorflow/compiler/jit/shape_inference_helpers.h" |
| #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/graph_constructor.h" |
| #include "tensorflow/core/common_runtime/shape_refiner.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/function.pb.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/resource_var.h" |
| #include "tensorflow/core/framework/shape_inference.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/graph/tensor_id.h" |
| #include "tensorflow/core/grappler/utils/transitive_fanin.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/path.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/protobuf/graph_debug_info.pb.h" |
| #include "tensorflow/core/protobuf/meta_graph.pb.h" |
| #include "tensorflow/core/protobuf/saved_object_graph.pb.h" |
| #include "tensorflow/core/protobuf/saver.pb.h" |
| #include "tensorflow/core/protobuf/struct.pb.h" |
| #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" |
| #include "tensorflow/stream_executor/lib/statusor.h" |
| |
| static inline absl::string_view StringRefToView(llvm::StringRef ref) { |
| return {ref.data(), ref.size()}; |
| } |
| |
| namespace tensorflow { |
| using mlir::NamedAttrList; |
| using mlir::TensorType; |
| using mlir::tf_saved_model::AssetOp; |
| using mlir::tf_saved_model::GlobalTensorOp; |
| using mlir::tf_saved_model::SessionInitializerOp; |
| using stream_executor::port::StatusOr; |
| |
| namespace { |
| |
| bool IsOutputShapesAttribute(const AttrValue& attr_value, |
| llvm::StringRef attr_name) { |
| return attr_name.compare("_output_shapes") == 0 && |
| attr_value.value_case() == AttrValue::kList; |
| } |
| |
| bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, |
| llvm::StringRef attr_name) { |
| if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes") |
| return attr_value.value_case() == AttrValue::kList; |
| return false; |
| } |
| |
| void LoadImporterDialects(mlir::MLIRContext& context) { |
| // Load dialects involved in the conversion |
| mlir::DialectRegistry registry; |
| mlir::RegisterAllTensorFlowDialects(registry); |
| context.appendDialectRegistry(registry); |
| context.loadAllAvailableDialects(); |
| } |
| |
| // This class is used to generate new MLIR function name strings that are both |
| // unique in the TF function library `flib_` and unique among the name strings |
| // generated by the class object during its lifetime. |
| // |
| // In theory, this class is not necessary because we should simply take |
| // the TF function name and use it as MLIR function name. However, for some |
| // unknown reasons (callout for investigation in b/142268695), keeping the |
| // function names unchanged in an MLIR roundtrip causes test failures. |
| // TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using |
| // and TF function name as MLIR function name after b/142268695 is root caused. |
| class NameUniquifier : public OpOrArgNameMapper { |
| public: |
| explicit NameUniquifier(const FunctionLibraryDefinition& flib) |
| : flib_(flib) {} |
| |
| private: |
| bool IsUnique(llvm::StringRef name) override { |
| return !flib_.Contains(std::string(name)); |
| } |
| |
| std::string GetName(OpOrVal op_or_val) override { |
| DCHECK(false) << "Unimplemented"; |
| return ""; |
| } |
| |
| const FunctionLibraryDefinition& flib_; |
| }; |
| |
| // Stateful helper class to import a TensorFlow model into an MLIR Module. |
| // |
| // This is the base class that contains common utilities shared between the |
| // GraphDef importer and SavedModel importer. |
| // |
| // A subclass is expected to call `PrepareConvert` first to perform necessary |
| // preparation over the graph and also certain internal bookkeeping data. |
| // Afterwards the other protected methods can be called. |
| class ImporterBase { |
| protected: |
| explicit ImporterBase( |
| const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::ModuleOp module, |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name, |
| NameUniquifier* function_name_uniquifier, |
| llvm::StringRef function_name_for_debug_info = "") |
| : builder_(module.getContext()), |
| module_(module), |
| context_(module.getContext()), |
| tf_name_to_mlir_name_(tf_name_to_mlir_name), |
| graph_flib_(flib), |
| specs_(specs), |
| debug_info_(debug_info), |
| function_name_for_debug_info_(function_name_for_debug_info), |
| function_name_uniquifier_(function_name_uniquifier), |
| error_handler_(module.getContext()) {} |
| |
| // Returns the inferred function signature of the given function body. Input |
| // types are unranked tensor of the respective datatype in the function and |
| // result types are inferred by the shape_refiner_. Result types need not be |
| // unranked tensors and could be ranked tensors in cases where result type |
| // depends on an op with static output shape like tf.Const. |
| StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody); |
| |
| // Extracts arg and ret nodes from FunctionBody. |
| void GetArgsAndRetsFromFunctionBody( |
| const FunctionBody& fbody, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes); |
| |
| // Prepares converting the graph to an MLIR module. This step removes the |
| // backedges of the graph, orders the nodes and infers the shapes. |
| Status PrepareConvert(const Graph& graph); |
| |
| // Converts the prepared graph to a Function and adds it to the module. A set |
| // of nodes from the graph are given to converted to the arguments and returns |
| // of the function. |
| Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes, |
| llvm::ArrayRef<mlir::NamedAttribute> attrs); |
| |
| // Finds out the function definition for the given function name from the |
| // graph and converts it to a function of the module. This method is called |
| // on demand because the graph flib_def does not provide an iterator |
| // interface. |
| Status ConvertLibFunction(llvm::StringRef func_name); |
| |
| // Returns the list of nodes in the graph. Nodes are presented in the reverse |
| // order of a post-order depth-first visit starting from the graph's source |
| // nodes. |
| llvm::ArrayRef<Node*> GetOrderedNodes() const { return ordered_nodes_; } |
| |
| // Returns the inferred input type at index `idx` of the `node` in the |
| // context. |
| StatusOr<mlir::Type> InferInputType(const Node& node, int idx, |
| mlir::Builder builder); |
| |
| // Returns the inferred output type at index `idx` of the `node` in the |
| // context. |
| StatusOr<mlir::Type> InferOutputType(const Node& node, int idx, |
| mlir::Builder builder); |
| |
| private: |
| // Most types with subtypes have only one subtype. |
| using ElementSubtypes = llvm::SmallVector<TensorType, 1>; |
| |
| // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all |
| // data type and shape information is maintained by the shape_refiner_. |
| // TODO(jpienaar): Remove once shape inference on import is removed. |
| Status AddNodesToShapeRefiner( |
| std::unordered_map<string, Node*>* node_name_map); |
| |
| // Prune nodes that do not feed into fetch nodes. |
| Status PruneUnreachableNodes( |
| std::unordered_map<string, Node*>* node_name_map); |
| |
| // Converts feeds to Placeholder nodes. |
| Status ConvertFeedsToPlaceholders( |
| std::unordered_map<string, Node*>* node_name_map); |
| |
| // Converts the inferred shape referred to by 'handle' in 'context', with |
| // given element type, and returns an MLIR tensor type. |
| StatusOr<TensorType> ConvertDataTypeAndShape( |
| DataType dtype, const shape_inference::ShapeHandle& handle, |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder); |
| |
| // Converts the inferred shape referred to by 'handle' in 'context', with |
| // given element type, and returns an MLIR tensor type. |
| StatusOr<TensorType> ConvertElementTypeAndShape( |
| mlir::Type element_type, const shape_inference::ShapeHandle& handle, |
| shape_inference::InferenceContext* context, mlir::Builder builder); |
| |
| // Converts the inferred subtypes for an element type to corresponding MLIR |
| // types in 'context'. |
| StatusOr<ElementSubtypes> ConvertSubtypes( |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder); |
| |
| // Converts the tensor proto into an MLIR elements attribute. |
| StatusOr<mlir::ElementsAttr> ConvertTensorProto(const TensorProto& value) { |
| return ::tensorflow::ConvertTensorProto(value, &builder_); |
| } |
| |
| // Converts func name in graphdef to mlir::SymbolRefAttribute. |
| StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName( |
| const std::string& func_name); |
| |
| // Converts the given non-function-call AttrValue to an MLIR Attribute. |
| StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value); |
| |
| // Converts the given function-call AttrValue to MLIR Attributes and pushes |
| // them to the given attributes list. For example, if there is a kFunc |
| // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to |
| // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, |
| // {base_name.k2 : rfc}}. |
| Status ConvertFunctionCallAttribute(const std::string& base_name, |
| const AttrValue& value, |
| NamedAttrList* attributes); |
| |
| // Helper to create either a tf_executor operation or a TF operation wrapped |
| // in an island. |
| mlir::Operation* CreateOperation( |
| const Node& node, llvm::StringRef node_type_name, |
| const mlir::OperationState& result, |
| const llvm::SmallVectorImpl<mlir::Value>& control_operands); |
| |
| // Converts one NodeDef from the input GraphDef into an Operation and |
| // inserts it into the MLIR module using builder_. |
| Status ConvertNode(const Node& node); |
| |
| // If the input graph represents a while-loop, the edges pointing from a |
| // "NextIteration" node to a "Merge" node add cyclic dependencies and make the |
| // topological sorting impossible. We need to remove these edges from the |
| // input graph to infer shapes and construct a Function. For each |
| // "NextIteration" node, there are two operations, "NextIteration.source" |
| // and "NextIteration.sink" are added to the MLIR module. |
| using BackEdge = BackEdgeHelper::BackEdge; |
| |
| // Removes backedges from the input graph. The removed edges are added back to |
| // to OpBuilder after the remaining graph is converted to the Function. |
| Status RemoveBackedges(const Graph& graph); |
| |
| // Restores backedges removed during shape inference to the final Function. |
| Status AddBackedges(); |
| |
| // Restores a single backedge in the Function by adding a replicated |
| // operation before the dst operation. |
| Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst, |
| int dst_input); |
| |
| // Adds the input arguments and return operation to the function. The |
| // arguments are added as basic block argument. Also the argument types and |
| // the id of the nodes from the input graph needs to be specified. |
| Status ConvertFunctionArgAndRets( |
| mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op, |
| llvm::ArrayRef<mlir::Type> arg_types, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes); |
| |
| // Gets the location information of the given node. It uses the |
| // "original_node_name" in the NodeDef to get the corresponding file location |
| // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If |
| // there are multiple "original_node_names", a FusedLoc is returned. If the |
| // node name couldn't be found in the input DebugInfo, a NameLoc is used as |
| // the location. |
| mlir::Location GetLocation(const Node& node); |
| |
| // Appends the location string for the node to the error message and returns |
| // the combined error status. |
| Status EmitErrorWithLocationStr(const Node& node, const Status& error_status); |
| |
| // Inserts a placeholder node in the graph to replace a feed output tensor, |
| // and returns the new placeholder node and a boolean indicating if the |
| // original input node was removed from the graph. Uses of the feed output |
| // tensor are replaced with this placeholder node. If the feed output tensor |
| // is of a single output node, the control dependencies are forwarded to the |
| // the placeholder node, and the original node will be removed. |
| // Note: This modifies the graph, and so any list of ordered nodes needs to be |
| // reconstructed. |
| StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed( |
| const TensorShapeProto& shape, DataType dtype, Node* node, int index, |
| const std::unordered_map<string, Node*>& node_name_map); |
| |
| // Gets the input and output nodes corresponding to the specified input and |
| // output nodes in specs_. If there are no input or output nodes specified, |
| // nodes will be empty. |
| Status GetInputOutputNodes( |
| const std::unordered_map<string, Node*>& node_name_map, |
| std::unordered_set<const Node*>* nodes); |
| |
| llvm::StringSet<>& GetUnmodelledOpTypes() { |
| // All the TF ops encountered that aren't modelled in dialect. |
| static auto* unmodelled_op_types = new llvm::StringSet<>(); |
| return *unmodelled_op_types; |
| } |
| |
| // The input graph with backedges removed. The removed backedges are stored |
| // in the back_edge_helper. |
| BackEdgeHelper back_edge_helper_; |
| // A map between node and output index, for each backedge. |
| absl::flat_hash_map<const Node*, int> back_edge_node_output_; |
| absl::flat_hash_map<const Node*, BackEdge> back_edge_dst_inputs_; |
| // A map between sink and source operation of NextIteration |
| absl::flat_hash_map<mlir::Operation*, mlir::Operation*> |
| next_iteration_sink_source_; |
| |
| // All nodes and version information about the (copied) imported graph. |
| std::unique_ptr<Graph> graph_; |
| std::vector<Node*> ordered_nodes_; |
| |
| // Maps from a Node ID to a MLIR value. |
| using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>; |
| |
| mlir::OpBuilder builder_; |
| mlir::ModuleOp module_; |
| mlir::MLIRContext* context_; |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_; |
| const FunctionLibraryDefinition& graph_flib_; |
| const GraphImportConfig& specs_; |
| const GraphDebugInfo& debug_info_; |
| llvm::StringRef function_name_for_debug_info_; |
| NodeValueMap node_values_; |
| // TODO(jpienaar): Remove once shape inference on import is removed. |
| // The shape_refinner_ will be nullptr if shape inference on import is |
| // not enabled. |
| std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr; |
| NameUniquifier* function_name_uniquifier_; |
| mlir::StatusScopedDiagnosticHandler error_handler_; |
| |
| protected: |
| // Maps feed as TensorId to new Placeholder node name. |
| absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_; |
| }; |
| |
| // Returns true if the node with given name has a non primary output that is |
| // used by some other node as an input. Returns false if no outputs are in use |
| // or only the first output is in use. |
| bool HasNonPrimaryOutputInUse(const GraphDef& graph_def, |
| const std::string& node) { |
| for (const auto& node_def : graph_def.node()) { |
| for (const auto& input : node_def.input()) { |
| if (absl::StartsWith(input, node + ":") && input != node + ":0") { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| // Updates the given LegacyFedInput node with Placeholder node if it is one of |
| // the inputs. Returns an error if non primary output of the LegacyFedInput node |
| // is in use and therefore can not be replaced by the Placeholder node that only |
| // has a single output. |
| Status UpdateLegacyFedInputNode(const GraphDef& graph_def, |
| const GraphImportConfig::InputArrays& inputs, |
| NodeDef* node) { |
| const std::string& node_name = node->name(); |
| auto it = inputs.find(node_name); |
| |
| // Node is not an input. |
| if (it == inputs.end()) return Status::OK(); |
| |
| if (HasNonPrimaryOutputInUse(graph_def, node_name)) { |
| return errors::InvalidArgument( |
| "LegacyFedInput node ", node->name(), |
| " has non primary output in use and can not be replaced with " |
| "Placeholder node"); |
| } |
| |
| DataType dtype = it->second.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->attr().at("output_types").list().type(0); |
| } |
| // Update op name, drop inputs and set attributes required by the Placeholder |
| // op. |
| *node->mutable_op() = "Placeholder"; |
| node->clear_attr(); |
| node->clear_input(); |
| AddNodeAttr("dtype", dtype, node); |
| AddNodeAttr("shape", it->second.shape, node); |
| return Status::OK(); |
| } |
| |
| // Preprocesses GraphDef before it can be converted to Graph by, |
| // - Adding the default attributes to each node def if they are missing from |
| // the GraphDef. |
| // - Replacing LegacyFedInput nodes with Placeholder nodes if |
| // convert_legacy_fed_inputs option is enabled. |
| Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { |
| for (auto& node_def : *graph_def->mutable_node()) { |
| // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One |
| // solution could be have a tool to let users upgrade old serialized graphs. |
| if (specs && specs->convert_legacy_fed_inputs && |
| node_def.op() == "LegacyFedInput") { |
| TF_RETURN_IF_ERROR( |
| UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def)); |
| } |
| |
| const tensorflow::OpRegistrationData* op_reg_data = |
| tensorflow::OpRegistry::Global()->LookUp(node_def.op()); |
| if (!op_reg_data) { |
| // This is likely a function call node, so we should continue. |
| continue; |
| } |
| ::tensorflow::AddDefaultsToNodeDef(op_reg_data->op_def, &node_def); |
| } |
| return Status::OK(); |
| } |
| |
| // Mapping from node name to feed (index and ArrayInfo). Node name must outlive |
| // this map. |
| using FeedsByNode = absl::flat_hash_map< |
| absl::string_view, |
| absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>; |
| |
| // Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output |
| // tensor name to index and ArrayInfo. Keys and values are backed by |
| // `GraphImportConfig::InputArrays`. |
| StatusOr<FeedsByNode> GetFeedsByNode( |
| const GraphImportConfig::InputArrays& inputs) { |
| FeedsByNode feeds_by_node; |
| feeds_by_node.reserve(inputs.size()); |
| |
| for (const auto& input : inputs) { |
| TensorId tensor = ParseTensorName(input.first); |
| if (tensor.index() < 0) |
| return errors::FailedPrecondition( |
| "Feed output tensor must be a data output '", tensor.ToString(), "'"); |
| |
| auto& node = feeds_by_node[tensor.node()]; |
| if (!node.insert({tensor.index(), &input}).second) |
| return errors::FailedPrecondition( |
| "Multiple feeds for the same output tensor '", tensor.ToString(), |
| "'"); |
| } |
| |
| return feeds_by_node; |
| } |
| |
| // Creates a unique name for a node that will be replacing a feed output tensor. |
| std::string GetUniqueNodeName( |
| absl::string_view node_name, int index, |
| const std::unordered_map<string, Node*>& node_name_map) { |
| std::string new_node_name_base = absl::StrCat(node_name, "_", index); |
| int count = 0; |
| std::string new_node_name = new_node_name_base; |
| while (node_name_map.find(new_node_name) != node_name_map.end()) { |
| new_node_name = absl::StrCat(new_node_name_base, "_", count++); |
| } |
| return new_node_name; |
| } |
| |
| Status ImporterBase::RemoveBackedges(const Graph& graph) { |
| // TODO(fengliuai): Converting to GraphDef and back is the easiest way to |
| // clone a graph. |
| // TODO(fengliuai): clone the graph without going to graph_def first. |
| GraphDef graph_def; |
| graph.ToGraphDef(&graph_def); |
| graph_ = absl::make_unique<Graph>(graph.flib_def()); |
| GraphConstructorOptions opts; |
| opts.allow_internal_ops = true; |
| opts.add_default_attributes = false; |
| TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph( |
| opts, std::move(graph_def), graph_.get())); |
| |
| // Remove all the backedges. So the nodes can be added to the shape refiner. |
| TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get())); |
| VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size()) |
| << " backedges."; |
| |
| // Creates a map for quickly identifying whether a node output is a backedge. |
| for (const auto& edge : back_edge_helper_.RemovedEdges()) { |
| if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() && |
| back_edge_node_output_[edge.src] != edge.src_output) { |
| return errors::FailedPrecondition( |
| "More than one of the src node outputs are backedges!"); |
| } |
| back_edge_node_output_[edge.src] = edge.src_output; |
| // We expect a merge to receive a single backedge (multiple NextIteration |
| // nodes feeding into the same merge is unexpected here). |
| DCHECK(!back_edge_dst_inputs_.contains(edge.dst)); |
| back_edge_dst_inputs_[edge.dst] = edge; |
| } |
| |
| // Obtains a RPO ordering, using node names as a tiebreak for stable sorting. |
| GetReversePostOrder( |
| *graph_, &ordered_nodes_, |
| [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); |
| return Status::OK(); |
| } |
| |
| Status CopyStackTraces(const Graph& from, Graph* to) { |
| // Copy over the stack traces. |
| // TODO(jpienaar): This really shouldn't be needed, copying the Graph above |
| // and then needing these traversals is unfortunate. |
| std::unordered_map<string, Node*> node_map = from.BuildNodeNameIndex(); |
| for (Node* node : to->nodes()) { |
| if (const Node* old_node = node_map[node->name()]) { |
| if (const std::shared_ptr<AbstractStackTrace>& stack = |
| old_node->GetStackTrace()) { |
| DVLOG(2) << "Stack for " << node->name() << " " |
| << old_node->GetStackTrace()->ToString( |
| AbstractStackTrace::TracePrintingOptions()); |
| node->SetStackTrace(stack); |
| } else { |
| DVLOG(1) << "No stack for " << node->name() << " (" << node |
| << ") in Graph " << &from; |
| } |
| } else { |
| DVLOG(1) << "No stack for " << node->name() << " (" << node |
| << ") in Graph " << &from; |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed( |
| const TensorShapeProto& shape, DataType dtype, Node* node, int index, |
| const std::unordered_map<string, Node*>& node_name_map) { |
| DCHECK_LT(index, node->num_outputs()); |
| const bool update_inplace = node->num_outputs() == 1 && index == 0; |
| std::string new_node_name = |
| update_inplace ? node->name() |
| : GetUniqueNodeName(node->name(), index, node_name_map); |
| |
| Node* placeholder_node; |
| NodeBuilder builder(new_node_name, "Placeholder"); |
| builder.Attr("shape", shape); |
| builder.Attr("dtype", dtype); |
| TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node)); |
| |
| // Update edges from original feed with Placeholder node. |
| std::vector<const Edge*> data_edges; |
| std::vector<const Edge*> control_edges; |
| for (const tensorflow::Edge* edge : node->out_edges()) { |
| if (edge->src_output() == index) { |
| data_edges.push_back(edge); |
| } else if (update_inplace && edge->IsControlEdge()) { |
| control_edges.push_back(edge); |
| } |
| } |
| |
| for (const auto* edge : data_edges) { |
| TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(), |
| edge->dst_input())); |
| } |
| |
| // TODO(lyandy): Preserve control dependencies properly by not forwarding |
| // control dependencies to data outputs and not removing single output nodes. |
| // When a data output is replaced as a feed, unless there is another non feed |
| // data output or an explicit control output used by the same node, transitive |
| // control dependencies are not to be executed. For single output nodes, |
| // Placeholders can be converted to a NoOp if there are no uses, and |
| // PlaceholderWithDefault can be converted to an Identity. |
| for (const auto* edge : control_edges) { |
| graph_->AddControlEdge(placeholder_node, edge->dst()); |
| graph_->RemoveControlEdge(edge); |
| } |
| |
| if (update_inplace) { |
| graph_->RemoveNode(node); |
| } |
| |
| return std::pair<Node*, bool>(placeholder_node, update_inplace); |
| } |
| |
| Status ImporterBase::GetInputOutputNodes( |
| const std::unordered_map<string, Node*>& node_name_map, |
| std::unordered_set<const Node*>* nodes) { |
| auto add_node = [&](absl::string_view name) { |
| auto it = node_name_map.find(std::string(name)); |
| if (it == node_name_map.end()) { |
| return errors::FailedPrecondition( |
| absl::StrCat("Graph does not contain node: ", name)); |
| } |
| nodes->insert(it->second); |
| return Status::OK(); |
| }; |
| |
| // Remap feeds and fetches to newly created Placeholder nodes. |
| for (const auto& input : specs_.inputs) { |
| TensorId tensor = ParseTensorName(input.first); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| TF_RETURN_IF_ERROR(add_node(remapped_it->second)); |
| } else { |
| TF_RETURN_IF_ERROR(add_node(tensor.node())); |
| } |
| } |
| |
| for (const auto& output : specs_.outputs) { |
| TensorId tensor = ParseTensorName(output); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| TF_RETURN_IF_ERROR(add_node(remapped_it->second)); |
| } else { |
| TF_RETURN_IF_ERROR(add_node(tensor.node())); |
| } |
| } |
| |
| for (const auto& control_output : specs_.control_outputs) |
| TF_RETURN_IF_ERROR(add_node(control_output)); |
| |
| return Status::OK(); |
| } |
| |
| // TODO(jpienaar): Remove this post shape inference on import flag is removed. |
| Status ImporterBase::AddNodesToShapeRefiner( |
| std::unordered_map<string, Node*>* node_name_map) { |
| shape_refiner_ = absl::make_unique<ShapeRefiner>(graph_->versions(), |
| graph_->op_registry()); |
| // Some operations (for example "TPUExecute") don't have shape inference |
| // function defined, so we should set this to false for adding nodes with |
| // these types of operations. |
| shape_refiner_->set_require_shape_inference_fns(false); |
| shape_refiner_->set_function_library_for_shape_inference(&graph_flib_); |
| |
| TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); |
| |
| // First add all nodes to the refiner. |
| for (Node* node : ordered_nodes_) { |
| // We need to use a TensorFlow node to teach the shape refiner that user |
| // specifies certain data type and shape for the inputs in the `specs_`. |
| // This node shouldn't have any inputs, only have one output and its |
| // output type/shape is only determined by its "named" attributes. (The |
| // attributes should have fixed names so we can use the info from `specs_` |
| // to set the value of them.) `Placeholder` satisfies these constraints. |
| // |
| // Therefore, if the input node isn't a `Placeholder`, we create one and use |
| // it to replace the original input node, so the shape refiner can |
| // successfully propagate the user's input type and shape to the rest of the |
| // graph. |
| bool node_added_to_shape_refiner = false; |
| auto it = feeds_by_node.find(node->name()); |
| if (it != feeds_by_node.end()) { |
| auto op_name = node->op_def().name(); |
| if (op_name != "Placeholder" && op_name != "LegacyFedInput" && |
| op_name != FunctionLibraryDefinition::kArgOp) { |
| for (const auto& output_tensor : it->second) { |
| const int index = output_tensor.first; |
| const ArrayInfo& array_info = output_tensor.second->second; |
| |
| DataType dtype = array_info.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->output_type(index); |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| auto placeholder_node_and_removed, |
| CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, |
| *node_name_map)); |
| |
| Node* placeholder_node = placeholder_node_and_removed.first; |
| if (placeholder_node_and_removed.second) { |
| // Original node has been removed from the graph. |
| node = placeholder_node; |
| node_added_to_shape_refiner = true; |
| } |
| remapped_feeds_[{it->first, index}] = placeholder_node->name(); |
| (*node_name_map)[placeholder_node->name()] = placeholder_node; |
| // Add the new placeholder node to the shape refiner. |
| Status status = shape_refiner_->AddNode(placeholder_node); |
| if (!status.ok()) { |
| return EmitErrorWithLocationStr(*placeholder_node, status); |
| } |
| } |
| } else { |
| auto index_it = it->second.find(0); |
| if (index_it == it->second.end()) { |
| return errors::FailedPrecondition( |
| "Missing feed output tensor at index 0 for node '", node->name(), |
| "'"); |
| } |
| node->AddAttr("shape", index_it->second->second.shape); |
| DataType dtype = index_it->second->second.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->output_type(0); |
| } |
| node->AddAttr("dtype", dtype); |
| } |
| } |
| if (!node_added_to_shape_refiner) { |
| // Add the node to the shape refiner if the node hasn't been removed. |
| Status status = shape_refiner_->AddNode(node); |
| if (!status.ok()) { |
| return EmitErrorWithLocationStr(*node, status); |
| } |
| } |
| |
| auto set_shape_from_list_attr = [&](const AttrValue* attr) { |
| auto& list = attr->list(); |
| for (auto shape : llvm::enumerate(list.shape())) { |
| auto* node_context = shape_refiner_->GetContext(node); |
| shape_inference::ShapeHandle handle; |
| Status status = |
| node_context->MakeShapeFromShapeProto(shape.value(), &handle); |
| if (!status.ok()) { |
| return EmitErrorWithLocationStr(*node, status); |
| } |
| node_context->set_output(shape.index(), handle); |
| } |
| return Status::OK(); |
| }; |
| |
| // We currently have no other way to get shapes from ReadVariableOp's. |
| // Some graphs seem to have _output_shapes attributes on them, so use that |
| // if possible. |
| // TODO(silvasean): Ideally, we would do this in a separate shape inference |
| // pass to avoid adding complexity to the importer. But right now, we don't |
| // have an MLIR-native shape inference pass, so we need to do this while we |
| // still have the Graph around, i.e. here, in the importer. |
| if (node->op_def().name() == "ReadVariableOp") { |
| // TODO(silvasean): In some graphs, this seems to be annotated on every |
| // node. Why and by whom? |
| // TODO(b/140588338): We should ideally incorporate that information for |
| // all nodes, but right now, this can result in e.g. an Identity node with |
| // signature such as |
| // `(tensor<?x?xf32>) -> tensor<?x9216xf32>` which fails the verifier |
| // (which checks for exact type equality; _output_shapes results in |
| // us shoehorning in the more-precise type on the output). |
| if (const AttrValue* attr = node->attrs().Find("_output_shapes")) |
| TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); |
| } |
| |
| // If it is the argument node, the shape handle is set explicitly, so it |
| // can be propagated to the body nodes of the function. |
| if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) { |
| auto* node_context = shape_refiner_->GetContext(node); |
| DCHECK(node_context != nullptr); |
| if (const AttrValue* attr = node->attrs().Find("shape")) { |
| shape_inference::ShapeHandle handle; |
| Status status = |
| node_context->MakeShapeFromShapeProto(attr->shape(), &handle); |
| if (!status.ok()) { |
| return EmitErrorWithLocationStr(*node, status); |
| } |
| node_context->set_output(0, handle); |
| } else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) { |
| TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); |
| } else { |
| node_context->set_output(0, node_context->UnknownShape()); |
| } |
| } |
| } |
| |
| // Since we might have inserted and removed nodes from the graph, fix |
| // source/sink edges and reconstruct the RPO ordering of nodes |
| FixupSourceAndSinkEdges(graph_.get()); |
| |
| // Prune nodes in the graph that are not reachable from the output. |
| if (specs_.prune_unused_nodes) { |
| std::unordered_set<const Node*> prune_start; |
| TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); |
| if (!prune_start.empty()) { |
| if (PruneForReverseReachability(graph_.get(), prune_start)) { |
| VLOG(1) << "Pruned unused nodes in graphdef"; |
| } else { |
| VLOG(1) << "No unused nodes in graphdef to prune"; |
| } |
| } else { |
| VLOG(1) << "No output nodes specified, skipping pruning"; |
| } |
| } else { |
| VLOG(1) << "Pruning unused nodes in graphdef is disabled"; |
| } |
| |
| // Re-initialize ordered_nodes_ since we might have modified the graph. |
| GetReversePostOrder( |
| *graph_, &ordered_nodes_, |
| [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); |
| |
| VLOG(1) << "Inferring graph shapes to fixpoint"; |
| |
| // The "changed" information from UpdateNode can give false positives, so we |
| // create a dedicated method to verify the shapes are not changed before and |
| // after the shape refine. |
| auto same_inferred_shape = [](shape_inference::InferenceContext* c, |
| shape_inference::ShapeHandle s0, |
| shape_inference::ShapeHandle s1) -> bool { |
| if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) { |
| return true; |
| } |
| if (c->Rank(s0) != c->Rank(s1)) { |
| return false; |
| } |
| for (int i = 0; i < c->Rank(s0); ++i) { |
| if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) { |
| int64 val0 = c->Value(c->Dim(s0, i)); |
| int64 val1 = c->Value(c->Dim(s1, i)); |
| // Negative value is treated as unknown so all negative values indicate |
| // the same dimension. |
| if (val0 >= 0 && val1 >= 0 && val0 != val1) return false; |
| } |
| } |
| return true; |
| }; |
| |
| bool changed = true; |
| int i = 0; |
| const int kMaxIterationCount = 2; |
| while (changed && i != kMaxIterationCount) { |
| changed = false; |
| for (const Node* node : ordered_nodes_) { |
| auto* shape_context = shape_refiner_->GetContext(node); |
| DCHECK(shape_context != nullptr); |
| absl::InlinedVector<shape_inference::ShapeHandle, 4> existing; |
| existing.reserve(shape_context->num_outputs()); |
| for (int o = 0; o < shape_context->num_outputs(); ++o) { |
| existing.push_back(shape_context->output(o)); |
| } |
| bool inferred = false; |
| shape_inference::ShapeHandle handle; |
| Status status = |
| shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred); |
| if (!status.ok()) { |
| return EmitErrorWithLocationStr(*node, status); |
| } |
| for (int o = 0; o < shape_context->num_outputs(); ++o) { |
| if (!same_inferred_shape(shape_context, shape_context->output(o), |
| existing[o])) { |
| changed = true; |
| break; |
| } |
| } |
| } |
| ++i; |
| } |
| if (i >= kMaxIterationCount) { |
| LOG(WARNING) << "Graph shapes did not converge to a fixpoint within " |
| << kMaxIterationCount |
| << " iterations. Graph shapes may be conservative."; |
| } |
| VLOG(1) << "Graph shapes were inferred with " << (i - 1) |
| << " extra rounds of analysis to reach a fixpoint."; |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::Type> ImporterBase::InferInputType(const Node& node, int idx, |
| mlir::Builder builder) { |
| if (specs_.enable_shape_inference) { |
| // TODO(jpienaar): Remove this if shape inference on import flag is removed. |
| ExtendedInferenceContext* shape_context = |
| shape_refiner_->GetExtendedContext(&node); |
| DataType dtype = shape_context->input_type(idx); |
| auto* context = shape_context->get_context(); |
| return ConvertDataTypeAndShape(dtype, context->input(idx), |
| context->input_handle_shapes_and_types(idx), |
| context, builder); |
| } |
| DataType dtype = node.properties()->input_types[idx]; |
| mlir::Type element_type; |
| TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); |
| return mlir::UnrankedTensorType::get(element_type); |
| } |
| |
| StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx, |
| mlir::Builder builder) { |
| DataType dtype = node.properties()->output_types[idx]; |
| |
| // Returns output type given inference context. |
| auto shape_ic = [&](shape_inference::InferenceContext* c) { |
| return ConvertDataTypeAndShape(dtype, c->output(idx), |
| c->output_handle_shapes_and_types(idx), c, |
| builder); |
| }; |
| |
| if (specs_.enable_shape_inference) { |
| // TODO(jpienaar): Remove this if shape inference on import flag is removed. |
| ExtendedInferenceContext* shape_context = |
| shape_refiner_->GetExtendedContext(&node); |
| return shape_ic(shape_context->get_context()); |
| } |
| |
| // Treat TensorList init ops specially here as the op requires knowing its |
| // element dtype. |
| // TODO(jpienaar): Reconsider post refactoring shape functions. |
| if (node.type_string() == "TensorListReserve" || |
| node.type_string() == "EmptyTensorList") { |
| mlir::Type etype; |
| if (auto element_dtype = node.attrs().Find("element_dtype")) { |
| TF_RETURN_IF_ERROR( |
| ConvertDataType(element_dtype->type(), builder, &etype)); |
| } |
| return mlir::RankedTensorType::get( |
| {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)}, |
| etype.getContext())); |
| } |
| |
| if (node.IsWhileNode()) { |
| auto* output_shapes = node.attrs().Find("output_shapes"); |
| auto* element_types = node.attrs().Find("T"); |
| if (output_shapes && !output_shapes->list().shape().empty()) { |
| const auto& output_shape = output_shapes->list().shape(idx); |
| const auto& element_type = element_types->list().type(idx); |
| return ConvertToMlirTensorType(output_shape, element_type, &builder); |
| } |
| } |
| |
| auto type_from_array_attr = [&node, &idx, &builder]( |
| absl::string_view output_shape_attr, |
| absl::string_view element_type_attr) { |
| auto* output_shapes = node.attrs().Find(output_shape_attr); |
| auto* element_types = node.attrs().Find(element_type_attr); |
| const auto& output_shape = output_shapes->list().shape(idx); |
| const auto& element_type = element_types->list().type(idx); |
| return ConvertToMlirTensorType(output_shape, element_type, &builder); |
| }; |
| |
| if (node.type_string() == "IteratorGetNext" || |
| node.type_string() == "IteratorGetNextSync" || |
| node.type_string() == "MultiDeviceIteratorGetNextFromShard") |
| return type_from_array_attr("output_shapes", "output_types"); |
| |
| if (node.type_string() == "InfeedDequeueTuple") |
| return type_from_array_attr("shapes", "dtypes"); |
| |
| if (node.type_string() == "InfeedDequeue") { |
| assert(idx == 0); |
| const auto& output_shape = node.attrs().Find("shape")->shape(); |
| const auto& element_type = node.attrs().Find("dtype")->type(); |
| return ConvertToMlirTensorType(output_shape, element_type, &builder); |
| } |
| |
| // Returns a simple, more conservative unranked tensor type. |
| auto default_type = [&]() -> StatusOr<mlir::Type> { |
| mlir::Type element_type; |
| TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); |
| return mlir::UnrankedTensorType::get(element_type); |
| }; |
| |
| // Below we only try and do some shape inference for "source" ops which have |
| // no inputs. |
| if (node.num_inputs() > 0) return default_type(); |
| |
| // Do some simply inference here to get the function arguments correct for |
| // this common case. |
| // TODO(jpienaar): Reconsider post refactoring shape functions. |
| if (node.IsArg()) { |
| if (dtype == DT_RESOURCE) { |
| const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes"); |
| const AttrValue* shape_attr = node.attrs().Find("_handle_shapes"); |
| if (dtype_attr && shape_attr) { |
| if (dtype_attr->list().type().empty()) { |
| return errors::InvalidArgument( |
| "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", |
| shape_attr->DebugString()); |
| } |
| if (shape_attr->list().shape().empty()) { |
| return errors::InvalidArgument( |
| "Invalid \"_handle_shapes\" attribute value for _Arg node: ", |
| shape_attr->DebugString()); |
| } |
| DataType dtype = dtype_attr->list().type(0); |
| const TensorShapeProto& shape_proto = shape_attr->list().shape(0); |
| TF_ASSIGN_OR_RETURN( |
| auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder)); |
| return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get( |
| {etype.cast<TensorType>()}, builder.getContext())); |
| } else { |
| return mlir::UnrankedTensorType::get( |
| mlir::TF::ResourceType::get(builder.getContext())); |
| } |
| } else if (auto shape = node.attrs().Find("_output_shapes")) { |
| if (shape->has_list() && shape->list().shape_size() == 1) { |
| return ConvertToMlirTensorType(shape->list().shape().at(0), dtype, |
| &builder); |
| } |
| } |
| } |
| |
| const tensorflow::OpRegistrationData* op_reg_data; |
| TF_RETURN_IF_ERROR( |
| graph_->op_registry()->LookUp(node.type_string(), &op_reg_data)); |
| if (!op_reg_data) { |
| DVLOG(1) << "Skipping inference for unregistered op " << node.type_string(); |
| return default_type(); |
| } |
| if (op_reg_data->shape_inference_fn == nullptr) { |
| DVLOG(1) << "Skipping inference for op without shape function " |
| << node.type_string(); |
| return default_type(); |
| } |
| shape_inference::InferenceContext c(graph_->versions().producer(), |
| node.attrs(), op_reg_data->op_def, |
| std::vector<PartialTensorShape>{}, {}, |
| /*input_tensors_as_shapes=*/{}, {}); |
| TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); |
| return shape_ic(&c); |
| } |
| |
| StatusOr<TensorType> ImporterBase::ConvertDataTypeAndShape( |
| DataType dtype, const shape_inference::ShapeHandle& handle, |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder) { |
| TF_ASSIGN_OR_RETURN(auto subtypes, |
| ConvertSubtypes(handle_subtypes, context, builder)); |
| |
| mlir::Type element_type; |
| if (dtype == DT_VARIANT) |
| element_type = mlir::TF::VariantType::get(subtypes, context_); |
| else if (dtype == DT_RESOURCE) |
| element_type = mlir::TF::ResourceType::get(subtypes, context_); |
| else |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(dtype, builder, &element_type)); |
| |
| return ConvertElementTypeAndShape(element_type, handle, context, builder); |
| } |
| |
| StatusOr<TensorType> ImporterBase::ConvertElementTypeAndShape( |
| mlir::Type element_type, const shape_inference::ShapeHandle& handle, |
| shape_inference::InferenceContext* context, mlir::Builder builder) { |
| if (!context->RankKnown(handle)) { |
| return mlir::UnrankedTensorType::get(element_type); |
| } |
| |
| // Sentinel for an unknown dimension size. getTensorType interprets any |
| // negative value as an unknown dimension. |
| // TODO(jmolloy): Ideally this shouldn't be a local sentinel. |
| const int64_t kUnknownDim = -1; |
| |
| absl::InlinedVector<int64_t, 4> dimensions; |
| int32 rank = context->Rank(handle); |
| dimensions.reserve(rank); |
| for (int i = 0; i < rank; ++i) { |
| auto dim_handle = context->Dim(handle, i); |
| if (!context->ValueKnown(dim_handle)) |
| dimensions.push_back(kUnknownDim); |
| else |
| dimensions.push_back(context->Value(dim_handle)); |
| } |
| |
| return mlir::RankedTensorType::get( |
| llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type); |
| } |
| |
| StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes( |
| const std::vector<shape_inference::ShapeAndType>* handle_subtypes, |
| shape_inference::InferenceContext* context, mlir::Builder builder) { |
| ElementSubtypes subtypes; |
| if (!handle_subtypes) return subtypes; |
| |
| subtypes.reserve(handle_subtypes->size()); |
| for (const auto& subtype : *handle_subtypes) { |
| mlir::Type element_type; |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type)); |
| TF_ASSIGN_OR_RETURN(TensorType type, |
| ConvertElementTypeAndShape(element_type, subtype.shape, |
| context, builder)); |
| subtypes.push_back(type); |
| } |
| return subtypes; |
| } |
| |
| Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, |
| const AttrValue& value, |
| NamedAttrList* attributes) { |
| TF_ASSIGN_OR_RETURN(auto func_attr, |
| ConvertFunctionCallName(value.func().name())); |
| attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); |
| |
| for (const auto& it : value.func().attr()) { |
| auto name = absl::StrCat(base_name, ".", it.first); |
| TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second)); |
| attributes->push_back(builder_.getNamedAttr(name, value)); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName( |
| const std::string& func_name) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(func_name)); |
| auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name]; |
| auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name); |
| return builder_.getSymbolRefAttr(func); |
| } |
| |
| StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue( |
| const AttrValue& value) { |
| switch (value.value_case()) { |
| case AttrValue::kFunc: { |
| // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. |
| // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue |
| // will not use this representation. |
| NamedAttrList attrs; |
| for (const auto& func_attr : value.func().attr()) { |
| TF_ASSIGN_OR_RETURN( |
| auto attr, ImporterBase::ConvertAttributeValue(func_attr.second)); |
| attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); |
| } |
| auto func_attrs = builder_.getDictionaryAttr(attrs); |
| return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); |
| } |
| case AttrValue::kList: { |
| if (!value.list().func().empty()) { |
| absl::InlinedVector<mlir::Attribute, 8> attrs; |
| for (const auto& item : value.list().func()) { |
| TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); |
| if (item.attr_size() != 0) |
| return errors::Unimplemented( |
| "func attributes with non-zero attr.size()"); |
| attrs.push_back(attr); |
| } |
| return builder_.getArrayAttr( |
| llvm::makeArrayRef(attrs.begin(), attrs.end())); |
| } |
| return ConvertNonFuncAttributeValue(value, &builder_); |
| } |
| default: |
| return ConvertNonFuncAttributeValue(value, &builder_); |
| } |
| } |
| |
| void ImporterBase::GetArgsAndRetsFromFunctionBody( |
| const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes) { |
| arg_nodes->reserve(fbody.arg_nodes.size()); |
| ret_nodes->reserve(fbody.ret_nodes.size()); |
| for (auto arg : fbody.arg_nodes) { |
| arg_nodes->emplace_back(arg, 0); |
| } |
| for (auto ret : fbody.ret_nodes) { |
| ret_nodes->emplace_back(ret, 0); |
| } |
| *control_ret_nodes = fbody.control_ret_nodes; |
| } |
| |
| Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { |
| // If the library function has been converted already, nothing needs to be |
| // done. |
| if (tf_name_to_mlir_name_->find(std::string(func_name)) != |
| tf_name_to_mlir_name_->end()) |
| return Status::OK(); |
| |
| std::string mlir_func_name( |
| function_name_uniquifier_->GetUniqueName(func_name)); |
| (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name; |
| |
| const auto& func_lib = graph_flib_; |
| const auto* func_def = func_lib.Find(std::string(func_name)); |
| if (func_def == nullptr) { |
| return errors::FailedPrecondition( |
| absl::StrCat("Failed to find function '", StringRefToView(func_name), |
| "'. The imported TensorFlow GraphDef is ill-formed.")); |
| } |
| |
| // Converts the function definition to a graph. |
| std::unique_ptr<FunctionBody> fbody; |
| TF_RETURN_IF_ERROR( |
| FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody)); |
| |
| // Converts the argument and return types to MLIR types. |
| absl::InlinedVector<mlir::NamedAttribute, 8> attributes; |
| attributes.reserve(func_def->attr_size()); |
| for (const auto& name_and_value : func_def->attr()) { |
| // This is a function definition attribute, so it shouldn't contain |
| // kFunc attribute and it is treated as normal one. |
| TF_ASSIGN_OR_RETURN(auto attr, |
| ConvertAttributeValue(name_and_value.second)); |
| std::string attr_name = |
| mangling_util::MangleAttributeName(name_and_value.first); |
| attributes.push_back(builder_.getNamedAttr(attr_name, attr)); |
| } |
| |
| // Checks opdef stateful attribute and import that as Function Attribute |
| if (func_def->signature().is_stateful()) { |
| auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName(); |
| attributes.push_back( |
| builder_.getNamedAttr(stateful_str, builder_.getUnitAttr())); |
| } |
| |
| // Checks for an associated custom gradient function. Adds it to the attribute |
| // list of this function. |
| auto grad_func_name = func_lib.FindGradient(std::string(func_name)); |
| if (!grad_func_name.empty()) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name)); |
| auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name]; |
| auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name); |
| auto gradient_attr = builder_.getSymbolRefAttr(grad_func); |
| auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); |
| attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); |
| } |
| |
| // Converts the graph to an MLIR function and adds it to the module. |
| // We populate the NodeSpec so that all the _Arg ops get their shape |
| // added correctly. |
| GraphImportConfig specs; |
| specs.enable_shape_inference = specs_.enable_shape_inference; |
| for (const auto& name_and_value : func_def->attr()) { |
| if (name_and_value.first == "_input_shapes") { |
| auto& list = name_and_value.second.list(); |
| auto& signature = func_def->signature(); |
| if (list.shape_size() != signature.input_arg_size()) { |
| return errors::FailedPrecondition( |
| "Number of input arguments must be equal to the length of " |
| "_input_shapes attribute in function '", |
| StringRefToView(func_name), "'."); |
| } |
| for (int i = 0; i < list.shape_size(); i++) { |
| auto& input_arg = signature.input_arg(i); |
| auto& array_info = specs.inputs[input_arg.name()]; |
| array_info.imported_dtype = input_arg.type(); |
| array_info.shape = list.shape(i); |
| } |
| } |
| } |
| |
| ImporterBase child_importer(graph_flib_, debug_info_, specs, module_, |
| tf_name_to_mlir_name_, function_name_uniquifier_, |
| func_name); |
| TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph)); |
| |
| TF_ASSIGN_OR_RETURN(auto func_type, |
| child_importer.InferLibFunctionType(*fbody)); |
| |
| absl::InlinedVector<OutputTensor, 4> arg_nodes; |
| absl::InlinedVector<OutputTensor, 4> ret_nodes; |
| absl::InlinedVector<Node*, 4> control_ret_nodes; |
| GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, |
| &control_ret_nodes); |
| |
| TF_RETURN_IF_ERROR(child_importer.Convert( |
| mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, |
| llvm::makeArrayRef(attributes.begin(), attributes.end()))); |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::PruneUnreachableNodes( |
| std::unordered_map<string, Node*>* node_name_map) { |
| std::unordered_set<const Node*> prune_start; |
| TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); |
| |
| if (!prune_start.empty()) { |
| if (PruneForReverseReachability(graph_.get(), prune_start)) { |
| VLOG(1) << "Pruned unused nodes in graphdef"; |
| } else { |
| VLOG(1) << "No unused nodes in graphdef to prune"; |
| } |
| } else { |
| VLOG(1) << "No output nodes specified, skipping pruning"; |
| } |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::ConvertFeedsToPlaceholders( |
| std::unordered_map<string, Node*>* node_name_map) { |
| // Feeds (edges) are converted into single-output placeholder nodes to |
| // simplify the conversion process. |
| TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); |
| for (const auto& it : feeds_by_node) { |
| TensorId tensor = ParseTensorName(it.first); |
| auto jt = node_name_map->find(std::string(tensor.node())); |
| if (jt == node_name_map->end()) { |
| return errors::FailedPrecondition( |
| absl::StrCat("Graph does not contain node: ", tensor.node())); |
| } |
| |
| Node* node = jt->second; |
| auto op_name = node->op_def().name(); |
| if (op_name != "Placeholder" && op_name != "LegacyFedInput" && |
| op_name != FunctionLibraryDefinition::kArgOp) { |
| for (const auto& output_tensor : it.second) { |
| const int index = output_tensor.first; |
| const ArrayInfo& array_info = output_tensor.second->second; |
| |
| DataType dtype = array_info.imported_dtype; |
| // Uses the existing output type if it isn't specified by the user. |
| if (dtype == DT_INVALID) { |
| dtype = node->output_type(index); |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| auto placeholder_node_and_removed, |
| CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, |
| *node_name_map)); |
| |
| Node* placeholder_node = placeholder_node_and_removed.first; |
| if (placeholder_node->in_edges().empty()) { |
| graph_->AddControlEdge(graph_->source_node(), placeholder_node, |
| true /* skip test for duplicates */); |
| } |
| if (placeholder_node->out_edges().empty()) { |
| graph_->AddControlEdge(placeholder_node, graph_->sink_node(), |
| true /* skip test for duplicates */); |
| } |
| remapped_feeds_[{it.first, index}] = placeholder_node->name(); |
| (*node_name_map)[placeholder_node->name()] = placeholder_node; |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::PrepareConvert(const Graph& graph) { |
| TF_RETURN_IF_ERROR(RemoveBackedges(graph)); |
| TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get())); |
| |
| auto node_name_map = graph_->BuildNodeNameIndex(); |
| |
| if (specs_.enable_shape_inference) { |
| // TODO(jpienaar): Remove once infer shapes on import flag is removed. |
| TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map)); |
| } else { |
| TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map)); |
| } |
| |
| // Prune nodes in the graph that are not reachable from the output. |
| if (specs_.prune_unused_nodes) { |
| TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map)); |
| } |
| |
| if (!specs_.enable_shape_inference) { |
| // Re-initialize ordered_nodes_ since we might have modified the graph. |
| GetReversePostOrder( |
| *graph_, &ordered_nodes_, |
| [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::Convert( |
| llvm::StringRef func_name, mlir::FunctionType func_type, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes, |
| llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
| // TODO(b/122040776): Uses debug info for FunctionDef. |
| auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), |
| func_name, func_type, attrs); |
| |
| module_.push_back(function); |
| // Seeds the builder with an initial block. |
| function.addEntryBlock(); |
| builder_ = mlir::OpBuilder(function.getBody()); |
| |
| // Create the graph operation in which we will convert the individual nodes. |
| auto graph = builder_.create<mlir::tf_executor::GraphOp>( |
| function.getLoc(), func_type.getResults()); |
| builder_.createBlock(&graph.body()); |
| |
| for (const Node* node : ordered_nodes_) { |
| TF_RETURN_IF_ERROR(ConvertNode(*node)); |
| } |
| |
| // Adds the backedges back to the function by creating the source and sink |
| // pairs. |
| TF_RETURN_IF_ERROR(AddBackedges()); |
| |
| TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph, |
| func_type.getInputs(), arg_nodes, |
| ret_nodes, control_ret_nodes)); |
| |
| // TODO(jpienaar): Update post removing shape_refinier_. |
| if (!specs_.enable_shape_inference) { |
| // Refine graph's type given more precise fetch. |
| auto fetch = graph.GetFetch(); |
| bool all_equal = true; |
| for (auto it : |
| llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) { |
| auto rt = std::get<1>(it); |
| if (rt == std::get<0>(it).getType()) continue; |
| std::get<0>(it).setType(rt); |
| all_equal = false; |
| } |
| if (!all_equal) { |
| function.setType(mlir::FunctionType::get(function.getContext(), |
| func_type.getInputs(), |
| graph.getResultTypes())); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::ConvertFunctionArgAndRets( |
| mlir::FuncOp func, mlir::tf_executor::GraphOp graph_op, |
| llvm::ArrayRef<mlir::Type> arg_types, |
| const absl::InlinedVector<OutputTensor, 4>& arg_nodes, |
| const absl::InlinedVector<OutputTensor, 4>& ret_nodes, |
| const absl::InlinedVector<Node*, 4>& control_ret_nodes) { |
| auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) { |
| for (const auto& node_attr : node->attrs()) { |
| const auto& key = node_attr.first; |
| // Only import optional attributes (e.g., those starting with an |
| // underscore). |
| if (key.empty() || key[0] != '_') continue; |
| // Ignore shape inference attributes as shape information is already |
| // populated in the result type. |
| if (IsOutputShapesAttribute(node_attr.second, key) || |
| IsResourceOutputShapesAttribute(node_attr.second, key)) |
| continue; |
| TF_ASSIGN_OR_RETURN(auto converted_attr, |
| ConvertAttributeValue(node_attr.second)); |
| std::string dialect_attribute = "tf." + key; |
| if (is_arg) { |
| func.setArgAttr(index, dialect_attribute, converted_attr); |
| } else { |
| func.setResultAttr(index, dialect_attribute, converted_attr); |
| } |
| } |
| return Status::OK(); |
| }; |
| |
| auto* bb = &func.front(); |
| llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4> |
| arg_nodes_to_values; |
| for (int i = 0, e = arg_types.size(); i < e; ++i) { |
| auto& arg_node = arg_nodes[i]; |
| // The lookup can't fail here: otherwise some nodes in the function haven't |
| // be converted to mlir operations and don't have a mapping. |
| mlir::Operation* island = node_values_.find(arg_node.node->id())->second; |
| |
| auto bb_arg = bb->getArgument(i); |
| mlir::Value arg_def = bb_arg; |
| |
| if (island->getNumResults() != 2) |
| return errors::InvalidArgument( |
| "Only feed output tensors of single output nodes are supported"); |
| |
| // Collect mapping of OutputTensor to associated block arg. |
| arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); |
| island->getResult(0).replaceAllUsesWith(arg_def); |
| // Erase control outputs from feed. |
| auto control_uses = island->getResult(1).getUses(); |
| for (auto& control_use : llvm::make_early_inc_range(control_uses)) |
| control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); |
| |
| if (!arg_node.node->requested_device().empty()) |
| func.setArgAttr( |
| i, "tf.device", |
| builder_.getStringAttr(arg_node.node->requested_device())); |
| |
| if (arg_node.node->IsArg()) { |
| TF_RETURN_IF_ERROR( |
| set_attributes_on_func(arg_node.node, i, /*is_arg=*/true)); |
| } |
| |
| island->dropAllReferences(); |
| island->erase(); |
| } |
| |
| llvm::SmallVector<mlir::Value, 8> inst_to_return; |
| for (auto ret_and_idx : llvm::enumerate(ret_nodes)) { |
| const auto& ret = ret_and_idx.value(); |
| auto* inst = node_values_[ret.node->id()]; |
| if (ret.node->IsRetval()) { |
| if (!ret.node->requested_device().empty()) |
| func.setResultAttr( |
| ret_and_idx.index(), "tf.device", |
| builder_.getStringAttr(ret.node->requested_device())); |
| TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(), |
| /*is_arg=*/false)); |
| // Lookup the instruction inside the island |
| auto island_op = llvm::cast<mlir::tf_executor::IslandOp>(inst); |
| mlir::Operation* inner_op = &island_op.GetBody().front(); |
| // Remove kRetOp or kDeviceRetOp operation and return its operand. |
| // kRetOp and kDeviceRetOp should have just one operand unless they have |
| // control dependencies. |
| if (inner_op->getNumOperands() != 1) |
| return errors::Unimplemented("Return node with multiple inputs."); |
| inst_to_return.push_back(inner_op->getOperand(0)); |
| inst->dropAllReferences(); |
| inst->erase(); |
| } else { |
| // Lookup and use block arg if fetch is a feed. |
| auto it = arg_nodes_to_values.find({ret.node, ret.index}); |
| if (it != arg_nodes_to_values.end()) |
| inst_to_return.push_back(it->second); |
| else |
| inst_to_return.push_back(inst->getResult(ret.index)); |
| } |
| } |
| |
| for (Node* control_ret : control_ret_nodes) { |
| auto* inst = node_values_[control_ret->id()]; |
| inst_to_return.push_back(*std::prev(inst->result_end())); |
| } |
| |
| // Terminate the function by adding a Fetch operation to terminate the graph |
| // and a return operation to return the Graph results. |
| builder_.setInsertionPointToEnd(&graph_op.body().front()); |
| builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(), |
| inst_to_return); |
| builder_.setInsertionPointToEnd(bb); |
| builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_), |
| graph_op.getResults()); |
| return Status::OK(); |
| } |
| |
| mlir::Location ImporterBase::GetLocation(const Node& node) { |
| DVLOG(1) << "Getting location for " << node.name() << " " << &node; |
| // TODO(b/142400497): What is the semantic contract for locations? |
| const auto& debug_info = debug_info_.traces(); |
| |
| // Create a location for node `name` in function `function_name`. |
| auto create_location = [&](llvm::StringRef name, |
| llvm::StringRef function_name) -> mlir::Location { |
| // Use the catenation of function and node names as the lookup key into the |
| // debug info. This matches the way that the key is formed on the python |
| // side. |
| // |
| // We also use this as the name for the NameLoc for ops in function, since |
| // otherwise our names could collide across functions. |
| // For ops in the main graph, we omit the "@function_name" (which, would be |
| // just "@" since function_name would be empty) because some code seems to |
| // depend on the name being this way for correctness. |
| std::string debug_info_key = (name + "@" + function_name).str(); |
| std::string name_for_name_loc = |
| function_name.empty() ? name.str() : (name + "@" + function_name).str(); |
| auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_); |
| |
| llvm::SmallVector<mlir::Location, 4> locations; |
| // Prefer stack traces if available, fallback to debug info if not, and then |
| // finally to just name. |
| if (auto stack_trace = node.GetStackTrace()) { |
| DVLOG(1) << "Stack available for " << node.name(); |
| absl::Span<const StackFrame> frames = stack_trace->ToFrames(); |
| locations.reserve(frames.size()); |
| for (const StackFrame& frame : llvm::reverse(frames)) { |
| auto file_name = mlir::Identifier::get(frame.file_name, context_); |
| // Use col 1 as there is no column info in StackTrace. |
| auto file_line_loc = mlir::FileLineColLoc::get( |
| file_name, frame.line_number, 1, context_); |
| locations.push_back(file_line_loc); |
| } |
| } else { |
| DVLOG(1) << "No stack trace for " << node.name(); |
| const auto location_it = debug_info.find(debug_info_key); |
| if (location_it != debug_info.end()) { |
| DVLOG(1) << "Available serialized debug info for " << node.name(); |
| // Convert the stack trace to a chain of mlir::CallSiteLocs. |
| const auto& trace = location_it->second; |
| locations.reserve(trace.file_line_cols_size()); |
| for (const auto& location : trace.file_line_cols()) { |
| const auto& file = debug_info_.files(location.file_index()); |
| auto file_name = mlir::Identifier::get(file, context_); |
| auto file_line_loc = mlir::FileLineColLoc::get( |
| file_name, location.line(), location.col(), context_); |
| locations.push_back(file_line_loc); |
| } |
| } |
| } |
| |
| // If there are no locations in the stack trace, fall back to just a |
| // NameLoc with no child. |
| if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_); |
| |
| // Use the front FileLineColLoc to generate a NameLoc. |
| mlir::Location node_name_loc = |
| mlir::NameLoc::get(name_loc_id, locations.front()); |
| |
| // If there are more locations then generate a stack trace, otherwise just |
| // return the name loc. |
| auto callsite_locs = llvm::makeArrayRef(locations).drop_front(); |
| return callsite_locs.empty() |
| ? node_name_loc |
| : mlir::CallSiteLoc::get(node_name_loc, callsite_locs); |
| }; |
| |
| // For NextIteration nodes, location is used to pair source and sink nodes. |
| // Hence, we use node name as location to keep it unique. |
| // TODO(prakalps): In future the plan is to use tokens to pair source/sink |
| // nodes. Then NextIteration nodes would not need to be handled separately. |
| if (node.type_string() == "NextIteration") |
| return create_location(node.name(), function_name_for_debug_info_); |
| |
| if (node.GetStackTrace()) |
| return create_location(node.name(), function_name_for_debug_info_); |
| |
| const auto& node_def = node.def(); |
| auto original_nodes = |
| node_def.experimental_debug_info().original_node_names(); |
| auto original_funcs = |
| node_def.experimental_debug_info().original_func_names(); |
| |
| if (original_nodes.empty()) { |
| return create_location(node.name(), function_name_for_debug_info_); |
| } else { |
| // If the original nodes are defined, then we use them to get a list of |
| // call sites, and then fuse them to a single fused location, with the name |
| // of the node_def. |
| llvm::SmallVector<mlir::Location, 4> node_locations; |
| node_locations.reserve(original_nodes.size() + 1); |
| |
| // store the names in the experimental_debug_info |
| for (int i = 0, e = original_nodes.size(); i != e; ++i) { |
| auto node_name = original_nodes[i]; |
| auto func_name = (i < original_funcs.size()) ? original_funcs[i] : ""; |
| node_locations.push_back(create_location(node_name, func_name)); |
| } |
| // store the name of the node_def |
| node_locations.push_back( |
| create_location(node.name(), function_name_for_debug_info_)); |
| return mlir::FusedLoc::get(node_locations, context_); |
| } |
| } |
| |
| Status ImporterBase::EmitErrorWithLocationStr(const Node& node, |
| const Status& error_status) { |
| const mlir::Location location = GetLocation(node); |
| mlir::emitError(location); |
| return error_handler_.Combine(error_status); |
| } |
| |
| mlir::Operation* ImporterBase::CreateOperation( |
| const Node& node, llvm::StringRef node_type_name, |
| const mlir::OperationState& result, |
| const llvm::SmallVectorImpl<mlir::Value>& control_operands) { |
| // For the tf.executor specific operations (not wrapped in an island), we |
| // have an extra returned value for the control result, and we concatenate |
| // control and non-control operands. |
| mlir::SmallVector<mlir::Type, 4> types(result.types); |
| types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext())); |
| mlir::SmallVector<mlir::Value, 4> operands(result.operands); |
| operands.append(control_operands.begin(), control_operands.end()); |
| |
| auto loc = result.location; |
| // Dispatch based on the name and create the appropriate operation. |
| if (node.IsSwitch()) { |
| // Switch and _SwitchN both are in switch class, differentiate based on |
| // op name. |
| if (node.op_def().name() == "_SwitchN") { |
| return builder_.create<mlir::tf_executor::SwitchNOp>(loc, types, operands, |
| result.attributes); |
| } |
| return builder_.create<mlir::tf_executor::SwitchOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsMerge()) { |
| return builder_.create<mlir::tf_executor::MergeOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsNextIteration()) { |
| // NextIteration is a bit special, we create a pair of operations that are |
| // linked together through a token returned by the source. |
| // We make use of a separate builder to insert the source at the top of |
| // the block. |
| mlir::OpBuilder builder_at_begin(builder_.getBlock(), |
| builder_.getBlock()->begin()); |
| auto source_op = |
| builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>( |
| loc, operands[0].getType(), result.attributes); |
| return builder_.create<mlir::tf_executor::NextIterationSinkOp>( |
| loc, source_op.token(), operands, result.attributes); |
| } |
| if (node.IsLoopCond()) { |
| return builder_.create<mlir::tf_executor::LoopCondOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsEnter()) { |
| return builder_.create<mlir::tf_executor::EnterOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsExit()) { |
| return builder_.create<mlir::tf_executor::ExitOp>(loc, types, operands, |
| result.attributes); |
| } |
| if (node.IsControlTrigger()) { |
| return builder_.create<mlir::tf_executor::ControlTriggerOp>( |
| loc, operands, result.attributes); |
| } |
| // Regular TensorFlow operation are wrapped in a tf_executor.island. |
| auto island = builder_.create<mlir::tf_executor::IslandOp>( |
| result.location, types, control_operands, |
| mlir::ArrayRef<mlir::NamedAttribute>{}); |
| island.body().push_back(new mlir::Block); |
| mlir::OpBuilder island_builder = |
| mlir::OpBuilder::atBlockEnd(&island.GetBody()); |
| |
| // Create the operation inside the island now. |
| mlir::Operation* inner_op = island_builder.createOperation(result); |
| |
| // Sets operand_segment_sizes or result_segment_sizes attribute to the op. |
| const auto set_segment_sizes_attr = |
| [&](const NameRangeMap& arg_ranges, |
| const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, |
| llvm::StringRef attr_name) { |
| std::vector<mlir::Attribute> values; |
| values.reserve(args.size()); |
| for (const auto& arg : args) { |
| auto range = arg_ranges.at(arg.name()); |
| values.push_back( |
| island_builder.getI32IntegerAttr(range.second - range.first)); |
| } |
| auto attr_type = |
| mlir::VectorType::get(args.size(), builder_.getIntegerType(32)); |
| auto attr_value = mlir::DenseElementsAttr::get(attr_type, values); |
| inner_op->setAttr(attr_name, attr_value); |
| }; |
| |
| if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>() || |
| inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) { |
| // The op has multiple variadic operands or results. |
| // Calculate operand and result segment sizes using the OpDef. |
| NameRangeMap input_ranges, output_ranges; |
| // This will fail only if the OpDef is syntactically invalid. |
| // TODO(jpienaar): Convert this CHECK into a properly propagated error. |
| TF_CHECK_OK( |
| NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges)); |
| if (inner_op->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) { |
| // Add derived "operand_segment_sizes" attr to the created operation. |
| // TODO(b/146937733): Don't use <void> here. |
| set_segment_sizes_attr(input_ranges, node.op_def().input_arg(), |
| mlir::OpTrait::AttrSizedOperandSegments< |
| void>::getOperandSegmentSizeAttr()); |
| } |
| |
| if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) { |
| // Add derived "result_segment_sizes" attr to the created operation. |
| // TODO(b/146937733): Don't use <void> here. |
| set_segment_sizes_attr(output_ranges, node.op_def().output_arg(), |
| mlir::OpTrait::AttrSizedResultSegments< |
| void>::getResultSegmentSizeAttr()); |
| } |
| } |
| |
| mlir::OperationName name = inner_op->getName(); |
| if (!name.getAbstractOperation() && |
| // Skip unmodelled ops that are handled differently. |
| (node_type_name != "_Arg" && node_type_name != "_Retval")) { |
| if (GetUnmodelledOpTypes().insert(name.getStringRef()).second) { |
| LOG(INFO) << "Unmodelled op type `" << node.type_string() << "`" |
| << (node.op_def().is_stateful() |
| ? " is stateful but effects not modelled" |
| : " is not stateful but will be treated as such " |
| "conservatively"); |
| } |
| } |
| |
| // Add the terminator for the island |
| island_builder.create<mlir::tf_executor::YieldOp>(result.location, |
| inner_op->getResults()); |
| return island.getOperation(); |
| } |
| |
| Status ImporterBase::ConvertNode(const Node& node) { |
| if (!node.IsOp()) { |
| // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by |
| // Graph and don't exist in GraphDef. |
| return Status::OK(); |
| } |
| |
| // If it is a custom OP, its definition should be found in the library. We |
| // create the MLIR function and insert it to the module if it doesn't exist. |
| std::string node_type_name = node.type_string(); |
| const auto* func_def = graph_flib_.Find(node_type_name); |
| bool convert_to_legacy_call = false; |
| if (func_def) { |
| TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name)); |
| node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; |
| convert_to_legacy_call = true; |
| } |
| |
| auto get_full_op_name = [&](const std::string& op_name) { |
| const char* kTfPrefix = "tf."; |
| return kTfPrefix + op_name; |
| }; |
| |
| std::string op_name = get_full_op_name(node_type_name); |
| if (back_edge_node_output_.contains(&node)) { |
| op_name = op_name + ".sink"; |
| } |
| |
| mlir::OperationState result(GetLocation(node), op_name); |
| for (int i = 0; i < node.num_outputs(); ++i) { |
| // The backedge has been removed, so we shouldn't count the corresponding |
| // output from the src node when converting to an operation. |
| if (back_edge_node_output_.contains(&node) && |
| back_edge_node_output_[&node] == i) { |
| continue; |
| } |
| TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_)); |
| result.types.push_back(type); |
| } |
| |
| // Surprisingly input edges can be nondeterministically ordered. This |
| // particularly seems to be the case for the control edges between _SOURCE |
| // and _SINK that the Graph constructor inserts. Copy the input edges and |
| // sort the edges, but only the control edges, not data edges! |
| // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes. |
| // They'll break roundtripping anyway unless we strip them when converting |
| // back to graphdef. |
| absl::InlinedVector<const Edge*, 8> in_edges(node.in_edges().size()); |
| absl::c_copy(node.in_edges(), in_edges.begin()); |
| absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) { |
| if (e1->IsControlEdge() && !e2->IsControlEdge()) return false; |
| if (!e1->IsControlEdge() && e2->IsControlEdge()) return true; |
| if (e1->IsControlEdge() && e2->IsControlEdge()) |
| return e1->src()->id() < e2->src()->id(); |
| return e1->dst_input() < e2->dst_input(); |
| }); |
| |
| result.operands.reserve(in_edges.size()); |
| |
| // Collect the control operands separately, they will be held by the island. |
| mlir::SmallVector<mlir::Value, 8> control_operands; |
| |
| for (const auto* input_edge : in_edges) { |
| const Node& input_node = *input_edge->src(); |
| if (input_node.IsSource()) { |
| if (in_edges.size() != 1) { |
| return errors::FailedPrecondition( |
| "The node has other inputs besides the _Source node"); |
| } |
| // We don't import the _SOURCE node. |
| continue; |
| } |
| if (input_node.IsArg() && input_edge->IsControlEdge()) { |
| // Currently we have not reached consensus as to what TF function |
| // semantics are (b/133509504). Here we assume that all arguments to a |
| // function should be available before we start execution of any internal |
| // node. This makes the control dependencies between function arguments |
| // and internal nodes redundant, and so we do not import them. The TF |
| // inliner however assumes no such dependency between function args and |
| // internal nodes exists, unless explicitly stated. Since we drop control |
| // dependencies here, it leads to loss of information. If the function is |
| // inlined later, the inliner would not know of these explicit control |
| // dependencies present in the original graph. |
| continue; |
| } |
| if (node_values_.find(input_node.id()) == node_values_.end()) |
| return errors::FailedPrecondition( |
| "Graph not traversed in reverse post order; use seen before def!"); |
| mlir::Operation* inst = node_values_[input_node.id()]; |
| if (input_edge->IsControlEdge()) |
| control_operands.push_back(inst->getResult(inst->getNumResults() - 1)); |
| else |
| result.operands.push_back(inst->getResult(input_edge->src_output())); |
| } |
| |
| using FuncPairType = std::pair<const std::string*, const AttrValue*>; |
| std::vector<FuncPairType> funcs; |
| result.attributes.reserve(node.attrs().size() + 2); |
| auto abstract_op = result.name.getAbstractOperation(); |
| auto derived_op = |
| abstract_op |
| ? abstract_op->getInterface<mlir::DerivedAttributeOpInterface>() |
| : nullptr; |
| for (const auto& name_and_value : node.attrs()) { |
| const auto& attr_name = name_and_value.first; |
| // Skip adding derived attributes to the generated op. |
| if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue; |
| const AttrValue& attr_value = name_and_value.second; |
| |
| // Remove _output_shapes attribute that will be added by the exporter. |
| if (IsOutputShapesAttribute(attr_value, attr_name)) continue; |
| |
| if (attr_value.value_case() == AttrValue::kFunc) { |
| // Attribute iteration order is not defined for protocol buffer Map. |
| // Process function attributes separately in the lexicographical order to |
| // have deterministic order of functions in the constructed IR. |
| funcs.emplace_back(&attr_name, &attr_value); |
| } else { |
| TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value)); |
| result.attributes.push_back(builder_.getNamedAttr(attr_name, attr)); |
| } |
| } |
| |
| auto comparator = [](const FuncPairType& a, const FuncPairType& b) { |
| return *a.first < *b.first; |
| }; |
| std::sort(funcs.begin(), funcs.end(), comparator); |
| for (const auto& func : funcs) { |
| TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second, |
| &result.attributes)); |
| } |
| |
| const auto& node_def = node.def(); |
| result.attributes.push_back(builder_.getNamedAttr( |
| "device", builder_.getStringAttr(std::string(node_def.device())))); |
| |
| // Map user function calls to LegacyCall ops and add the user function name |
| // as an attribute. |
| if (convert_to_legacy_call) { |
| result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_); |
| mlir::SymbolRefAttr val = builder_.getSymbolRefAttr(node_type_name); |
| result.addAttribute("f", val); |
| |
| if (!result.attributes.get("_disable_call_shape_inference")) { |
| result.addAttribute("_disable_call_shape_inference", |
| builder_.getBoolAttr(false)); |
| } |
| } |
| |
| auto composite_control_flow_op = [&](const std::string& name) { |
| result.name = mlir::OperationName(get_full_op_name(name), context_); |
| bool stateless = absl::StartsWith(node_type_name, "Stateless"); |
| mlir::BoolAttr val = builder_.getBoolAttr(stateless); |
| result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); |
| }; |
| |
| // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common |
| // Case/If/While op in MLIR and add the differentiating attribute. |
| if (node.IsCaseNode()) composite_control_flow_op("Case"); |
| if (node.IsIfNode()) composite_control_flow_op("If"); |
| if (node.IsWhileNode()) { |
| composite_control_flow_op("While"); |
| auto* output_shapes = node.attrs().Find("output_shapes"); |
| if (output_shapes && !output_shapes->list().shape().empty()) |
| result.attributes.push_back( |
| builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr())); |
| } |
| |
| // Register the mapping between the TF node and the newly created operation. |
| node_values_[node.id()] = |
| CreateOperation(node, node_type_name, result, control_operands); |
| return Status::OK(); |
| } |
| |
| // Add the backedges to the CFG. Given a backedge, we replace the original |
| // source and destination operations by two new operations. Most of the |
| // fields of the replacements are copied from the original operations. |
| // However, |
| // - for the src operation, one output is inserted to the front of the output |
| // list. The type of the output is set to the type of the non-control result |
| // of the dst operation, and |
| // - for the dst operation, one operand is inserted to the front of the |
| // operand list. This operand is using the first result of the src |
| // operation. |
| // TODO(fengliuai): Preserve the order of the results and operands if |
| // necessary. |
| Status ImporterBase::AddBackedges() { |
| for (auto it : back_edge_dst_inputs_) { |
| BackEdge& edge = it.second; |
| if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) { |
| return errors::FailedPrecondition( |
| "Invalid backedge; should be from NextIteration to Merge!"); |
| } |
| auto* sink = node_values_[edge.src->id()]; |
| auto* dst = node_values_[edge.dst->id()]; |
| TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input)); |
| } |
| return Status::OK(); |
| } |
| |
| Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, |
| int dst_input) { |
| // Get the NextIteration.Source operation from the token operand of the sink. |
| mlir::Operation* source = sink->getOperand(0).getDefiningOp(); |
| |
| // Adds the "source" to the operands of the dst by creating a new dst |
| // operation. |
| mlir::OperationState state(dst->getLoc(), dst->getName()); |
| auto num_operands = dst->getNumOperands(); |
| state.operands.reserve(num_operands + 1); |
| for (int input = 0, e = num_operands + 1; input != e; ++input) { |
| if (input < dst_input) { |
| state.operands.push_back(dst->getOperand(input)); |
| } else if (input == dst_input) { |
| state.operands.push_back(source->getResult(0)); |
| } else { |
| state.operands.push_back(dst->getOperand(input - 1)); |
| } |
| } |
| state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end()); |
| state.types.assign(dst->getResultTypes().begin(), |
| dst->getResultTypes().end()); |
| builder_.setInsertionPoint(dst); |
| auto* new_dst = builder_.createOperation(state); |
| |
| // Replaces the output uses of the old operation by the corresponding |
| // result of the new operation, and deletes the old operation. |
| for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) { |
| auto new_output = new_dst->getResult(i); |
| dst->getResult(i).replaceAllUsesWith(new_output); |
| } |
| dst->dropAllReferences(); |
| dst->erase(); |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType( |
| const FunctionBody& fbody) { |
| mlir::Builder builder(context_); |
| |
| // The FunctionBody contains a graph with a single-output _Arg node for each |
| // function argument and a single-input _Retval node for each function return |
| // value. |
| // |
| // We already populated the ShapeRefiner with all the information about the |
| // shapes of these graph edges, so we just query it to build the corresponding |
| // MLIR function type signature. |
| |
| llvm::SmallVector<mlir::Type, 4> arg_types; |
| if (specs_.inputs.empty()) { |
| arg_types.reserve(fbody.arg_types.size()); |
| for (auto arg : fbody.arg_nodes) { |
| // Find node in the graph using the node id instead of using `arg` |
| // directly because the graph has been cloned. |
| auto* node = graph_->FindNodeId(arg->id()); |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferOutputType(*node, /*idx=*/0, builder)); |
| arg_types.push_back(type); |
| } |
| } else { |
| arg_types.reserve(fbody.arg_types.size()); |
| for (const auto& it : llvm::enumerate(specs_.inputs)) { |
| mlir::Type element_type; |
| const auto& node_info = it.value().second; |
| DataType dtype = node_info.imported_dtype; |
| // Uses the existing output type of the arg node if the data type of the |
| // the node isn't specified through the import configuration. |
| if (dtype == DT_INVALID) { |
| auto arg = fbody.arg_nodes[it.index()]; |
| auto* node = graph_->FindNodeId(arg->id()); |
| dtype = node->output_type(0); |
| if (dtype == DT_INVALID) { |
| return errors::InvalidArgument("Input ", it.index(), |
| "has invalid data type"); |
| } |
| } |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(dtype, builder, &element_type)); |
| if (node_info.shape.unknown_rank()) { |
| arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); |
| } else { |
| llvm::SmallVector<int64_t, 4> shape; |
| TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); |
| arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); |
| } |
| } |
| } |
| |
| llvm::SmallVector<mlir::Type, 4> ret_types; |
| ret_types.reserve(fbody.ret_types.size()); |
| for (auto ret : fbody.ret_nodes) { |
| // Find node in the graph using the node id instead of using `ret` directly |
| // because the graph has been cloned. |
| auto* node = graph_->FindNodeId(ret->id()); |
| TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); |
| ret_types.push_back(type); |
| } |
| |
| return builder.getFunctionType(arg_types, ret_types); |
| } |
| |
| // Stateful helper class to import a TensorFlow model expressed in GraphDef into |
| // an MLIR Module. |
| // |
| // The nodes defined in the graph are converted to a function called |
| // 'func_name'. All library function definitions are converted to MLIR functions |
| // in the module. |
| class GraphDefImporter : public ImporterBase { |
| public: |
| // Main entry point: converts the given graph to an MLIR Module. |
| static StatusOr<mlir::OwningModuleRef> Convert( |
| mlir::MLIRContext* context, const Graph& graph, |
| const GraphDebugInfo& debug_info, |
| const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, |
| llvm::StringRef func_name); |
| |
| private: |
| explicit GraphDefImporter( |
| const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::ModuleOp module, |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name, |
| NameUniquifier* function_name_uniquifier) |
| : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, |
| function_name_uniquifier) {} |
| |
| // Returns the function signature of the main function of converted MLIR |
| // module, the input nodes and output nodes. The type and shape information |
| // for the function arguments are read from `specs`, but the type and shape |
| // information for the function returns are inferred by the shape refiner in |
| // ImporterBase. |
| StatusOr<mlir::FunctionType> InferMainFunctionType( |
| const GraphImportConfig& specs, mlir::MLIRContext* context, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes); |
| |
| // Returns the function signature of the main function, alongside input and |
| // output nodes, for function graphs. Arguments and return values are |
| // determined by node op type. Type and shape information of the function are |
| // inferred by the shape refiner in ImporterBase. |
| StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph( |
| mlir::MLIRContext* context, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes); |
| |
| // Finds the graph's target nodes/function's control ret nodes based on |
| // supplied node names in `control_outputs`. If `control_outputs` are not |
| // unique or a control ret node is missing, an error will be returned. |
| Status GetControlRetsFromGraph( |
| llvm::ArrayRef<std::string> control_outputs, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes); |
| }; |
| |
| StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert( |
| mlir::MLIRContext* context, const Graph& graph, |
| const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, |
| const GraphImportConfig& specs, llvm::StringRef func_name) { |
| LoadImporterDialects(*context); |
| mlir::OwningModuleRef module = |
| mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); |
| std::unordered_map<std::string, std::string> tf_name_to_mlir_name; |
| NameUniquifier function_name_uniquifier(flib_def); |
| |
| GraphDefImporter importer(flib_def, debug_info, specs, module.get(), |
| &tf_name_to_mlir_name, &function_name_uniquifier); |
| |
| TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); |
| |
| mlir::FunctionType func_type; |
| absl::InlinedVector<OutputTensor, 4> arg_nodes; |
| absl::InlinedVector<OutputTensor, 4> ret_nodes; |
| absl::InlinedVector<Node*, 4> control_ret_nodes; |
| llvm::SmallVector<mlir::NamedAttribute, 1> attrs; |
| if (specs.graph_as_function) { |
| if (specs.prune_unused_nodes || !specs.inputs.empty() || |
| !specs.outputs.empty()) |
| return errors::InvalidArgument( |
| "Pruning of graph is currently unsupported when the main graph is " |
| "converted to a function."); |
| |
| TF_ASSIGN_OR_RETURN(func_type, |
| importer.GetArgsRetsAndTypesFromFunctionGraph( |
| context, &arg_nodes, &ret_nodes)); |
| |
| TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, |
| &control_ret_nodes)); |
| |
| mlir::Builder b(context); |
| std::string s; |
| llvm::raw_string_ostream ss(s); |
| auto node_name = [&](const OutputTensor& tensor) { |
| ss << tensor.node->name(); |
| }; |
| llvm::interleave(arg_nodes, ss, node_name, ","); |
| auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| llvm::interleave(ret_nodes, ss, node_name, ","); |
| auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| llvm::interleave(specs.control_outputs, ss, ","); |
| auto control_outputs = |
| b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); |
| |
| // Under `graph_as_function` mode, `tf.entry_function` is always set as it |
| // is assumed feed, fetch, and target nodes are set correctly. |
| attrs.push_back(b.getNamedAttr( |
| "tf.entry_function", |
| b.getDictionaryAttr({inputs, outputs, control_outputs}))); |
| } else { |
| // Collects the argument and return nodes by looking up the node names |
| // specified by the user. |
| TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType( |
| specs, context, &arg_nodes, &ret_nodes)); |
| |
| TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, |
| &control_ret_nodes)); |
| |
| // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and |
| // decoding in a centralized place. |
| // Record the input and output mapping. |
| if (!specs.inputs.empty() || !specs.outputs.empty() || |
| !specs.control_outputs.empty()) { |
| mlir::Builder b(context); |
| std::string s; |
| llvm::raw_string_ostream ss(s); |
| llvm::interleave( |
| specs.inputs, ss, |
| [&](const std::pair<std::string, ArrayInfo>& v) { ss << v.first; }, |
| ","); |
| auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| llvm::interleave(specs.outputs, ss, ","); |
| auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); |
| s.clear(); |
| llvm::interleave(specs.control_outputs, ss, ","); |
| auto control_outputs = |
| b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); |
| |
| attrs.push_back(b.getNamedAttr( |
| "tf.entry_function", |
| b.getDictionaryAttr({inputs, outputs, control_outputs}))); |
| } |
| } |
| |
| // Record version info. |
| PopulateTfVersions(module.get(), graph.versions()); |
| |
| TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( |
| func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs)); |
| |
| // Mark main function public, others private. |
| for (auto function : module.get().getOps<mlir::FuncOp>()) { |
| auto visibility = function.getName() == func_name |
| ? mlir::FuncOp::Visibility::Public |
| : mlir::FuncOp::Visibility::Private; |
| function.setVisibility(visibility); |
| } |
| return module; |
| } |
| |
| StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType( |
| const GraphImportConfig& specs, mlir::MLIRContext* context, |
| absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes) { |
| // Find all the input nodes and output nodes. |
| // Feeds have been remapped to single output nodes (Placeholder), so an exact |
| // name match is sufficient. |
| absl::flat_hash_map<absl::string_view, int> inputs; |
| for (auto input_and_idx : llvm::enumerate(specs.inputs)) { |
| TensorId tensor = ParseTensorName(input_and_idx.value().first); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| inputs.insert({remapped_it->second, input_and_idx.index()}); |
| } else { |
| inputs.insert({tensor.node(), input_and_idx.index()}); |
| } |
| } |
| |
| absl::flat_hash_set<absl::string_view> output_node_names; |
| std::vector<TensorId> outputs; |
| output_node_names.reserve(specs.outputs.size()); |
| for (const auto& output : specs.outputs) { |
| TensorId tensor = ParseTensorName(output); |
| auto remapped_it = remapped_feeds_.find(tensor); |
| if (remapped_it != remapped_feeds_.end()) { |
| output_node_names.insert(remapped_it->second); |
| outputs.push_back({remapped_it->second, 0}); |
| } else { |
| output_node_names.insert(tensor.node()); |
| outputs.push_back(tensor); |
| } |
| } |
| |
| if (!inputs.empty() || !outputs.empty()) { |
| arg_nodes->resize(inputs.size()); |
| ret_nodes->resize(outputs.size()); |
| |
| for (Node* n : GetOrderedNodes()) { |
| // Handle inputs/arguments. |
| auto input_it = inputs.find(n->name()); |
| if (input_it != inputs.end()) { |
| (*arg_nodes)[input_it->second] = {n, 0}; |
| } |
| |
| // Handle outputs/returns. |
| if (output_node_names.contains(n->name())) { |
| for (int i = 0, e = outputs.size(); i != e; ++i) { |
| TensorId tensor = outputs[i]; |
| if (n->name() != tensor.node()) continue; |
| (*ret_nodes)[i] = {n, tensor.index()}; |
| } |
| } |
| } |
| } |
| |
| // Starts to construct the function type. |
| mlir::Builder builder(context); |
| llvm::SmallVector<mlir::Type, 4> arg_types; |
| arg_types.reserve(specs.inputs.size()); |
| int i = 0; |
| for (const auto& it : specs.inputs) { |
| Node* arg_node = arg_nodes->at(i).node; |
| if (arg_node == nullptr) { |
| return errors::InvalidArgument("Input ", it.first, |
| " was not found in graph"); |
| } |
| mlir::Type element_type; |
| const auto& node_info = it.second; |
| DataType imported_dtype = node_info.imported_dtype; |
| // Uses the existing output type of the arg node if the data type of the |
| // the node isn't specified through the import configuration. |
| if (imported_dtype == DT_INVALID) { |
| imported_dtype = arg_node->output_type(0); |
| if (imported_dtype == DT_INVALID) { |
| return errors::InvalidArgument("Input ", i, "has invalid data type"); |
| } |
| } |
| TF_RETURN_IF_ERROR( |
| ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type)); |
| if (node_info.shape.unknown_rank()) { |
| arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); |
| } else { |
| llvm::SmallVector<int64_t, 4> shape; |
| TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); |
| arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); |
| } |
| i++; |
| } |
| |
| llvm::SmallVector<mlir::Type, 4> ret_types; |
| ret_types.reserve(specs.outputs.size()); |
| for (int i = 0, e = specs.outputs.size(); i != e; ++i) { |
| if (ret_nodes->at(i).node == nullptr) { |
| return errors::InvalidArgument("Output ", specs.outputs[i], |
| " was not found in graph"); |
| } |
| } |
| for (const auto& ret : *ret_nodes) { |
| if (ret.node->num_outputs() <= ret.index) { |
| return errors::InvalidArgument("Invalid output index ", ret.index, |
| " specified for node: ", ret.node->name()); |
| } |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferOutputType(*ret.node, ret.index, builder)); |
| ret_types.push_back(type); |
| } |
| |
| return builder.getFunctionType(arg_types, ret_types); |
| } |
| |
| StatusOr<mlir::FunctionType> |
| GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( |
| mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes, |
| absl::InlinedVector<OutputTensor, 4>* ret_nodes) { |
| auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) { |
| auto* attr = node->attrs().Find("index"); |
| if (!attr) |
| return errors::InvalidArgument(node->type_string(), " node '", |
| node->name(), |
| "' is missing attribute 'index'"); |
| |
| auto index = attr->i(); |
| const int num_nodes = nodes->size(); |
| if (num_nodes < index + 1) nodes->resize(index + 1); |
| |
| if ((*nodes)[index].node != nullptr) |
| return errors::InvalidArgument(node->type_string(), " node '", |
| node->name(), "' has attribute 'index' ", |
| index, " that conflicts with node '", |
| (*nodes)[index].node->name(), "'"); |
| (*nodes)[index] = {node, 0}; |
| |
| return Status::OK(); |
| }; |
| |
| // Collect arg and ret nodes from graph. |
| for (auto* node : GetOrderedNodes()) |
| if (node->IsArg()) |
| TF_RETURN_IF_ERROR(add_node(node, arg_nodes)); |
| else if (node->IsRetval()) |
| TF_RETURN_IF_ERROR(add_node(node, ret_nodes)); |
| |
| // Collect arg and ret types and create function type. |
| mlir::Builder builder(context); |
| llvm::SmallVector<mlir::Type, 4> arg_types; |
| arg_types.reserve(arg_nodes->size()); |
| for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) { |
| auto& arg_node = arg_node_and_idx.value(); |
| if (arg_node.node == nullptr) |
| return errors::InvalidArgument("Graph missing _Arg at index ", |
| arg_node_and_idx.index()); |
| |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferOutputType(*arg_node.node, /*idx=*/0, builder)); |
| arg_types.push_back(type); |
| } |
| |
| llvm::SmallVector<mlir::Type, 4> ret_types; |
| ret_types.reserve(ret_nodes->size()); |
| for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) { |
| auto& ret_node = ret_node_and_idx.value(); |
| if (ret_node.node == nullptr) |
| return errors::InvalidArgument("Graph missing _Retval at index ", |
| ret_node_and_idx.index()); |
| |
| TF_ASSIGN_OR_RETURN(auto type, |
| InferInputType(*ret_node.node, /*idx=*/0, builder)); |
| ret_types.push_back(type); |
| } |
| |
| return builder.getFunctionType(arg_types, ret_types); |
| } |
| |
| Status GraphDefImporter::GetControlRetsFromGraph( |
| llvm::ArrayRef<std::string> control_outputs, |
| absl::InlinedVector<Node*, 4>* control_ret_nodes) { |
| if (control_outputs.empty()) return Status::OK(); |
| |
| llvm::SmallDenseMap<llvm::StringRef, int32_t> controls_to_idx; |
| for (auto control_and_idx : llvm::enumerate(control_outputs)) |
| controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()}); |
| |
| if (controls_to_idx.size() != control_outputs.size()) |
| return errors::InvalidArgument("Control outputs must be unique"); |
| |
| control_ret_nodes->resize(controls_to_idx.size()); |
| |
| for (auto* node : GetOrderedNodes()) { |
| auto it = controls_to_idx.find(node->name()); |
| if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node; |
| } |
| |
| for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs)) |
| if (std::get<0>(node_and_name) == nullptr) |
| return errors::InvalidArgument( |
| "Control output '", std::get<1>(node_and_name), "' is missing"); |
| |
| return Status::OK(); |
| } |
| |
| // Stateful helper class to import a TensorFlow model expressed in SavedModel |
| // into an MLIR Module. |
| class SavedModelObjectGraphImporter : public ImporterBase { |
| public: |
| // Main entry point: converts all functions in the given meta graph to an MLIR |
| // Module. |
| static StatusOr<mlir::OwningModuleRef> Convert( |
| SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context, bool add_default_attributes); |
| |
| private: |
| explicit SavedModelObjectGraphImporter( |
| const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::ModuleOp module, |
| std::unordered_map<std::string, std::string>* tf_name_to_mlir_name, |
| NameUniquifier* function_name_uniquifier) |
| : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, |
| function_name_uniquifier) {} |
| }; |
| |
| // Determines the names used to reference objects in the SavedObjectGraph. |
| class ObjectNames { |
| public: |
| explicit ObjectNames(const SavedObjectGraph& object_graph, |
| absl::Span<std::string> exported_names); |
| |
| // Gets the names that external users of the SavedModel can use to refer to |
| // this node. |
| llvm::ArrayRef<llvm::StringRef> GetExportedNames(int node_id) const; |
| |
| // Gets the name in the module symbol table for this node. |
| // This name is only used for internal IR references. |
| llvm::StringRef GetSymbolTableName(int node_id) const; |
| |
| private: |
| // In the absence of any other information, use this name as the symbol table |
| // name for this node. |
| std::string GetDefaultSymbolTableName(int node_id) const; |
| // Determines if a name is exported. |
| bool IsExported(const std::string& name); |
| // Main object graph traversal function. |
| void RecursivelyVisitObjectGraph(int node_id); |
| // Gets a stable StringRef from a std::string. |
| llvm::StringRef SaveString(const std::string& s) const; |
| |
| // The object graph we are traversing. |
| const SavedObjectGraph& object_graph_; |
| // The set of names to export. Empty means "export all". |
| std::unordered_set<std::string> names_to_export_; |
| |
| // When we recursively follow the object graph tree structure from the root, |
| // we track its path in the object graph by pushing and popping from here |
| // during traversal. |
| llvm::SmallVector<std::string, 8> path_segments_; |
| // The set of node_id's that are on the current DFS stack. |
| // For cyclic object graphs, this prevents infinite recursion. |
| std::unordered_set<int> on_stack_nodes_; |
| |
| // Key: node_id. |
| // Value: all object names that node_id appears as. |
| // Each object name corresponds to a unique path from the root of the object |
| // graph. |
| // The common intuitive case is when there is only one name for a given |
| // object, which corresponds to the object graph being a tree. |
| // |
| // But, there cases where the object graph is a general graph. For |
| // example, this happens commonly in Keras models, where `foo.bar` is |
| // also reachable via the name `keras_api.foo.bar`. |
| // Cycles are possible too. |
| absl::flat_hash_map<int, std::vector<std::string>> object_names_; |
| |
| // Key: node_id |
| // Value: all names that this object is exported as |
| absl::flat_hash_map<int, llvm::SmallVector<llvm::StringRef, 1>> |
| exported_names_; |
| // Key: node_id |
| // Value: pretty symbol table name to use for internal references to this |
| // object. |
| absl::flat_hash_map<int, llvm::StringRef> pretty_symbol_table_name_; |
| |
| // Stable strings we can take StringRef's into. Used only by the SaveString |
| // method. |
| mutable std::unordered_set<std::string> saved_strings_; |
| }; |
| |
| ObjectNames::ObjectNames(const SavedObjectGraph& object_graph, |
| absl::Span<std::string> exported_names) |
| : object_graph_(object_graph), |
| names_to_export_(exported_names.begin(), exported_names.end()) { |
| // Visit all reachable nodes from the root of the object graph. |
| // This builds up object_names_ to contain all names like `foo.bar` that a |
| // particular node in the graph can be reached from. |
| RecursivelyVisitObjectGraph(/*node_id=*/0); |
| |
| // Populate the exported_names_ map. |
| // TODO(silvasean): Diagnose typos in exported names? |
| for (auto& kv : object_names_) { |
| // Make object names map independent of our particular choice of object |
| // graph traversal. |
| std::sort(kv.second.begin(), kv.second.end(), |
| [](absl::string_view a, absl::string_view b) { |
| // The sort order here influences the "pretty name" we assign |
| // below. We want the most debuggable name to be first. |
| // |
| // Debuggability heuristics: |
| // 1. Names that end in digits are likely to be internal aliases |
| // to the "real" names. |
| // 2. Longer names are more likely to be internal aliases. |
| // |
| // Example set of object names created by Keras for the weight |
| // matrix of a fully connected layer on a trivial FC mnist |
| // model: |
| // - `model.layer-1.kernel` (this is the "best" name) |
| // - `model.keras_api.layers.1.kernel` |
| // - `model.variables.0` |
| // - `model.keras_api.layers.1.keras_api.trainable_variables.0` |
| // - ... 10 more long aliases ending in digits ... |
| return std::make_tuple(isdigit(a.back()), a.size(), a) < |
| std::make_tuple(isdigit(b.back()), b.size(), b); |
| }); |
| for (const std::string& name : kv.second) { |
| if (IsExported(name)) { |
| exported_names_[kv.first].push_back(SaveString(name)); |
| } |
| } |
| } |
| // Create "pretty" symbol table names for nodes where that is applicable. |
| // We could make all symbol table names use the default, which is basically |
| // just the node id. But for debugging purposes, it's nicer if we can mix in |
| // a recognizable object name if we have the information to do so. |
| for (auto& kv : object_names_) { |
| int node_id = kv.first; |
| std::string internal_name = |
| absl::StrCat(GetDefaultSymbolTableName(node_id), "__"); |
| // If the object has an exported name, we prefer that since it is probably |
| // the most recognizable. Otherwise, we grab some non-exported name of the |
| // object. |
| if (exported_names_.find(node_id) != exported_names_.end()) { |
| internal_name += exported_names_[node_id][0].str(); |
| } else { |
| internal_name += object_names_[node_id][0]; |
| } |
| pretty_symbol_table_name_[node_id] = SaveString(internal_name); |
| } |
| } |
| |
| llvm::ArrayRef<llvm::StringRef> ObjectNames::GetExportedNames( |
| int node_id) const { |
| auto it = exported_names_.find(node_id); |
| if (it != exported_names_.end()) { |
| return it->second; |
| } |
| return {}; |
| } |
| |
| llvm::StringRef ObjectNames::GetSymbolTableName(int node_id) const { |
| auto it = pretty_symbol_table_name_.find(node_id); |
| if (it != pretty_symbol_table_name_.end()) { |
| return it->second; |
| } |
| return SaveString(GetDefaultSymbolTableName(node_id)); |
| } |
| |
| std::string ObjectNames::GetDefaultSymbolTableName(int node_id) const { |
| return absl::StrCat("__sm_node", node_id); |
| } |
| |
| bool ObjectNames::IsExported(const std::string& name) { |
| if (names_to_export_.empty()) { |
| return true; |
| } |
| return names_to_export_.find(name) != names_to_export_.end(); |
| } |
| |
| void ObjectNames::RecursivelyVisitObjectGraph(int node_id) { |
| const SavedObject& object = object_graph_.nodes(node_id); |
| |
| switch (object.kind_case()) { |
| case SavedObject::kConstant: |
| case SavedObject::kFunction: |
| case SavedObject::kVariable: { |
| object_names_[node_id].push_back(absl::StrJoin(path_segments_, ".")); |
| break; |
| } |
| default: |
| break; |
| } |
| |
| for (const auto& child_ref : object.children()) { |
| bool on_stack = !on_stack_nodes_.insert(child_ref.node_id()).second; |
| if (on_stack) { |
| // This is a backedge. Don't traverse it. |
| continue; |
| } |
| |
| path_segments_.push_back(child_ref.local_name()); |
| RecursivelyVisitObjectGraph(child_ref.node_id()); |
| path_segments_.pop_back(); |
| |
| on_stack_nodes_.erase(child_ref.node_id()); |
| } |
| } |
| |
| llvm::StringRef ObjectNames::SaveString(const std::string& s) const { |
| return llvm::StringRef(*saved_strings_.insert(s).first); |
| } |
| |
| // Extracts a TensorProto for a Const op from a GraphDef, given an op_name. |
| // Returns nullptr on not found or other mismatch. |
| // This returns a pointer to the actual node within the graph_def so as to |
| // avoid expensive copies. |
| const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def, |
| const std::string& op_name) { |
| const NodeDef* match_node = nullptr; |
| for (const auto& node : graph_def.node()) { |
| if (node.name() == op_name) { |
| match_node = &node; |
| } |
| } |
| |
| if (!match_node) { |
| return nullptr; |
| } |
| |
| auto value_it = match_node->attr().find("value"); |
| if (value_it == match_node->attr().end()) { |
| return nullptr; |
| } |
| |
| if (!value_it->second.has_tensor()) { |
| return nullptr; |
| } |
| |
| return &value_it->second.tensor(); |
| } |
| |
| const TrackableObjectGraph::TrackableObject::SerializedTensor* |
| FindSerializedTensorInTrackable( |
| const TrackableObjectGraph::TrackableObject& trackable_object, |
| StringPiece name) { |
| for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { |
| if (maybe_serialized_tensor.name() == name) { |
| return &maybe_serialized_tensor; |
| } |
| } |
| return nullptr; |
| } |
| |
| Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph, |
| const ObjectNames& object_names) { |
| for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { |
| const SavedObject& object = object_graph.nodes(node_id); |
| if (object_names.GetExportedNames(node_id).empty()) { |
| continue; |
| } |
| if (object.kind_case() == SavedObject::kFunction) { |
| // We only allow a single input signature to each SavedFunction. |
| // This assumption means we have a 1:1 correspondence between |
| // tf.function <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef |
| // This makes defining the ABI easier (or even well-defined at all). |
| // TODO(silvasean): How to detect a function that doesn't have an |
| // explicitly user-provided input signature, but happens to have been |
| // traced exactly once? |
| if (object.function().concrete_functions_size() != 1) { |
| llvm::SmallVector<std::string, 4> names; |
| for (llvm::StringRef s : object_names.GetExportedNames(node_id)) { |
| names.push_back("'" + s.str() + "'"); |
| } |
| return errors::InvalidArgument( |
| "Exported function with exported name(s) ", |
| absl::StrJoin(names, ", "), |
| " with multiple concrete functions. Add " |
| "@tf.function(input_signature=[...]) on this function, or use a " |
| "narrower list of exported names that excludes this function."); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Recursively traverses a StructuredValue, linearizing all the leaves. |
| // |
| // This currently only handles the subset of StructuredValue that is needed for |
| // signatures. |
| // |
| // Given a StructuredValue with structure [{"x": leaf0}], the "index path" |
| // needed to reach leaf0 is `[0, "x"]`, as it would be if you were operating on |
| // a Python object (`obj[0]["x"] is leaf0`). Each leaf corresponds to a |
| // linearized function argument or return on a FunctionDef, and hence to an |
| // mlir::FuncOp argument / return. |
| // |
| // This must match the linearization that happens in `tf.nest.flatten`. |
| // In particular, dict values should be linearized in sorted key order. |
| // |
| // The linearized index paths can be returned back to a structured |
| // representation (e.g. to emit C structs matching a signature) with a simple |
| // algorithm that recurses on each run of index paths with identical first |
| // elements. |
| class StructuredValueLinearizer { |
| public: |
| StructuredValueLinearizer(const StructuredValue& value, |
| mlir::MLIRContext* context); |
| |
| // Returns the list of index paths to each leaf of the StructuredValue, |
| // in a linearized order matching `tf.nest.flatten`. |
| // |
| // If an error occurred during the linearization process, an error message |
| // with `error_context` prepended will be included in the returned status. |
| StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths( |
| llvm::StringRef error_context) const; |
| |
| private: |
| // Main function that recursively traverses the StructuredValue. |
| void RecursivelyFindLeaves(const StructuredValue& value); |
| |
| mlir::Builder builder_; |
| // The current index path. We push/pop this during recursive traversal of the |
| // StructuredValue. |
| llvm::SmallVector<mlir::Attribute, 4> current_index_path_; |
| // The list of leaf index paths we have discovered so far. |
| llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_; |
| // If non-empty, an error message to report. |
| std::string error_message_; |
| }; |
| |
| StructuredValueLinearizer::StructuredValueLinearizer( |
| const StructuredValue& value, mlir::MLIRContext* context) |
| : builder_(context) { |
| RecursivelyFindLeaves(value); |
| } |
| |
| StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> |
| StructuredValueLinearizer::GetLeafIndexPaths( |
| llvm::StringRef error_context) const { |
| if (error_message_.empty()) { |
| return llvm::makeArrayRef(leaf_index_paths_); |
| } |
| return errors::InvalidArgument( |
| error_context.str(), error_message_, |
| "This likely means that you have @tf.function " |
| "on an exported function instead of " |
| "@tf.function(input_signature=[...]). Consider annotating an " |
| "input_signature or narrowing your set of " |
| "exported names to not include this function."); |
| } |
| |
| void StructuredValueLinearizer::RecursivelyFindLeaves( |
| const StructuredValue& value) { |
| switch (value.kind_case()) { |
| case StructuredValue::kDictValue: { |
| // Dict values must be linearized in sorted order of keys. |
| const DictValue& dict = value.dict_value(); |
| using FieldTy = protobuf::MapPair<std::string, StructuredValue>; |
| llvm::SmallVector<const FieldTy*, 4> fields; |
| for (auto& field : dict.fields()) { |
| fields.push_back(&field); |
| } |
| llvm::sort(fields, [](const FieldTy* a, const FieldTy* b) { |
| return a->first < b->first; |
| }); |
| for (auto& field : fields) { |
| current_index_path_.push_back(builder_.getStringAttr(field->first)); |
| RecursivelyFindLeaves(field->second); |
| current_index_path_.pop_back(); |
| } |
| return; |
| } |
| case StructuredValue::kTupleValue: { |
| const TupleValue& tuple = value.tuple_value(); |
| for (int i = 0, e = tuple.values_size(); i < e; i++) { |
| current_index_path_.push_back(builder_.getI64IntegerAttr(i)); |
| RecursivelyFindLeaves(tuple.values(i)); |
| current_index_path_.pop_back(); |
| } |
| return; |
| } |
| // We don't differentiate between tuples and lists. |
| case StructuredValue::kListValue: { |
| const ListValue& list = value.list_value(); |
| for (int i = 0, e = list.values_size(); i < e; i++) { |
| current_index_path_.push_back(builder_.getI64IntegerAttr(i)); |
| RecursivelyFindLeaves(list.values(i)); |
| current_index_path_.pop_back(); |
| } |
| return; |
| } |
| case StructuredValue::kTensorSpecValue: { |
| // Base case: record the current path stack as the index path needed to |
| // get to this leaf. |
| leaf_index_paths_.push_back(builder_.getArrayAttr(current_index_path_)); |
| return; |
| } |
| case StructuredValue::kNoneValue: { |
| // Base case: do nothing. |
| // This arises, for example, as the top-level object of an output |
| // signature when there are no return values. |
| return; |
| } |
| default: { |
| llvm::raw_string_ostream os(error_message_); |
| // TODO(silvasean): Use an enumerant name string instead of a number. |
| os << "Unhandled structured value kind " << value.kind_case() |
| << " at index path: <value>"; |
| for (auto path_element : current_index_path_) { |
| os << "."; |
| if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) { |
| os << integer.getValue(); |
| } else { |
| auto str = path_element.cast<mlir::StringAttr>(); |
| os << str.getValue(); |
| } |
| } |
| os << "\n"; |
| } |
| } |
| } |
| |
| // For exported functions with bound inputs, rewrite the function |
| // signature to match the requirements of tf_saved_model bound input args. |
| // |
| // The raw imported functions have `tensor<*x!tf.resource>` as the type for |
| // mutable bound inputs and `tensor<...>` as the type for immutable |
| // bound inputs. Here we canonicalize both of them into |
| // `tensor<!tf.resource<tensor<...>>>`. |
| void AdjustBoundInputArgTypes(mlir::ModuleOp module) { |
| mlir::SymbolTable symbol_table(module); |
| for (auto func : module.getOps<mlir::FuncOp>()) { |
| if (!mlir::tf_saved_model::IsExported(func)) continue; |
| mlir::OpBuilder builder(func.getBody()); |
| llvm::SmallVector<mlir::Type, 4> new_input_types; |
| for (int i = 0, e = func.getNumArguments(); i < e; i++) { |
| auto arg = func.getArgument(i); |
| auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType< |
| mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table); |
| if (global_tensor) { |
| auto old_type = arg.getType(); |
| auto new_type = |
| mlir::tf_saved_model::GetBoundInputArgTypeFor(global_tensor); |
| arg.setType(new_type); |
| if (global_tensor.is_mutable()) { |
| auto arg_with_original_type = builder.create<mlir::TF::CastOp>( |
| global_tensor.getLoc(), old_type, arg, |
| /*Truncate=*/builder.getBoolAttr(false)); |
| arg.replaceAllUsesWith(arg_with_original_type); |
| // The RAUW replaces the arg with itself, so we need to set it back. |
| arg_with_original_type.setOperand(arg); |
| } else { |
| auto arg_with_original_type = |
| builder.create<mlir::TF::ReadVariableOp>(global_tensor.getLoc(), |
| old_type, arg); |
| arg.replaceAllUsesWith(arg_with_original_type); |
| // The RAUW replaces the arg with itself, so we need to set it back. |
| arg_with_original_type.setOperand(arg); |
| } |
| } |
| new_input_types.push_back(arg.getType()); |
| } |
| func.setType(mlir::FunctionType::get(module.getContext(), new_input_types, |
| func.getType().getResults())); |
| } |
| } |
| |
| // Marks the visibility of functions in the saved model module. |
| void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) { |
| for (auto func : module.getOps<mlir::FuncOp>()) { |
| auto visibility = mlir::tf_saved_model::IsExported(func) |
| ? mlir::FuncOp::Visibility::Public |
| : mlir::FuncOp::Visibility::Private; |
| func.setVisibility(visibility); |
| } |
| } |
| |
| // Reorder the ops in the module to make testing easier and less dependent |
| // on implementation details such as the order of functions in the |
| // FunctionDefLibrary. |
| // |
| // The order this ensures is: |
| // 1. GlobalTensorOp's |
| // 2. FuncOps's. |
| // |
| // Within each of 1. and 2., ops are sorted by exported name (if |
| // available, and only the first exported name is considered), followed by |
| // non-exported ops. |
| void SortSavedModelModule(mlir::ModuleOp module) { |
| struct NamedGlobalTensor { |
| llvm::StringRef name; |
| GlobalTensorOp global_tensor; |
| }; |
| llvm::SmallVector<NamedGlobalTensor, 8> named_global_tensors; |
| for (auto global_tensor : module.getOps<GlobalTensorOp>()) { |
| auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor); |
| // We use stable_sort, so duplicate empty names are fine here. |
| named_global_tensors.push_back( |
| {exported_names.empty() ? "" : exported_names.front(), global_tensor}); |
| } |
| llvm::stable_sort(named_global_tensors, |
| [](const NamedGlobalTensor& a, const NamedGlobalTensor& b) { |
| return std::make_tuple(a.name.empty(), a.name) < |
| std::make_tuple(b.name.empty(), b.name); |
| }); |
| |
| struct NamedFunc { |
| llvm::StringRef name; |
| mlir::FuncOp func; |
| }; |
| llvm::SmallVector<NamedFunc, 8> named_funcs; |
| llvm::SmallVector<mlir::FuncOp, 8> private_funcs; |
| for (auto func : module.getOps<mlir::FuncOp>()) { |
| auto exported_names = mlir::tf_saved_model::GetExportedNames(func); |
| if (!exported_names.empty()) |
| named_funcs.push_back({exported_names.front(), func}); |
| else |
| private_funcs.push_back(func); |
| } |
| llvm::stable_sort(named_funcs, [](const NamedFunc& a, const NamedFunc& b) { |
| return a.name < b.name; |
| }); |
| llvm::stable_sort(private_funcs, [](mlir::FuncOp a, mlir::FuncOp b) { |
| return a.getName() < b.getName(); |
| }); |
| |
| struct NamedAsset { |
| llvm::StringRef name; |
| AssetOp asset; |
| }; |
| llvm::SmallVector<NamedAsset, 4> assets; |
| for (auto asset : module.getOps<AssetOp>()) { |
| assets.push_back({asset.getName(), asset}); |
| } |
| llvm::stable_sort(assets, [](const NamedAsset& a, const NamedAsset& b) { |
| return a.name < b.name; |
| }); |
| |
| // Move onto the front of the module in reverse of the final desired order. |
| for (auto func : llvm::reverse(private_funcs)) { |
| func.getOperation()->moveBefore(&module.getBody()->front()); |
| } |
| for (auto named_func : llvm::reverse(named_funcs)) { |
| named_func.func.getOperation()->moveBefore(&module.getBody()->front()); |
| } |
| for (auto named_global_tensor : llvm::reverse(named_global_tensors)) { |
| named_global_tensor.global_tensor.getOperation()->moveBefore( |
| &module.getBody()->front()); |
| } |
| |
| for (auto asset : assets) { |
| asset.asset.getOperation()->moveBefore(&module.getBody()->front()); |
| } |
| |
| auto initializers = module.getOps<SessionInitializerOp>(); |
| if (!initializers.empty()) { |
| (*initializers.begin()) |
| .getOperation() |
| ->moveBefore(&module.getBody()->front()); |
| } |
| } |
| |
| Status CreateSavedModelIR( |
| const ObjectNames& object_names, mlir::ModuleOp module, |
| const SavedObjectGraph& object_graph, |
| const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name, |
| SavedModelV2Bundle* saved_model) { |
| mlir::OpBuilder builder(module.getBodyRegion()); |
| mlir::SymbolTable symbol_table(module); |
| |
| // Create a side data-structure, indexed by the object_graph node_id to |
| // a TrackableObject that is restorable. |
| absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*> |
| restored_objects; |
| TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore( |
| [&](int saved_node_id, |
| const TrackableObjectGraph::TrackableObject& trackable_object) { |
| restored_objects.insert( |
| std::make_pair(saved_node_id, &trackable_object)); |
| return Status::OK(); |
| })); |
| |
| for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { |
| const SavedObject& object = object_graph.nodes(node_id); |
| // For correctness, we cannot import functions that don't have exported |
| // names, since they don't necessarily have a well-defined ABI (diagnosed |
| // earlier). |
| // |
| // For variables/constants, pruning them is purely an optimization, |
| // and more complicated since it requires use-def analysis of which |
| // functions use which variables/constants, so we don't do anything |
| // special for them here as part of our initial IR construction. |
| if (object.kind_case() == SavedObject::kFunction) { |
| if (object_names.GetExportedNames(node_id).empty()) { |
| continue; |
| } |
| std::string error_context = |
| "While importing SavedModel function '" + |
| object_names.GetExportedNames(node_id)[0].str() + "': "; |
| const SavedFunction& function = object.function(); |
| auto orig_func = symbol_table.lookup<mlir::FuncOp>( |
| tf_name_to_mlir_name.find(function.concrete_functions(0))->second); |
| mlir::FuncOp func = orig_func; |
| // If there are potentially references to this func from within the |
| // module, create a wrapper around it and decorate the wrapper with the |
| // tf_saved_model attributes instead. |
| if (!mlir::SymbolTable::symbolKnownUseEmpty(orig_func.getName(), |
| &module.getBodyRegion())) { |
| func = orig_func.cloneWithoutRegions(); |
| module.insert(module.getBody()->begin(), func); |
| func.addEntryBlock(); |
| func.setName("__sm_exported_" + orig_func.getName().str()); |
| llvm::SmallVector<mlir::Value, 4> args_as_values; |
| for (auto block_argument : func.getArguments()) { |
| args_as_values.push_back(block_argument); |
| } |
| mlir::OpBuilder body_builder(&func.getBody()); |
| auto call = body_builder.create<mlir::TF::StatefulPartitionedCallOp>( |
| func.getLoc(), orig_func.getType().getResults(), args_as_values, |
| builder.getSymbolRefAttr(orig_func.getName()), |
| /*config=*/builder.getStringAttr(""), |
| /*config_proto=*/builder.getStringAttr(""), |
| /*executor_type=*/builder.getStringAttr("")); |
| body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults()); |
| } |
| func->setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); |
| const SavedConcreteFunction& concrete_function = |
| object_graph.concrete_functions().at(function.concrete_functions(0)); |
| |
| // We do not handle the other element of this tuple, which corresponds to |
| // Python kwonlyargs, since currently TensorFlow prohibits this in |
| // combination with input_signature: |
| // https://github.com/tensorflow/tensorflow/blob/8cb8627abb5ef83a6fba34f8fd0e4ee430562eb1/tensorflow/python/eager/function.py#L2027-L2030 |
| // Our SavedModel import requires input_signature on the tf.function, so |
| // we never need to handle the kwonlyargs. |
| auto positional_arg_structure = |
| concrete_function.canonicalized_input_signature() |
| .tuple_value() |
| .values(0); |
| StructuredValueLinearizer input_linearizer(positional_arg_structure, |
| builder.getContext()); |
| |
| int bound_input_base = |
| func.getNumArguments() - concrete_function.bound_inputs_size(); |
| TF_ASSIGN_OR_RETURN(auto input_index_paths, |
| input_linearizer.GetLeafIndexPaths( |
| error_context + "in input signature: ")); |
| const int input_index_paths_size = input_index_paths.size(); |
| if (bound_input_base != input_index_paths_size) { |
| return errors::InvalidArgument( |
| error_context, |
| "Argument mismatch between concrete function input signature " |
| "vs underlying FunctionDef for concrete function '", |
| function.concrete_functions(0), "' (", input_index_paths.size(), |
| " vs ", bound_input_base, ")"); |
| } |
| for (auto index_path : llvm::enumerate(input_index_paths)) { |
| func.setArgAttr(index_path.index(), "tf_saved_model.index_path", |
| index_path.value()); |
| } |
| |
| for (auto& bound_input : |
| llvm::enumerate(concrete_function.bound_inputs())) { |
| int arg_index = bound_input_base + bound_input.index(); |
| auto symbol_ref = builder.getSymbolRefAttr( |
| object_names.GetSymbolTableName(bound_input.value())); |
| func.setArgAttr(arg_index, "tf_saved_model.bound_input", symbol_ref); |
| } |
| |
| StructuredValueLinearizer output_linearizer( |
| concrete_function.output_signature(), builder.getContext()); |
| TF_ASSIGN_OR_RETURN(auto output_index_paths, |
| output_linearizer.GetLeafIndexPaths( |
| error_context + "in output signature: ")); |
| if (func.getNumResults() != output_index_paths.size()) { |
| return errors::InvalidArgument( |
| error_context, |
| "Result mismatch between concrete function output signature " |
| "vs underlying FunctionDef for concrete function '", |
| function.concrete_functions(0), "' (", output_index_paths.size(), |
| " vs ", func.getNumResults(), ")"); |
| } |
| for (auto index_path : llvm::enumerate(output_index_paths)) { |
| func.setResultAttr(index_path.index(), "tf_saved_model.index_path", |
| index_path.value()); |
| } |
| } else if (object.kind_case() == SavedObject::kVariable) { |
| const SavedVariable& variable = object.variable(); |
| // Find the trackable in the side data structure. |
| auto variable_trackable_it = restored_objects.find(node_id); |
| if (variable_trackable_it == restored_objects.end()) { |
| return errors::FailedPrecondition("Could not restore saved variable: ", |
| variable.name()); |
| } |
| const auto* serialized_tensor_attr = FindSerializedTensorInTrackable( |
| *variable_trackable_it->second, "VARIABLE_VALUE"); |
| if (!serialized_tensor_attr) { |
| return errors::FailedPrecondition( |
| "Could not find serialized tensor for saved variable: ", |
| variable.name()); |
| } |
| const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key(); |
| |
| // Load it from the reader. |
| Tensor value; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| saved_model->variable_reader()->Lookup(checkpoint_key, &value), |
| "Could not read checkpoint key from variables bundle: ", |
| checkpoint_key); |
| TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); |
| // A variable can have a partially known type, such as tensor<?x27x?xf32>, |
| // even if the initializer is a specific static shape. |
| TF_ASSIGN_OR_RETURN( |
| auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(), |
| &builder)); |
| auto op = builder.create<GlobalTensorOp>( |
| builder.getUnknownLoc(), |
| builder.getStringAttr(object_names.GetSymbolTableName(node_id)), |
| value_attr, |
| /*type=*/mlir::TypeAttr::get(type), |
| /*is_mutable=*/builder.getUnitAttr()); |
| op->setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); |
| } else if (object.kind_case() == SavedObject::kConstant) { |
| const SavedConstant& constant = object.constant(); |
| const TensorProto* value = ExtractConstTensorFromGraph( |
| saved_model->meta_graph_def().graph_def(), constant.operation()); |
| if (!value) { |
| return errors::FailedPrecondition( |
| "Unable to find const node referenced in object graph: ", |
| constant.operation()); |
| } |
| TF_ASSIGN_OR_RETURN(auto value_attr, |
| ConvertTensorProto(*value, &builder)); |
| auto op = builder.create<GlobalTensorOp>( |
| builder.getUnknownLoc(), |
| builder.getStringAttr(object_names.GetSymbolTableName(node_id)), |
| value_attr, |
| /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()), |
| /*is_mutable=*/nullptr); |
| op->setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); |
| } |
| } |
| AdjustBoundInputArgTypes(module); |
| module->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); |
| SortSavedModelModule(module); |
| MarkSavedModelFunctionVisibility(module); |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert( |
| SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context, bool add_default_attributes) { |
| LoadImporterDialects(*context); |
| GraphDebugInfo dummy_debug_info; |
| const GraphDebugInfo& debug_info = |
| saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; |
| |
| GraphImportConfig specs; |
| specs.prune_unused_nodes = true; |
| mlir::OwningModuleRef module = |
| mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); |
| std::unordered_map<std::string, std::string> tf_name_to_mlir_name; |
| |
| const auto& graphdef = saved_model->meta_graph_def().graph_def(); |
| PopulateTfVersions(module.get(), graphdef.versions()); |
| |
| GraphConstructorOptions options; |
| options.allow_internal_ops = true; |
| options.add_default_attributes = add_default_attributes; |
| Graph graph(OpRegistry::Global()); |
| |
| GraphDef preprocessed_graphdef(graphdef); |
| if (add_default_attributes) { |
| TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef)); |
| } |
| |
| TF_RETURN_IF_ERROR( |
| ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); |
| |
| NameUniquifier function_name_uniquifier(graph.flib_def()); |
| SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs, |
| module.get(), &tf_name_to_mlir_name, |
| &function_name_uniquifier); |
| |
| TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); |
| |
| auto fn_names = graph.flib_def().ListFunctionNames(); |
| for (const auto& fn_name : fn_names) { |
| TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); |
| } |
| |
| if (!saved_model->meta_graph_def().has_object_graph_def()) { |
| return errors::InvalidArgument( |
| "SavedModel does not have an object graph. Please use TF2."); |
| } |
| auto& object_graph = saved_model->meta_graph_def().object_graph_def(); |
| ObjectNames object_names(object_graph, exported_names); |
| |
| // Clean up a couple func's that always seem to be present when importing a |
| // SavedModel. This is not strictly needed, as there is a separate pass that |
| // will clean them up, but this makes staring at the raw IR of minimal |
| // examples quite a bit nicer. |
| for (auto func : llvm::make_early_inc_range(module->getOps<mlir::FuncOp>())) { |
| if (func.getName().startswith("__inference__traced_save_") || |
| func.getName().startswith("__inference__traced_restore_") || |
| func.getName().startswith("__inference_signature_wrapper_")) { |
| func.erase(); |
| } |
| } |
| |
| // Diagnose SavedFunction's with multiple input signatures. |
| TF_RETURN_IF_ERROR( |
| DiagnoseMultipleConcreteFunctions(object_graph, object_names)); |
| |
| // Construct the SavedModel IR. |
| TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(), |
| object_graph, tf_name_to_mlir_name, |
| saved_model)); |
| assert(mlir::succeeded(mlir::verify(module.get()))); |
| |
| return module; |
| } |
| |
| class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput { |
| public: |
| static StatusOr<SimpleSavedModelMLIRImportInput> Create( |
| const MLIRImportOptions& import_options, |
| const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) { |
| DCHECK(meta_graph_def); |
| GraphDef graph_def; |
| if (import_options.enable_grappler) { |
| // Grappler is best-effort. |
| auto statusor = RunGrappler(*meta_graph_def); |
| if (statusor.ok()) { |
| graph_def = std::move(statusor).ValueOrDie(); |
| } else { |
| // If the grappler fails, use the original graph def. |
| LOG(WARNING) << "SimpleSavedModelMLIRImportInput: grappler failed: " |
| << statusor.status(); |
| graph_def = meta_graph_def->graph_def(); |
| } |
| } else { |
| graph_def = meta_graph_def->graph_def(); |
| } |
| |
| auto graph = std::make_unique<Graph>(OpRegistry::Global()); |
| |
| if (import_options.upgrade_legacy) { |
| TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty( |
| graph_def, graph->flib_def().default_registry())); |
| } |
| |
| GraphConstructorOptions graph_ctor_options; |
| graph_ctor_options.allow_internal_ops = true; |
| graph_ctor_options.add_default_attributes = true; |
| TF_RETURN_IF_ERROR( |
| ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get())); |
| |
| if (import_options.upgrade_legacy) { |
| // TODO(jpienaar): Remove need to const_cast. |
| TF_RETURN_IF_ERROR(UpgradeLegacyGraph( |
| graph.get(), |
| const_cast<FunctionLibraryDefinition*>(&graph->flib_def()), |
| /*restrict_functionalization_to_tpu_nodes=*/false)); |
| } |
| |
| return SimpleSavedModelMLIRImportInput(meta_graph_def, debug_info, |
| std::move(graph)); |
| } |
| |
| SimpleSavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def, |
| const GraphDebugInfo& debug_info, |
| std::unique_ptr<Graph> graph) |
| : SavedModelMLIRImportInput(meta_graph_def, debug_info), |
| graph_(std::move(graph)) {} |
| |
| StatusOr<const Graph*> GetSubGraph(absl::string_view name, |
| const GraphImportConfig& specs) override { |
| DCHECK(CheckGraphNameValidity(name)); |
| DCHECK(CheckGraphContainsFeedsAndFetches(specs)); |
| return graph_.get(); |
| } |
| |
| private: |
| bool CheckGraphContainsFeedsAndFetches(const GraphImportConfig& specs) const { |
| absl::flat_hash_set<std::string> feed_fetch_nodes; |
| for (const auto& iter : specs.inputs) { |
| TensorId tensor_id = ParseTensorName(iter.first); |
| feed_fetch_nodes.insert(std::string(tensor_id.node())); |
| } |
| for (const auto& output : llvm::concat<const std::string>( |
| specs.outputs, specs.control_outputs)) { |
| TensorId tensor_id = ParseTensorName(output); |
| feed_fetch_nodes.insert(std::string(tensor_id.node())); |
| } |
| |
| for (Node* node : graph_->op_nodes()) { |
| feed_fetch_nodes.erase(node->name()); |
| } |
| |
| return feed_fetch_nodes.empty(); |
| } |
| |
| bool CheckGraphNameValidity(absl::string_view name) const { |
| // If it is one of the signature name, it is valid. |
| const auto& signature_defs = meta_graph_def().signature_def(); |
| if (signature_defs.contains(std::string(name))) return true; |
| |
| // If it is the restore graph name, it is valid. |
| if (meta_graph_def().has_saver_def() && |
| meta_graph_def().saver_def().restore_op_name() == name) |
| return true; |
| |
| // If it is the init graph name, it is valid. |
| std::string init_op_name; |
| if (internal::GetInitOp("", meta_graph_def(), &init_op_name).ok()) { |
| if (init_op_name == name) return true; |
| } |
| |
| return false; |
| } |
| |
| // `graph_` contains the entire graph in the original MetaGraphDef. |
| std::unique_ptr<Graph> graph_; |
| }; |
| |
| // A helper class to import a TensorFlow model expressed in SavedModel V1 into |
| // an MLIR Module in SavedModel dialect. |
| // |
| // TODO(b/179683149): Rename this class to avoid confusion with TFLite. |
| class SavedModelSignatureDefImporterLite { |
| public: |
| // Main entry point: converts all functions (specified by SignatureDefs) in |
| // the given meta graph to an MLIR Module. |
| // |
| // `import_restore` is introduced to control whether restore graph |
| // is imported in eg. SavedModelSignatureDefImporter. Ideally, we don't need |
| // this option to control this as restore graph should be always imported. |
| // However, right now, SavedModelSignatureDefImporter cannot handle restore |
| // graph correctly. |
| // |
| // TODO(chky): Remove import_restore once the restore graph is correctly |
| // handled in SavedModelSignatureDefImporter. |
| static StatusOr<mlir::OwningModuleRef> Convert( |
| SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context, bool import_restore = true) { |
| LoadImporterDialects(*context); |
| SavedModelSignatureDefImporterLite importer(input, exported_names, context, |
| import_restore); |
| TF_ASSIGN_OR_RETURN(auto module, importer.ConvertSignatures()); |
| |
| SortSavedModelModule(*module); |
| MarkSavedModelFunctionVisibility(*module); |
| |
| return module; |
| } |
| |
| private: |
| SavedModelSignatureDefImporterLite(SavedModelMLIRImportInput& input, |
| absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context, |
| bool import_restore) |
| : input_(input), |
| exported_names_(exported_names), |
| module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))), |
| symbol_table_(module_.get()), |
| import_restore_(import_restore) {} |
| |
| // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function |
| // for each signature. |
| StatusOr<mlir::OwningModuleRef> ConvertSignatures(); |
| Status ConvertSignature(const std::string& sig_def_key, |
| const SignatureDef& signature_def); |
| |
| struct AssetInfo { |
| std::string tensor_name; |
| mlir::tf_saved_model::AssetOp op; |
| }; |
| StatusOr<std::vector<AssetInfo>> ConvertAssets(); |
| // Converts the initialization graph in the SavedModel to an MLIR function. |
| Status ConvertInitializer(const std::string& target_node_name, |
| const std::vector<AssetInfo>& assets); |
| |
| // Converts a graph with feeds and fetches to an MLIR function. |
| StatusOr<mlir::OwningModuleRef> ConvertGraph( |
| const std::string& name, |
| const std::vector<std::pair<std::string, TensorInfo>>& inputs, |
| const std::vector<std::pair<std::string, TensorInfo>>& outputs, |
| const std::vector<std::string> control_outputs); |
| |
| // Moves the functions in `sub_module` to `module_` and skips the duplicate |
| // functions. |
| void MoveConvertedFunctionsToModule(mlir::ModuleOp sub_module); |
| |
| GraphImportConfig::InputArrays ParseInputArrays( |
| llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs); |
| |
| private: |
| SavedModelMLIRImportInput& input_; |
| absl::Span<std::string> exported_names_; |
| mlir::OwningModuleRef module_; |
| mlir::SymbolTable symbol_table_; |
| bool import_restore_ = true; |
| }; |
| |
| StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>> |
| SavedModelSignatureDefImporterLite::ConvertAssets() { |
| std::vector<AssetFileDef> asset_file_defs; |
| TF_RETURN_IF_ERROR( |
| internal::GetAssetFileDefs(input_.meta_graph_def(), &asset_file_defs)); |
| |
| std::vector<AssetInfo> results; |
| results.reserve(asset_file_defs.size()); |
| |
| mlir::OpBuilder builder(module_->getBodyRegion()); |
| unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets. |
| for (const auto& asset : asset_file_defs) { |
| auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>( |
| module_->getLoc(), |
| /*sym_name=*/ |
| builder.getStringAttr( |
| absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())), |
| /*filename=*/ |
| builder.getStringAttr( |
| io::JoinPath(kSavedModelAssetsDirectory, asset.filename()))); |
| |
| results.push_back({asset.tensor_info().name(), asset_op}); |
| } |
| |
| return results; |
| } |
| |
| void SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( |
| mlir::ModuleOp sub_module) { |
| // Iterate through all functions and insert the ones that do not already exist |
| // in `module_`. |
| for (auto func : sub_module.getOps<mlir::FuncOp>()) { |
| if (symbol_table_.lookup(func.getName())) continue; |
| symbol_table_.insert(func.clone()); |
| } |
| } |
| |
| Status SavedModelSignatureDefImporterLite::ConvertInitializer( |
| const std::string& target_node_name, const std::vector<AssetInfo>& assets) { |
| std::vector<std::pair<std::string, TensorInfo>> inputs; |
| inputs.reserve(assets.size()); |
| for (const auto& asset : assets) { |
| TensorInfo tensor_info; |
| tensor_info.set_name(asset.tensor_name); |
| tensor_info.set_dtype(DT_STRING); |
| tensor_info.mutable_tensor_shape(); |
| inputs.push_back({asset.tensor_name, tensor_info}); |
| } |
| |
| TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs, |
| {}, {target_node_name})); |
| |
| mlir::SymbolTable sub_symbol_table(*sub_module); |
| |
| auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name); |
| init_func_op.removeAttr("tf.entry_function"); |
| |
| mlir::OpBuilder builder(module_->getBodyRegion()); |
| |
| // Bind asset inputs to asset ops. |
| DCHECK_EQ(init_func_op.getNumArguments(), assets.size()); |
| for (const auto& iter : llvm::enumerate(assets)) { |
| auto asset_op = iter.value().op; |
| init_func_op.setArgAttr(iter.index(), "tf_saved_model.bound_input", |
| builder.getSymbolRefAttr(asset_op.getName())); |
| } |
| |
| // Set the exported name of init function to an reserved name for |
| // tf_saved_model. |
| init_func_op->setAttr( |
| "tf_saved_model.exported_names", |
| builder.getStrArrayAttr({absl::StrCat( |
| "__tf_saved_model_session_initializer_", target_node_name)})); |
| |
| // Move the converted functions to top level MLIR module. |
| MoveConvertedFunctionsToModule(*sub_module); |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> |
| SavedModelSignatureDefImporterLite::ConvertGraph( |
| const std::string& name, |
| const std::vector<std::pair<std::string, TensorInfo>>& inputs, |
| const std::vector<std::pair<std::string, TensorInfo>>& outputs, |
| const std::vector<std::string> control_outputs) { |
| VLOG(1) << "Importing Signature: " << name; |
| |
| GraphImportConfig specs; |
| specs.prune_unused_nodes = true; |
| specs.inputs = ParseInputArrays(inputs); |
| for (auto& output : outputs) specs.outputs.push_back(output.second.name()); |
| specs.control_outputs = control_outputs; |
| |
| TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs)); |
| |
| // Convert sub-graph to MLIR module. |
| return GraphDefImporter::Convert(module_->getContext(), *subgraph, |
| input_.debug_info(), subgraph->flib_def(), |
| specs, name); |
| } |
| |
| Status SavedModelSignatureDefImporterLite::ConvertSignature( |
| const std::string& sig_def_key, const SignatureDef& signature_def) { |
| // Create local vectors for the input and output and sort them to be |
| // deterministic. We don't want anyone to really depend on the order, client |
| // should lookup argument/result mapping by attribute name. |
| // To avoid accidentally depending on the order we use an unintuitive sorting. |
| std::vector<std::pair<std::string, TensorInfo>> inputs( |
| signature_def.inputs().begin(), signature_def.inputs().end()); |
| llvm::sort(inputs, [](const auto& lhs, const auto& rhs) { |
| return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; |
| }); |
| std::vector<std::pair<std::string, TensorInfo>> outputs( |
| signature_def.outputs().begin(), signature_def.outputs().end()); |
| llvm::sort(outputs, [](const auto& lhs, const auto& rhs) { |
| return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; |
| }); |
| |
| // Convert sub-graph to MLIR module. |
| TF_ASSIGN_OR_RETURN(auto sub_module, |
| ConvertGraph(sig_def_key, inputs, outputs, {})); |
| mlir::OpBuilder builder(sub_module->getBodyRegion()); |
| |
| // Find the FuncOp which corresponds to current SignatureDef. |
| mlir::SymbolTable sub_symbol_table(*sub_module); |
| auto func_op = sub_symbol_table.lookup<mlir::FuncOp>(sig_def_key); |
| TF_RET_CHECK(func_op) |
| << "Graphdef importer should have created a function named " |
| << sig_def_key << "."; |
| |
| // Use unique SignatureDef key as exported name. |
| func_op->setAttr("tf_saved_model.exported_names", |
| builder.getStrArrayAttr({sig_def_key})); |
| |
| // Transfer input and output parameter names to index_path attributes. |
| for (auto input_and_idx : llvm::enumerate(inputs)) { |
| func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path", |
| builder.getStrArrayAttr({input_and_idx.value().first})); |
| } |
| for (auto output_and_idx : llvm::enumerate(outputs)) { |
| func_op.setResultAttr( |
| output_and_idx.index(), "tf_saved_model.index_path", |
| builder.getStrArrayAttr({output_and_idx.value().first})); |
| } |
| |
| // Move the converted functions to top level MLIR module. |
| MoveConvertedFunctionsToModule(*sub_module); |
| |
| return Status::OK(); |
| } |
| |
| GraphImportConfig::InputArrays |
| SavedModelSignatureDefImporterLite::ParseInputArrays( |
| llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) { |
| GraphImportConfig::InputArrays results; |
| for (const auto& iter : inputs) { |
| const auto& tensor_info = iter.second; |
| |
| // Only dense tensor is supported. |
| DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName); |
| |
| VLOG(1) << "Importing Signature Input: input_name = " << iter.first |
| << ", tensor_info = " << tensor_info.DebugString(); |
| |
| ArrayInfo array_info; |
| array_info.imported_dtype = tensor_info.dtype(); |
| |
| if (tensor_info.has_tensor_shape()) { |
| array_info.shape = tensor_info.tensor_shape(); |
| } else { |
| // If there is no tensor shape in the tensor info, conservatively set |
| // unknown_rank to true. |
| array_info.shape.set_unknown_rank(true); |
| } |
| |
| results.insert(std::pair<std::string, ArrayInfo>(tensor_info.name(), |
| std::move(array_info))); |
| } |
| return results; |
| } |
| |
| StatusOr<mlir::OwningModuleRef> |
| SavedModelSignatureDefImporterLite::ConvertSignatures() { |
| const auto& signatures = input_.meta_graph_def().signature_def(); |
| PopulateTfVersions(module_.get(), |
| input_.meta_graph_def().graph_def().versions()); |
| |
| llvm::DenseSet<llvm::StringRef> exported_name_set; |
| exported_name_set.insert(exported_names_.begin(), exported_names_.end()); |
| |
| for (const auto& key_and_signature_def : signatures) { |
| const std::string& sig_def_key = key_and_signature_def.first; |
| const SignatureDef& signature_def = key_and_signature_def.second; |
| |
| // It is safe to skip "__saved_model_init_op" since it is an internal |
| // signature that is not user-accessible. This signature will be handled in |
| // ConvertInitializer(). |
| if (sig_def_key == "__saved_model_init_op") { |
| continue; |
| } |
| if (!exported_name_set.empty() && |
| exported_name_set.count(sig_def_key) == 0) { |
| continue; |
| } |
| |
| TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def)); |
| } |
| |
| TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets()); |
| |
| mlir::OpBuilder builder(module_->getBodyRegion()); |
| llvm::SmallVector<mlir::Attribute, 2> init_sym_refs; |
| |
| if (import_restore_ && input_.meta_graph_def().has_saver_def()) { |
| std::vector<AssetInfo> variable_and_assets; |
| |
| // Create an AssetOp for the variable checkpoint files. The relative |
| // filename is used here. |
| auto variable_filename_op = builder.create<mlir::tf_saved_model::AssetOp>( |
| module_->getLoc(), |
| /*sym_name=*/ |
| builder.getStringAttr("__tf_saved_model_variables"), |
| /*filename=*/ |
| builder.getStringAttr(io::JoinPath(kSavedModelVariablesDirectory, |
| kSavedModelVariablesFilename))); |
| variable_and_assets.push_back( |
| {input_.meta_graph_def().saver_def().filename_tensor_name(), |
| variable_filename_op}); |
| variable_and_assets.insert(variable_and_assets.end(), assets.begin(), |
| assets.end()); |
| |
| const auto& restore_op_name = |
| input_.meta_graph_def().saver_def().restore_op_name(); |
| TF_RETURN_IF_ERROR( |
| ConvertInitializer(restore_op_name, variable_and_assets)); |
| init_sym_refs.push_back(builder.getSymbolRefAttr(restore_op_name)); |
| } |
| |
| std::string init_op_name; |
| TF_RETURN_IF_ERROR( |
| internal::GetInitOp("", input_.meta_graph_def(), &init_op_name)); |
| if (!init_op_name.empty()) { |
| TF_RETURN_IF_ERROR(ConvertInitializer(init_op_name, assets)); |
| init_sym_refs.push_back(builder.getSymbolRefAttr(init_op_name)); |
| } |
| |
| builder.create<mlir::tf_saved_model::SessionInitializerOp>( |
| module_->getLoc(), builder.getArrayAttr(init_sym_refs)); |
| |
| (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); |
| |
| SortSavedModelModule(*module_); |
| MarkSavedModelFunctionVisibility(*module_); |
| |
| return std::move(module_); |
| } |
| |
| // A helper class to import a TensorFlow model expressed in SavedModel V1 into |
| // an MLIR Module in SavedModel dialect. In addition to importing the model, it |
| // performs a few graph transformations, including: |
| // 1) Convert read-only ref variables to resource variables |
| // 2) Lift resource variables to global_tensors by using a TF session. |
| class SavedModelSignatureDefImporter { |
| public: |
| // Main entry point: converts all functions (specified by SignatureDefs) in |
| // the given meta graph to an MLIR Module. |
| static StatusOr<mlir::OwningModuleRef> Convert( |
| const SavedModelBundle& bundle, absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context, tensorflow::MLIRImportOptions options) { |
| // debug_info might not be loaded with loader_lite. |
| GraphDebugInfo debug_info; |
| if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info; |
| |
| TF_ASSIGN_OR_RETURN(auto input, |
| SimpleSavedModelMLIRImportInput::Create( |
| options, &bundle.meta_graph_def, debug_info)); |
| |
| TF_ASSIGN_OR_RETURN(auto module, |
| SavedModelSignatureDefImporterLite::Convert( |
| input, exported_names, context, |
| /*import_restore=*/false)); |
| |
| mlir::OpBuilder builder(module->getContext()); |
| (*module)->setAttr("tf_saved_model.under_construction", |
| builder.getUnitAttr()); |
| TF_RETURN_IF_ERROR(LiftVariables(bundle, *module)); |
| module->removeAttr("tf_saved_model.under_construction"); |
| |
| return module; |
| } |
| |
| private: |
| // Lifts the variables in `module`. |
| static Status LiftVariables(const SavedModelBundle& bundle, |
| mlir::ModuleOp module); |
| }; |
| |
| Status SavedModelSignatureDefImporter::LiftVariables( |
| const SavedModelBundle& bundle, mlir::ModuleOp module) { |
| mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); |
| |
| mlir::PassManager pm(module.getContext()); |
| SetCrashReproducer(pm); |
| pm.addNestedPass<mlir::FuncOp>( |
| mlir::tf_executor::CreateTFExecutorGraphPruningPass()); |
| pm.addNestedPass<mlir::FuncOp>( |
| mlir::CreateExecutorDialectToFunctionalConversionPass()); |
| pm.addPass( |
| mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass()); |
| pm.addNestedPass<mlir::FuncOp>( |
| mlir::TF:: |
| CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()); |
| pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass()); |
| pm.addPass( |
| mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession())); |
| pm.addNestedPass<mlir::FuncOp>( |
| mlir::tf_saved_model::CreateDedupBoundInputBindingPass()); |
| if (mlir::failed(pm.run(module))) |
| return diag_handler.Combine(errors::Internal("Failed to lift variables.")); |
| |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| SavedModelMLIRImportInput::~SavedModelMLIRImportInput() {} |
| |
| StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir( |
| const GraphDef& graphdef, const GraphDebugInfo& debug_info, |
| const GraphImportConfig& specs, mlir::MLIRContext* context, |
| bool add_default_attributes) { |
| GraphConstructorOptions options; |
| options.allow_internal_ops = true; |
| options.add_default_attributes = add_default_attributes; |
| Graph graph(OpRegistry::Global()); |
| |
| GraphDef preprocessed_graphdef(graphdef); |
| if (add_default_attributes) { |
| TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef)); |
| } |
| if (specs.upgrade_legacy) { |
| TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty( |
| preprocessed_graphdef, graph.flib_def().default_registry())); |
| } |
| TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
| options, std::move(preprocessed_graphdef), &graph)); |
| return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, |
| context); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir( |
| const Graph& graph, const GraphDebugInfo& debug_info, |
| const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, |
| mlir::MLIRContext* context) { |
| // TODO(jpienaar): Remove need to const_cast. |
| if (specs.upgrade_legacy) { |
| TF_RETURN_IF_ERROR( |
| UpgradeLegacyGraph(const_cast<Graph*>(&graph), |
| const_cast<FunctionLibraryDefinition*>(&flib_def), |
| specs.restrict_functionalization_to_tpu_nodes)); |
| } |
| return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs, |
| /*func_name=*/"main"); |
| } |
| |
| stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir( |
| const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, |
| mlir::MLIRContext* context) { |
| tensorflow::GraphDebugInfo dummy_debug_info; |
| tensorflow::GraphImportConfig specs; |
| specs.enable_shape_inference = false; |
| specs.graph_as_function = true; |
| for (const auto* control_ret_node : fbody->control_ret_nodes) |
| specs.control_outputs.push_back(control_ret_node->name()); |
| return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, |
| flib_def, specs, |
| fbody->fdef.signature().name()); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( |
| SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, |
| absl::Span<std::string> exported_names, bool add_default_attributes) { |
| return SavedModelObjectGraphImporter::Convert( |
| saved_model, exported_names, context, add_default_attributes); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir( |
| const SavedModelBundle& saved_model, absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context, MLIRImportOptions options) { |
| return SavedModelSignatureDefImporter::Convert(saved_model, exported_names, |
| context, options); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite( |
| const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, |
| absl::Span<std::string> exported_names, mlir::MLIRContext* context, |
| MLIRImportOptions options) { |
| TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create( |
| options, &meta_graph_def, debug_info)); |
| return ConvertSavedModelV1ToMlirLite(input, exported_names, context); |
| } |
| |
| StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlirLite( |
| SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names, |
| mlir::MLIRContext* context) { |
| return SavedModelSignatureDefImporterLite::Convert(input, exported_names, |
| context); |
| } |
| |
| std::string MlirModuleToString(mlir::ModuleOp module, |
| mlir::OpPrintingFlags flags) { |
| std::string txt_module; |
| { |
| llvm::raw_string_ostream os{txt_module}; |
| module.print(os, flags); |
| } |
| return txt_module; |
| } |
| |
| std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { |
| mlir::OpPrintingFlags flags; |
| if (show_debug_info) flags.enableDebugInfo(); |
| return MlirModuleToString(module, flags); |
| } |
| |
| } // namespace tensorflow |