| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" |
| |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/OpImplementation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| |
| namespace mlir { |
| namespace tf_saved_model { |
| |
| //===----------------------------------------------------------------------===// |
| // Utilities |
| //===----------------------------------------------------------------------===// |
| |
| static bool IsStrArrayAttr(Attribute attr) { |
| auto array = attr.dyn_cast<ArrayAttr>(); |
| if (!array) return false; |
| |
| return llvm::all_of(array, |
| [](Attribute attr) { return attr.isa<StringAttr>(); }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorFlowSavedModelDialect Op's |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) { |
| if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) { |
| return failure(); |
| } |
| return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>()); |
| } |
| |
| LogicalResult GlobalTensorOp::verify() { |
| GlobalTensorOp global_tensor = *this; |
| if (failed(VerifyTensorTypesCompatible(global_tensor.type(), |
| global_tensor.value().getType()))) { |
| return global_tensor.emitError() << "'type' and 'value' attributes should " |
| "have compatible tensor types"; |
| } |
| if (!global_tensor.is_mutable()) { |
| if (!global_tensor.type().cast<TensorType>().hasStaticShape()) { |
| return global_tensor.emitError() |
| << "'type' attribute for immutable 'tf_saved_model.global_tensor' " |
| "should have a static shape"; |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult SessionInitializerOp::verify() { |
| SessionInitializerOp session_initializer = *this; |
| mlir::SymbolTable symbol_table( |
| session_initializer->getParentOfType<ModuleOp>()); |
| |
| for (auto sym_ref : session_initializer.initializers()) { |
| auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>( |
| sym_ref.cast<FlatSymbolRefAttr>().getValue()); |
| |
| if (!init_func_op) |
| return session_initializer.emitOpError() |
| << "the initializer function does not exist"; |
| |
| if (!init_func_op.getFunctionType().getResults().empty()) |
| return session_initializer.emitOpError() |
| << "the initializer function should have no output"; |
| |
| auto exported_names = GetExportedNames(init_func_op); |
| |
| if (exported_names.empty()) |
| return session_initializer.emitOpError() |
| << "the initializer function should be exported"; |
| |
| if (exported_names.size() != 1) |
| return session_initializer.emitOpError() |
| << "the initializer function should have only one exported names"; |
| } |
| |
| return success(); |
| } |
| |
| } // namespace tf_saved_model |
| } // namespace mlir |
| |
| #define GET_OP_CLASSES |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" |
| |
| namespace mlir { |
| namespace tf_saved_model { |
| |
| //===----------------------------------------------------------------------===// |
| // TensorFlowSavedModelDialect Dialect |
| //===----------------------------------------------------------------------===// |
| |
| TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) |
| : Dialect(/*name=*/"tf_saved_model", context, |
| TypeID::get<TensorFlowSavedModelDialect>()) { |
| // The TensorFlow Dialect is needed in the verifier and other routines |
| // associated to this dialect. It makes little sense anyway to use the |
| // SavedModel dialect without the TensorFlow Dialect. |
| context->loadDialect<TF::TensorFlowDialect>(); |
| |
| addOperations< |
| #define GET_OP_LIST |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" |
| >(); |
| } |
| |
| static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { |
| auto attr = named_attr.getValue().dyn_cast<ArrayAttr>(); |
| if (!attr) { |
| return op->emitError() |
| << "'tf_saved_model.index_path' attribute should be an ArrayAttr"; |
| } |
| for (auto element : attr) { |
| if (element.isa<StringAttr>()) { |
| continue; |
| } |
| if (auto integer = element.dyn_cast<IntegerAttr>()) { |
| if (integer.getValue().getBitWidth() == 64) { |
| continue; |
| } |
| } |
| return op->emitError() << "'tf_saved_model.index_path' elements should " |
| "be strings or 64-bit integers"; |
| } |
| return mlir::success(); |
| } |
| |
| Type GetBoundInputArgTypeFor(mlir::Operation *op) { |
| if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) { |
| auto type = global_tensor.type().cast<TensorType>(); |
| return RankedTensorType::get( |
| {}, TF::ResourceType::get({type}, type.getContext())); |
| } |
| |
| if (auto asset = llvm::dyn_cast<AssetOp>(op)) { |
| return RankedTensorType::get({}, TF::StringType::get(asset.getContext())); |
| } |
| |
| op->emitError() << "unknown symbol operation"; |
| return {}; |
| } |
| |
| static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics, |
| Type arg_type, |
| mlir::Operation *symbol_op) { |
| auto expected_type = GetBoundInputArgTypeFor(symbol_op); |
| if (!expected_type) return failure(); |
| |
| if (arg_type != expected_type) { |
| return op_for_diagnostics->emitError() |
| << "bound input with type " << arg_type << " expected to have type " |
| << expected_type; |
| } |
| return success(); |
| } |
| |
| LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( |
| Operation *op, unsigned region_index, unsigned arg_index, |
| NamedAttribute named_attr) { |
| if (named_attr.getName() == "tf_saved_model.bound_input") { |
| if (!named_attr.getValue().isa<FlatSymbolRefAttr>()) { |
| return op->emitError() << "'tf_saved_model.bound_input' attribute should " |
| "be a FlatSymbolRefAttr"; |
| } |
| auto symbol_name = |
| named_attr.getValue().cast<FlatSymbolRefAttr>().getValue(); |
| auto module = op->getParentOfType<ModuleOp>(); |
| mlir::Operation *symbol_op = module.lookupSymbol(symbol_name); |
| if (!symbol_op) { |
| return op->emitError() << "'tf_saved_model.bound_input' attribute must " |
| "reference a valid symbol, got invalid symbol '" |
| << symbol_name << "'"; |
| } |
| auto arg_type = cast<func::FuncOp>(op).getArgument(arg_index).getType(); |
| return VerifyBoundInputArgType(op, arg_type, symbol_op); |
| } |
| if (named_attr.getName() == "tf_saved_model.index_path") { |
| return VerifyIndexPath(op, named_attr); |
| } |
| |
| return op->emitError() << "unknown tf_saved_model dialect arg attribute '" |
| << named_attr.getName().getValue() << "'"; |
| } |
| |
| LogicalResult TensorFlowSavedModelDialect::verifyRegionResultAttribute( |
| Operation *op, unsigned region_index, unsigned result_index, |
| NamedAttribute named_attr) { |
| if (named_attr.getName() == "tf_saved_model.index_path") { |
| return VerifyIndexPath(op, named_attr); |
| } |
| |
| return op->emitError() << "unknown tf_saved_model dialect result attribute '" |
| << named_attr.getName().getValue() << "'"; |
| } |
| |
| static bool HasAnyTfSavedModelArgAttr(func::FuncOp func) { |
| for (int i = 0, e = func.getNumArguments(); i < e; i++) { |
| if (func.getArgAttr(i, "tf_saved_model.index_path") || |
| func.getArgAttr(i, "tf_saved_model.bound_input")) { |
| return true; |
| } |
| } |
| for (int i = 0, e = func.getNumResults(); i < e; i++) { |
| if (func.getResultAttr(i, "tf_saved_model.index_path") || |
| func.getResultAttr(i, "tf_saved_model.bound_input")) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| static LogicalResult VerifySavedModelModule( |
| ModuleOp module, TensorFlowSavedModelDialect *dialect) { |
| auto exported_names_ident = |
| StringAttr::get(dialect->getContext(), "tf_saved_model.exported_names"); |
| // Check that there are no duplicated exported_names. |
| DenseMap<StringRef, Operation *> exported_name_to_op; |
| for (auto &op : module) { |
| auto attr = op.getAttr(exported_names_ident); |
| if (!attr) continue; |
| // If this verifier is called before we verify the |
| // 'tf_saved_model.exported_names' attribute, then it might be invalid. |
| // Forward to the dialect's verification to establish that precondition. |
| if (failed(dialect->verifyOperationAttribute( |
| &op, {exported_names_ident, attr}))) { |
| return failure(); |
| } |
| for (auto str : attr.cast<ArrayAttr>()) { |
| auto exported_name = str.cast<StringAttr>().getValue(); |
| auto p = exported_name_to_op.insert({exported_name, &op}); |
| if (!p.second) { |
| return op.emitError() |
| .append("duplicate exported name '", exported_name, "'") |
| .attachNote(p.first->getSecond()->getLoc()) |
| .append("previously seen here"); |
| } |
| } |
| } |
| for (auto func : module.getOps<func::FuncOp>()) { |
| const bool is_exported = IsExported(func); |
| |
| if (is_exported && !func.isPublic()) { |
| return func.emitError() |
| << "exported function @" << func.getName() << " should be public"; |
| } |
| |
| if (!is_exported && func.isPublic()) { |
| return func.emitError() << "non-exported function @" << func.getName() |
| << " should be private"; |
| } |
| if (!is_exported && HasAnyTfSavedModelArgAttr(func)) { |
| return func.emitError() << "can only apply 'tf_saved_model' argument " |
| "attributes to exported functions"; |
| } |
| } |
| |
| auto session_initializers = module.getOps<SessionInitializerOp>(); |
| if (!session_initializers.empty() && |
| !llvm::hasSingleElement(session_initializers)) { |
| return (*++session_initializers.begin()).emitError() |
| << "there must be no more than one session_initializer op"; |
| } |
| |
| auto is_init = [&session_initializers](mlir::func::FuncOp func) { |
| if (session_initializers.empty()) return false; |
| auto init_syms = (*session_initializers.begin()).initializers(); |
| return std::any_of( |
| init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) { |
| return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName(); |
| }); |
| }; |
| |
| SymbolTable symbol_table(module); |
| auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion()); |
| if (!symbol_uses.has_value()) { |
| return module.emitError() << "modules with 'tf_saved_model.semantics' must " |
| "have analyzable symbol uses"; |
| } |
| for (auto symbol_use : *symbol_uses) { |
| auto func = symbol_table.lookupNearestSymbolFrom<func::FuncOp>( |
| symbol_use.getUser(), symbol_use.getSymbolRef()); |
| if (func && IsExported(func)) { |
| // If it is an init function, then it can be used by the unique |
| // session_initializer op. |
| if (is_init(func) && |
| llvm::isa<SessionInitializerOp>(symbol_use.getUser())) |
| continue; |
| |
| return symbol_use.getUser() |
| ->emitError("exported function cannot be internally referenced") |
| .attachNote(func.getLoc()) |
| .append("references this exported function"); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult VerifyExportedFunc(func::FuncOp func) { |
| bool reached_bound_inputs = false; |
| auto module = func->getParentOfType<ModuleOp>(); |
| for (int i = 0, e = func.getNumArguments(); i < e; i++) { |
| if (func.getArgAttr(i, "tf_saved_model.bound_input")) { |
| reached_bound_inputs = true; |
| continue; |
| } |
| if (func.getArgAttr(i, "tf_saved_model.index_path")) { |
| if (reached_bound_inputs) { |
| return func.emitError() |
| << "all 'tf_saved_model.index_path' arg attributes should " |
| "precede all 'tf_saved_model.bound_input' arg attributes"; |
| } |
| continue; |
| } |
| if (func.getArgAttr(i, "tf.resource_name")) { |
| if (module->getAttr("tf_saved_model.under_construction")) continue; |
| return func.emitError() << "'tf.resource_name' attribute is not allowed " |
| "unless it is being under construction"; |
| } |
| return func.emitError() |
| << "all arguments should have 'tf_saved_model.index_path', " |
| "'tf_saved_model.bound_input' or 'tf.resource_name' attributes"; |
| } |
| llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs; |
| for (int i = 0, e = func.getNumArguments(); i < e; i++) { |
| if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>( |
| i, "tf_saved_model.bound_input")) { |
| if (!unique_bound_inputs.insert(attr.getValue()).second) { |
| if (module->getAttr("tf_saved_model.under_construction")) continue; |
| return func.emitError() |
| << "duplicate 'tf_saved_model.bound_input' binding"; |
| } |
| } |
| } |
| |
| for (int i = 0, e = func.getNumResults(); i < e; i++) { |
| if (!func.getResultAttr(i, "tf_saved_model.index_path")) { |
| return func.emitError() << "all results should have " |
| "'tf_saved_model.index_path' attributes"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute( |
| Operation *op, NamedAttribute named_attr) { |
| if (named_attr.getName() == "tf_saved_model.exported_names") { |
| if (!isa<func::FuncOp, GlobalTensorOp>(op)) { |
| return op->emitError() << "'tf_saved_model.exported_names' must be on a " |
| "'func' or 'tf_saved_model.global_tensor' op"; |
| } |
| if (!IsStrArrayAttr(named_attr.getValue())) { |
| return op->emitError() |
| << "'tf_saved_model.exported_names' must be an array of strings"; |
| } |
| if (!op->getParentOp()->getAttr("tf_saved_model.semantics")) { |
| return op->emitError() |
| << "'tf_saved_model.exported_names' must be on an op " |
| "whose immediate parent has attribute " |
| "'tf_saved_model.semantics'"; |
| } |
| if (auto func = dyn_cast<func::FuncOp>(op)) { |
| if (failed(VerifyExportedFunc(func))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| if (named_attr.getName() == "tf_saved_model.semantics") { |
| auto module = dyn_cast<ModuleOp>(op); |
| if (!module) { |
| return op->emitError() << "'tf_saved_model.semantics' must " |
| "be on a module op"; |
| } |
| return VerifySavedModelModule(module, this); |
| } |
| if (named_attr.getName() == "tf_saved_model.under_construction") { |
| return success(); |
| } |
| |
| return op->emitError() << "unknown tf_saved_model dialect attribute '" |
| << named_attr.getName().getValue() << "'"; |
| } |
| |
| SmallVector<StringRef, 2> GetExportedNames(Operation *op) { |
| SmallVector<StringRef, 2> ret; |
| auto exported_names = |
| op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names"); |
| if (exported_names) { |
| for (auto name : exported_names) { |
| ret.push_back(name.cast<StringAttr>().getValue()); |
| } |
| } |
| return ret; |
| } |
| |
| bool IsExported(Operation *op) { |
| auto exported_names = |
| op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names"); |
| return exported_names && !exported_names.empty(); |
| } |
| |
| bool HasTfSavedModelSemantics(ModuleOp module) { |
| return module->getAttr("tf_saved_model.semantics") != nullptr; |
| } |
| |
| Operation *LookupBoundInput(func::FuncOp func, int arg_index, |
| const SymbolTable &symbol_table) { |
| auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>( |
| arg_index, "tf_saved_model.bound_input"); |
| if (!attr) return nullptr; |
| return symbol_table.lookup(attr.getValue()); |
| } |
| |
| SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) { |
| auto initializers = op.getOps<SessionInitializerOp>(); |
| if (initializers.empty()) return {}; |
| return *initializers.begin(); |
| } |
| |
| class OptimizeSessionInitializerPattern |
| : public OpRewritePattern<SessionInitializerOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SessionInitializerOp op, |
| PatternRewriter &rewriter) const override { |
| SymbolTable symbol_table(op->getParentOfType<ModuleOp>()); |
| |
| SmallVector<func::FuncOp, 2> to_remove; |
| SmallVector<mlir::Attribute, 2> to_keep; |
| for (auto sym_ref : op.initializers()) { |
| auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>( |
| sym_ref.cast<FlatSymbolRefAttr>().getValue()); |
| |
| // The init function can only be referenced from the SessionInitializerOp. |
| // And there is at most one SessionInitializerOp in the module. So if both |
| // ops have no other uses or have one NoOp only, they can be simply |
| // erased. |
| auto &operations = init_func_op.front().getOperations(); |
| if ((operations.size() == 1 && |
| operations.front().hasTrait<OpTrait::IsTerminator>()) || |
| (operations.size() == 2 && |
| dyn_cast<mlir::TF::NoOp>(operations.front()) && |
| operations.back().hasTrait<OpTrait::IsTerminator>())) { |
| to_remove.push_back(init_func_op); |
| } else { |
| to_keep.push_back(sym_ref); |
| } |
| } |
| |
| for (auto func_op : to_remove) rewriter.eraseOp(func_op); |
| |
| if (to_keep.empty()) |
| rewriter.eraseOp(op); |
| else |
| op->setAttr("initializers", rewriter.getArrayAttr(to_keep)); |
| |
| return success(); |
| } |
| }; |
| |
| void SessionInitializerOp::getCanonicalizationPatterns( |
| RewritePatternSet &results, MLIRContext *context) { |
| results.add<OptimizeSessionInitializerPattern>(context); |
| } |
| |
| SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) { |
| auto session_initializer_op = GetSessionInitializerOp(op); |
| if (!session_initializer_op) return {}; |
| |
| SymbolTable symbol_table(op); |
| |
| SmallVector<StringRef, 2> results; |
| for (auto sym_ref : session_initializer_op.initializers()) { |
| auto init_func_op = symbol_table.lookup<mlir::func::FuncOp>( |
| sym_ref.cast<FlatSymbolRefAttr>().getValue()); |
| auto exported_names = GetExportedNames(init_func_op); |
| assert(exported_names.size() == 1); |
| results.push_back(exported_names[0]); |
| } |
| |
| return results; |
| } |
| |
| } // namespace tf_saved_model |
| } // namespace mlir |