| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Matchers.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| |
| namespace mlir { |
| namespace TFL { |
| namespace { |
| |
| // This pass outlines the cond/body region of the TFL WhileOp into functions and |
| // replaces the regions with calls to these outlined functions. |
| class WhileOutlinePass |
| : public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> { |
| void getDependentDialects(DialectRegistry& registry) const override { |
| registry.insert<TF::TensorFlowDialect>(); |
| } |
| |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WhileOutlinePass) |
| |
| explicit WhileOutlinePass() {} |
| |
| StringRef getArgument() const final { |
| // This is the argument used to refer to the pass in |
| // the textual format (on the commandline for example). |
| return "tfl-while-loop-outline"; |
| } |
| StringRef getDescription() const final { |
| // This is a brief description of the pass. |
| return "Hoist while op regions into functions"; |
| } |
| |
| private: |
| void runOnOperation() override; |
| |
| // Outlines the regions of the WhileOp's cond and body and insert function |
| // calls instead, |
| void OutlineWhile(WhileOp while_op); |
| |
| // Get unique name by using the loc to name mapping. |
| std::string GetName(Operation* op, StringRef suffix); |
| |
| tensorflow::OpOrArgLocNameMapper mapper_; |
| }; |
| |
| std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) { |
| return (mapper_.GetUniqueName(op) + suffix).str(); |
| } |
| |
| // Returns whether the WhileOp is already outlined (e.g., only consists of calls |
| // to functions). |
| bool IsAlreadyOutlined(WhileOp while_op) { |
| auto just_call = [](Region& region) { |
| auto it = region.front().begin(); |
| if (!isa<func::CallOp>(*it)) return false; |
| ++it; |
| if (!isa<YieldOp>(*it)) return false; |
| return true; |
| }; |
| return just_call(while_op.body()) && just_call(while_op.cond()); |
| } |
| |
| bool IsCompatibleTypeWithTFLCastOp(Type type) { |
| auto elemType = getElementTypeOrSelf(type); |
| // F32 and BF16 types are allowed. |
| if (elemType.isBF16() || elemType.isF32()) return true; |
| |
| // I1, I8 I16, I32, I64 types are allowed. |
| if (elemType.isInteger(1) || elemType.isInteger(8) || |
| elemType.isInteger(16) || elemType.isInteger(32) || |
| elemType.isInteger(64)) |
| return true; |
| |
| // Complex<F<32>> is allowed. |
| if (elemType.isa<ComplexType>() && |
| elemType.cast<ComplexType>().getElementType().isF32()) |
| return true; |
| |
| // QUINT8 and UI8 are allowed. |
| if (elemType.isa<TF::Quint8Type>() || |
| (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned())) |
| return true; |
| |
| return false; |
| } |
| |
| func::FuncOp CreateOutlineFunc(StringRef name, Region& region, |
| bool passthru_extra_args, int num_loop_carried, |
| const llvm::SetVector<Value>& extern_values, |
| const SmallVectorImpl<Type>& types, |
| Location loc) { |
| MLIRContext* context = loc.getContext(); |
| OpBuilder builder(context); |
| FunctionType type; |
| if (passthru_extra_args) { |
| type = FunctionType::get(context, types, types); |
| } else { |
| SmallVector<Type, 4> result_types; |
| auto operands = region.front().getTerminator()->getOperandTypes(); |
| result_types.append(operands.begin(), operands.end()); |
| type = FunctionType::get(context, types, result_types); |
| } |
| |
| auto outlined_func = builder.create<func::FuncOp>(loc, name, type); |
| outlined_func.getBody().takeBody(region); |
| Region& func_region = outlined_func.getBody(); |
| |
| // Replace all external uses with block args and update uses. |
| llvm::SmallVector<Value, 4> new_args; |
| new_args.reserve(extern_values.size()); |
| Block& block = func_region.front(); |
| for (Value value : extern_values) { |
| auto arg = block.addArgument(value.getType(), loc); |
| replaceAllUsesInRegionWith(value, arg, func_region); |
| new_args.push_back(arg); |
| } |
| |
| // Replace yield op with return. |
| Operation* yield_op = outlined_func.getBody().front().getTerminator(); |
| OpBuilder b(yield_op); |
| llvm::SmallVector<Value, 4> args; |
| auto loop_carried_yield_operands = |
| yield_op->getOperands().take_front(num_loop_carried); |
| args.reserve(loop_carried_yield_operands.size() + new_args.size()); |
| if (passthru_extra_args) { |
| // Add operands of yield to the return, inserting casts if needed. |
| for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) { |
| auto value = std::get<0>(it); |
| auto type = std::get<1>(it); |
| if (value.getType() == type) { |
| args.push_back(value); |
| } else { |
| if (IsCompatibleTypeWithTFLCastOp(value.getType()) && |
| IsCompatibleTypeWithTFLCastOp(type)) { |
| auto cast = b.create<CastOp>(yield_op->getLoc(), type, value); |
| args.push_back(cast); |
| } else { |
| auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value); |
| args.push_back(cast); |
| } |
| } |
| } |
| args.append(new_args.begin(), new_args.end()); |
| } else { |
| args.append(yield_op->operand_begin(), yield_op->operand_end()); |
| } |
| b.create<func::ReturnOp>(yield_op->getLoc(), args); |
| yield_op->erase(); |
| SymbolTable(region.getParentOfType<ModuleOp>()).insert(outlined_func); |
| outlined_func.setPrivate(); |
| return outlined_func; |
| } |
| |
| // Replace region with call to outline function. |
| void ReplaceRegionWithCall(StringRef name, Region& region, |
| bool passthru_extra_args, int num_loop_carried, |
| const llvm::SetVector<Value>& extern_values, |
| const SmallVectorImpl<Type>& types, Location loc) { |
| auto func = CreateOutlineFunc(name, region, passthru_extra_args, |
| num_loop_carried, extern_values, types, loc); |
| OpBuilder b(region); |
| // The body of the region is empty/has been outlined into the function. |
| auto block = b.createBlock(®ion); |
| SmallVector<Value, 4> new_operands; |
| new_operands.reserve(types.size()); |
| for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size())) |
| new_operands.push_back(block->addArgument(t, loc)); |
| for (Value v : extern_values) new_operands.push_back(v); |
| auto call = b.create<func::CallOp>(loc, func, new_operands); |
| b.create<YieldOp>(loc, call.getResults()); |
| } |
| |
| void WhileOutlinePass::OutlineWhile(WhileOp while_op) { |
| OpBuilder builder(&getContext()); |
| // Collect external values used. |
| llvm::SetVector<Value> extern_values; |
| |
| // The basic block arguments correspond to values that are loop carried, while |
| // all those post are loop independent. Initialize extern_values with while_op |
| // not loop carried operands. |
| auto num_loop_carried = while_op.cond().getNumArguments(); |
| auto not_carried_operands = |
| while_op.getOperands().drop_front(num_loop_carried); |
| extern_values.insert(not_carried_operands.begin(), |
| not_carried_operands.end()); |
| auto old_extern_values_size = extern_values.size(); |
| |
| llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()}; |
| for (auto it : llvm::enumerate(regions)) { |
| llvm::SetVector<Value> region_extern_values; |
| getUsedValuesDefinedAbove(*it.value(), region_extern_values); |
| |
| // Sink down constants into the functions. |
| for (auto extern_value : region_extern_values) { |
| if (!matchPattern(extern_value, m_Constant())) { |
| extern_values.insert(extern_value); |
| continue; |
| } |
| // Add constant at start of region. |
| auto const_builder = |
| OpBuilder(&it.value()->front(), it.value()->front().begin()); |
| auto const_value = const_builder.clone(*extern_value.getDefiningOp()); |
| replaceAllUsesInRegionWith(extern_value, const_value->getResult(0), |
| *it.value()); |
| } |
| } |
| |
| bool has_extra_extern_values = old_extern_values_size != extern_values.size(); |
| // If an extern value is already an operand post the loop carried operands, |
| // then it need not be passed in again. |
| // Compute all the extra operands that have to be added to the while. |
| llvm::SetVector<Value> extra_operands; |
| if (has_extra_extern_values) { |
| auto new_extern = |
| extern_values.getArrayRef().drop_front(old_extern_values_size); |
| extra_operands.insert(new_extern.begin(), new_extern.end()); |
| } |
| |
| // Skip if already just calls. |
| if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return; |
| |
| // Collect new types. |
| SmallVector<Type, 4> types; |
| types.reserve(extra_operands.size() + while_op.getNumOperands()); |
| for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type); |
| for (Value operand : extern_values) types.push_back(operand.getType()); |
| |
| // Create outline function from region. Optional pass extra arguments through |
| // to yield. |
| ReplaceRegionWithCall(GetName(while_op.getOperation(), "_cond"), |
| while_op.cond(), false, num_loop_carried, extern_values, |
| types, while_op.getLoc()); |
| ReplaceRegionWithCall(GetName(while_op.getOperation(), "_body"), |
| while_op.body(), true, num_loop_carried, extern_values, |
| types, while_op.getLoc()); |
| |
| // If there are extern values used then the result type of the while has to |
| // change, so replace with new while op. |
| if (extra_operands.empty()) return; |
| |
| const int operands_size = while_op.getNumOperands() + extra_operands.size(); |
| SmallVector<Value, 4> operands; |
| operands.reserve(operands_size); |
| operands.append(while_op.getOperands().begin(), while_op.getOperands().end()); |
| operands.append(extra_operands.begin(), extra_operands.end()); |
| SmallVector<Type, 4> new_types; |
| new_types.reserve(operands_size); |
| new_types.append(while_op.getResultTypes().begin(), |
| while_op.getResultTypes().end()); |
| for (auto extra_operand : extra_operands) |
| new_types.push_back(extra_operand.getType()); |
| |
| auto new_while_op = OpBuilder(while_op).create<WhileOp>( |
| while_op.getLoc(), new_types, operands, while_op->getAttrs()); |
| new_while_op.cond().takeBody(while_op.cond()); |
| new_while_op.body().takeBody(while_op.body()); |
| while_op.replaceAllUsesWith( |
| new_while_op.getResults().take_front(while_op.getNumResults())); |
| while_op.erase(); |
| } |
| |
| void WhileOutlinePass::runOnOperation() { |
| getOperation().walk( |
| [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); }); |
| } |
| } // namespace |
| |
| // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. |
| std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() { |
| return std::make_unique<WhileOutlinePass>(); |
| } |
| |
| static PassRegistration<WhileOutlinePass> pass; |
| |
| } // namespace TFL |
| } // namespace mlir |