| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <memory> |
| #include <stdexcept> |
| #include <utility> |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/iterator_range.h" |
| #include "llvm/IR/Constants.h" |
| #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" // from @llvm-project |
| #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" // from @llvm-project |
| #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project |
| #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project |
| #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project |
| #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project |
| #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project |
| #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinDialect.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/OperationSupport.h" // from @llvm-project |
| #include "mlir/IR/TypeRange.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Transforms/DialectConversion.h" // from @llvm-project |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h" |
| |
| namespace mlir { |
| namespace mhlo { |
| namespace { |
| |
| // Given a FuncOp with only memref args/outputs, create a new function that |
| // wraps/unwraps xla_framework.buffer types and then calls the original |
| // function. |
| // |
| // For example: |
| // func @func_to_outline(%arg0: memref<?xf32>) -> memref<?xf32> |
| // |
| // Will generate: |
| // func @func_to_outline_xla_framework(%arg0: !xla_framework.buffer) |
| // -> !xla_framework.buffer attributes {xla_entry = true} { |
| // %0 = xla_framework.buffer_to_mem %arg0 : memref<?xf32> |
| // %1 = call @func_to_outline(%0) : (memref<?xf32>) -> memref<?xf32> |
| // %2 = xla_framework.mem_to_buffer %1 : memref<?xf32> |
| // return %2 : !xla_framework.buffer |
| // } |
| struct OutlineXLAFunc : public RewritePattern { |
| explicit OutlineXLAFunc(MLIRContext *context, PatternBenefit benefit = 1) |
| : RewritePattern(func::FuncOp::getOperationName(), benefit, context) {} |
| |
| static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, |
| bool argAttrs, |
| SmallVectorImpl<NamedAttribute> &result) { |
| for (const auto &attr : attrs) { |
| if (attr.getName() == SymbolTable::getSymbolAttrName() || |
| attr.getName() == FunctionOpInterface::getTypeAttrName() || |
| attr.getName() == "std.varargs" || |
| (argAttrs && attr.getName() == func::FuncOp::getArgDictAttrName())) |
| continue; |
| result.push_back(attr); |
| } |
| } |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto func = dyn_cast<func::FuncOp>(op); |
| auto ctx = rewriter.getContext(); |
| auto loc = func.getLoc(); |
| SmallVector<Location> locs(func.getFunctionType().getNumInputs(), loc); |
| |
| // Functions should only be outlined once and should only use memrefs |
| if (!func) return failure(); |
| if (llvm::any_of(op->getOperandTypes(), |
| [](Type t) { return !t.isa<MemRefType>(); }) || |
| op->getNumResults() != 0) |
| return failure(); |
| if (func->hasAttr("outlined")) return failure(); |
| func->setAttr("outlined", BoolAttr::get(ctx, true)); |
| |
| // Prepare new func attribute information |
| func.setSymNameAttr(mlir::StringAttr::get(ctx, func.getName())); |
| SmallVector<Type> operands(func.getFunctionType().getNumInputs(), |
| ::mlir::xla_framework::BufferType::get(ctx)); |
| SmallVector<Type> result_array(func.getFunctionType().getNumResults(), |
| ::mlir::xla_framework::BufferType::get(ctx)); |
| auto func_type = FunctionType::get(ctx, operands, result_array); |
| SmallVector<NamedAttribute> attrs; |
| filterFuncAttributes(func->getAttrs(), true, attrs); |
| SmallVector<DictionaryAttr> arg_attrs; |
| func.getAllArgAttrs(arg_attrs); |
| |
| // The wrapper function will have the same name but with _xla_framework |
| // appended and will be annotated with the attribute "xla_entry". |
| auto outline_func = rewriter.create<func::FuncOp>( |
| loc, func.getSymName().str() + "_xla_framework", func_type, attrs, |
| arg_attrs); |
| outline_func->setAttr("outlined", BoolAttr::get(ctx, true)); |
| outline_func->setAttr("xla_entry", BoolAttr::get(ctx, true)); |
| auto *b = rewriter.createBlock(&outline_func.getBody(), {}, |
| func_type.getInputs(), locs); |
| |
| // Unwrap arguments |
| SmallVector<Value> args; |
| for (const auto &t : llvm::enumerate(func.getFunctionType().getInputs())) { |
| args.push_back(rewriter.create<xla_framework::XLABufferToMemOp>( |
| loc, t.value(), b->getArgument(t.index()))); |
| } |
| |
| auto call = rewriter.create<func::CallOp>( |
| loc, func.getSymName(), func.getFunctionType().getResults(), args); |
| // Wrap results |
| SmallVector<Value> results; |
| for (auto t : call.getResults()) { |
| results.push_back(rewriter.create<xla_framework::MemToXLABufferOp>( |
| loc, ::mlir::xla_framework::BufferType::get(ctx), t)); |
| } |
| |
| rewriter.create<func::ReturnOp>(loc, results); |
| return success(); |
| } |
| }; |
| |
| class OutlineWithXLAFrameworkPass |
| : public OutlineWithXLAFrameworkBase<OutlineWithXLAFrameworkPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<xla_framework::XLAFrameworkDialect, mlir::BuiltinDialect>(); |
| } |
| |
| public: |
| explicit OutlineWithXLAFrameworkPass() {} |
| |
| void runOnOperation() override { |
| ModuleOp m = getOperation(); |
| |
| // Populate type conversions. |
| MLIRContext *ctx = m.getContext(); |
| |
| // Populate patterns. |
| RewritePatternSet patterns(&getContext()); |
| patterns.add<OutlineXLAFunc>(ctx); |
| // Set target. |
| |
| if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| m->walk([](func::FuncOp f) { |
| if (f->hasAttr("outlined")) f->removeAttr("outlined"); |
| }); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp> > CreateOutlineWithXLAFrameworkPass() { |
| return std::make_unique<OutlineWithXLAFrameworkPass>(); |
| } |
| |
| } // namespace mhlo |
| } // namespace mlir |