| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <iterator> |
| #include <numeric> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/memory/memory.h" |
| #include "absl/strings/string_view.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/None.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project |
| #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/IR/Visitors.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/DialectConversion.h" // from @llvm-project |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project |
| #include "mlir/Transforms/InliningUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" |
| #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" |
| #include "tensorflow/compiler/mlir/tfr/passes/passes.h" |
| #include "tensorflow/compiler/mlir/tfr/utils/utils.h" |
| #include "tensorflow/core/lib/monitoring/counter.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| auto* tf_core_op_expansion_op_counter = |
| monitoring::Counter<1>::New("/tensorflow/core/op_expansion/op_counter", |
| "The number of composite op expanded.", "name"); |
| } |
| |
| void IncreaseOpExpansionExecuteCounterByOne(const std::string& op_name) { |
| tf_core_op_expansion_op_counter->GetCell(op_name)->IncrementBy(1); |
| } |
| |
| } // namespace tensorflow |
| |
| //===----------------------------------------------------------------------===// |
| // The pass to decompose unregistered TF ops with the TFR compose function. |
| // |
| namespace mlir { |
| namespace TFR { |
| |
| namespace { |
| |
| // Quantize the float value based on given scale and zero point attributes. |
| Attribute Quantize(float value, Attribute scale_attr, Attribute zp_attr, |
| OpBuilder builder) { |
| double scale = scale_attr.cast<FloatAttr>().getValueAsDouble(); |
| int64_t zp = zp_attr.cast<IntegerAttr>().getInt(); |
| |
| int quantized = static_cast<int>(std::round(value / scale) + zp); |
| quantized = |
| std::min(quantized, static_cast<int>(std::numeric_limits<int8_t>::max())); |
| quantized = |
| std::max(quantized, static_cast<int>(std::numeric_limits<int8_t>::min())); |
| return builder.getI32IntegerAttr(quantized); |
| } |
| |
| // Decompose the TF ops with the registered composition library. |
| class DecomposeTFOpsPass |
| : public PassWrapper<DecomposeTFOpsPass, OperationPass<func::FuncOp>> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeTFOpsPass) |
| |
| explicit DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module) |
| : external_tfr_module_(external_tfr_module) {} |
| |
| StringRef getArgument() const final { return "tfr-decompose"; } |
| |
| StringRef getDescription() const final { |
| return "Decompose TF ops with the registered composition library."; |
| } |
| |
| void runOnOperation() override; |
| |
| private: |
| // Apply canonicalization, mainly constant folding, on the function. |
| void ApplyCanonicalization(); |
| |
| // Rewrite unregistered TF ops to TFR func call ops. Return failure if all the |
| // ops are registered or the compose function doesn't exist. |
| LogicalResult RewriteUnregisteredTFOps(); |
| |
| // Inline the TFR func call ops. |
| LogicalResult InlineTFRFuncCalls(); |
| |
| // Optional external symbol table to look up the TFR function. |
| llvm::Optional<ModuleOp> external_tfr_module_; |
| }; |
| |
| #include "tensorflow/compiler/mlir/tfr/passes/generated_decompose.inc" |
| |
| void DecomposeTFOpsPass::ApplyCanonicalization() { |
| func::FuncOp func = getOperation(); |
| RewritePatternSet patterns(&getContext()); |
| |
| populateWithGenerated(patterns); |
| populateCanonicalizationPatterns(func, patterns); |
| |
| (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); |
| } |
| |
| LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { |
| func::FuncOp func = getOperation(); |
| SymbolTable table(external_tfr_module_.hasValue() |
| ? *external_tfr_module_ |
| : func->getParentOfType<ModuleOp>()); |
| OpBuilder builder(func); |
| bool changed = false; |
| func.walk([&table, &builder, &changed](Operation* op) { |
| // Only the un-registered ops requires decomposition. The remaining ones |
| // either will be constant folded or lowered by the rules defined in the |
| // bridge. |
| if (op->isRegistered()) { |
| return WalkResult::advance(); |
| } |
| |
| // Find out the compose function |
| auto compose_func_name = GetComposeFuncName(op->getName().getStringRef()); |
| auto compose_func = table.lookup<TFRFuncOp>(compose_func_name); |
| if (!compose_func || compose_func.isExternal()) { |
| // There are no decomposition methods defined for this op, skip. |
| return WalkResult::advance(); |
| } |
| |
| // Make sure all the attributes are valid. An attribute is valid when it is |
| // in the signature or it is allowed explicitly. |
| auto compose_func_signature = |
| table.lookup<TFRFuncOp>(compose_func_name + "_"); |
| if (!compose_func_signature) compose_func_signature = compose_func; |
| auto defined_attrs = compose_func_signature.getDefinedAttributeNames(); |
| if (failed(ValidateAttrs(op, defined_attrs))) { |
| return WalkResult::interrupt(); |
| } |
| |
| tensorflow::IncreaseOpExpansionExecuteCounterByOne( |
| op->getName().getStringRef().str()); |
| |
| auto compose_func_type = compose_func.getFunctionType(); |
| builder.setInsertionPoint(op); |
| TFRTensorType unconstrainted_tensor_type = builder.getType<TFRTensorType>(); |
| |
| // Create the new operands. This is mapping the operands from the target |
| // TF ops to the TFR function arguments. If the TFR function argument is |
| // a tensor_list, a "tfr.build_list" op is used to concat the available |
| // TF op operands. If the TFR function argument isn't a tensor/tensor_list, |
| // a constant is created by using the attribute stored in the TF op or the |
| // default value in the argument attribute. |
| llvm::SmallVector<Value, 4> new_operands; |
| for (auto arg : llvm::enumerate(compose_func_type.getInputs())) { |
| if (auto tensor_type = arg.value().dyn_cast<TFRTensorType>()) { |
| auto casted = builder.create<CastOp>(op->getLoc(), tensor_type, |
| op->getOperand(arg.index())); |
| new_operands.push_back(casted); |
| } else if (auto list_type = arg.value().dyn_cast<TFRTensorListType>()) { |
| llvm::SmallVector<Value, 4> variadic_operands; |
| for (int i = arg.index(); i < op->getNumOperands(); i++) { |
| auto casted = builder.create<CastOp>( |
| op->getLoc(), unconstrainted_tensor_type, op->getOperand(i)); |
| variadic_operands.push_back(casted); |
| } |
| auto build_list_op = builder.create<BuildListOp>( |
| op->getLoc(), list_type, variadic_operands); |
| new_operands.push_back(build_list_op.out()); |
| } else { |
| auto attr_name = compose_func.getArgAttrOfType<StringAttr>( |
| arg.index(), kAttrArgumentNameAttr); |
| auto attribute = op->getAttr(attr_name.getValue()); |
| if (!attribute) { |
| attribute = |
| compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr); |
| } |
| if (!attribute && attr_name.getValue() == "out_type") { |
| auto type = op->getResult(0).getType(); |
| if (type.isa<TensorType>()) { |
| type = type.cast<TensorType>().getElementType(); |
| } |
| attribute = TypeAttr::get(type); |
| } |
| Value attr_cst; |
| // Wrap these special attributes as a special TFR constant, so the SSA |
| // value has a valid type to be used as TFR function argument. These |
| // attributes are not expected to be manipulated by the lowering passes. |
| if (attribute.isa<TypeAttr>() || attribute.isa<ArrayAttr>() || |
| attribute.isa<StringAttr>() || attribute.isa<FlatSymbolRefAttr>()) { |
| TFRAttrType output_type = TFRAttrType::get(builder.getContext()); |
| attr_cst = |
| builder.create<ConstOp>(op->getLoc(), output_type, attribute); |
| } else { |
| attr_cst = |
| builder.create<mlir::arith::ConstantOp>(op->getLoc(), attribute); |
| } |
| new_operands.push_back(attr_cst); |
| } |
| } |
| |
| // Create the TFR call op |
| auto new_op = builder.create<CallOp>( |
| op->getLoc(), compose_func_type.getResults(), |
| SymbolRefAttr::get(builder.getContext(), compose_func.getName()), |
| new_operands); |
| |
| // Replace the use of the old op. This is mapping the results from the |
| // target TF ops to the TFR function returns. If the TFR function return is |
| // a tensor_list, "tfr.get_element" op is used to extract the required TF |
| // op result. |
| llvm::SmallVector<Value, 4> new_results; |
| for (auto res : llvm::enumerate(compose_func_type.getResults())) { |
| if (res.value().dyn_cast<TFRTensorType>()) { |
| new_results.push_back(new_op.getResult(res.index())); |
| } else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) { |
| for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) { |
| auto index = builder.create<mlir::arith::ConstantOp>( |
| op->getLoc(), builder.getIndexAttr(j)); |
| auto element_op = builder.create<GetElementOp>( |
| op->getLoc(), unconstrainted_tensor_type, |
| new_op.getResult(res.index()), index.getResult()); |
| new_results.push_back(element_op.out()); |
| } |
| } |
| } |
| for (auto res : llvm::zip(op->getResults(), new_results)) { |
| auto casted = builder.create<CastOp>( |
| op->getLoc(), std::get<0>(res).getType(), std::get<1>(res)); |
| std::get<0>(res).replaceAllUsesWith(casted.out()); |
| } |
| |
| // Copy all the unregisted attributes to the new op. |
| if (failed(CopyAllowedUnregisteredAttrs(op, new_op, defined_attrs))) { |
| return WalkResult::interrupt(); |
| } |
| |
| op->erase(); |
| changed |= true; |
| return WalkResult::advance(); |
| }); |
| |
| // If `changed` is false, it is considered as a failure, so the recursive |
| // rewrite will stop. |
| return success(changed); |
| } |
| |
| LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { |
| // The Inliner will automatically use the registered dialect inliner. |
| InlinerInterface inliner(&getContext()); |
| func::FuncOp func = getOperation(); |
| SymbolTable table(external_tfr_module_.hasValue() |
| ? *external_tfr_module_ |
| : func->getParentOfType<ModuleOp>()); |
| |
| // The inliner only inlines the TFR call op. |
| bool changed = false; |
| auto walk_result = func.walk([&](CallOp call_op) { |
| auto callee = table.lookup<TFRFuncOp>(call_op.callee()); |
| if (!callee || callee.isExternal()) return WalkResult::advance(); |
| |
| // Record the boundary of the inlined operations. The inlined operation will |
| // be inserted between these two operations. |
| Operation* inlined_point = call_op.getOperation(); |
| Operation* after_inlined_point = |
| &*std::next(Block::iterator(call_op.getOperation())); |
| |
| // Use the inliner to replace all the uses of the call_op by its |
| // composition. |
| if (failed(inlineCall(inliner, |
| cast<CallOpInterface>(call_op.getOperation()), |
| cast<CallableOpInterface>(callee.getOperation()), |
| callee.getCallableRegion(), |
| /**shouldCloneInLinedRegion=*/true))) { |
| // This failure is usually because the decompose function is not defined. |
| // This call will be raised to TF ops. |
| return WalkResult::interrupt(); |
| } |
| |
| // Propagate all the attributes to the inlined operations, which are defined |
| // by the two boundary operations. |
| PropagateAttrsToOperations(call_op, Block::iterator(inlined_point), |
| Block::iterator(after_inlined_point)); |
| |
| // Remove the call_op to finish the op expansion. |
| call_op.erase(); |
| changed |= true; |
| return WalkResult::advance(); |
| }); |
| |
| if (walk_result.wasInterrupted()) { |
| signalPassFailure(); |
| return failure(); |
| } |
| |
| // If `changed` is false, it is considered as a failure, so the recursive |
| // rewrite will stop. |
| return success(changed); |
| } |
| |
| void DecomposeTFOpsPass::runOnOperation() { |
| // Set a maximum iteration threshold in case there are infinite loops in the |
| // call stack. |
| int max_iterators = 10; |
| do { |
| // canonicalization |
| ApplyCanonicalization(); |
| |
| // rewrite unregistered tf ops. Failed either because no ops can be |
| // decomposed or the compose function isn't defined. |
| auto rewrite_status = RewriteUnregisteredTFOps(); |
| // inline the tfr call op until there are no tfr.call op can be inlined. |
| auto inline_status = InlineTFRFuncCalls(); |
| |
| if (failed(rewrite_status) && failed(inline_status)) { |
| break; |
| } |
| } while (max_iterators-- >= 0); |
| } |
| |
| } // namespace |
| |
| // Creates an instance of the pass to decompose the TF ops. |
| std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeTFOpsPass( |
| llvm::Optional<ModuleOp> tfr_module) { |
| return std::make_unique<DecomposeTFOpsPass>(tfr_module); |
| } |
| |
| static PassRegistration<DecomposeTFOpsPass> pass([] { |
| return CreateDecomposeTFOpsPass(); |
| }); |
| |
| } // namespace TFR |
| } // namespace mlir |