blob: e7d52c288d5c58d3a7ea738d8817affa77971667 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO/Internalize.h"
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
namespace xla {
namespace cpu {
namespace {
// Lower an MLIR module to an LLVM module.
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
mlir::PassManager manager(module->getContext());
manager.addPass(mlir::createConvertLinalgToLoopsPass());
manager.addPass(mlir::createConvertLinalgToLLVMPass());
manager.addPass(mlir::createConvertVectorToLLVMPass());
manager.addPass(mlir::createLowerToLLVMPass());
CHECK(succeeded(manager.run(*module)));
return mlir::translateModuleToLLVMIR(*module);
}
// Get arguments to pass a memref to an mlir function.
void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
llvm::IRBuilder<> *b, const Shape &opShape,
llvm::Value *op_val) {
llvm::Type *ty = op_val->getType();
while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
llvm::cast<llvm::PointerType>(ty)->getElementType())) {
ty = aty->getElementType()->getPointerTo();
}
op_val = b->CreateBitCast(op_val, ty);
args->push_back(op_val); // Allocated pointer.
args->push_back(op_val); // Aligned pointer.
args->push_back(b->getInt64(0)); // Offset.
// Sizes.
for (int64 dim : opShape.dimensions()) {
args->push_back(b->getInt64(dim));
}
int64_t accumulated_stride = 1;
llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
for (int64 dim : LayoutUtil::MinorToMajor(opShape)) {
strides[dim] = accumulated_stride;
accumulated_stride *= opShape.dimensions(dim);
}
// Strides.
for (int64 stride : strides) {
args->push_back(b->getInt64(stride));
}
}
} // namespace
Status EmitMlirFuncAndCall(
mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) {
llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
mlir::Builder mlir_builder(context);
// Get memref types for the inputs and output.
TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
result_shape, mlir_builder));
std::vector<mlir::Type> operand_types = {ret_memref};
for (int i = 0; i != operand_shapes.size(); ++i) {
TF_ASSIGN_OR_RETURN(
mlir::Type op_memref,
ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
operand_types.push_back(op_memref);
}
// Create the function an call the emission callback.
mlir::Location loc = mlir::UnknownLoc::get(context);
auto function = mlir::FuncOp::create(
loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
function.addEntryBlock();
mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
mlir_module->push_back(function);
mlir::OpBuilder op_builder(&function.getBody());
emitter(&op_builder, function);
// Now link it all into the main LLVM module.
auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module));
mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
llvm::Linker::linkModules(
*llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
[](llvm::Module &M, const llvm::StringSet<> &GVS) {
llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
});
// And leave behind a call to the function generated by MLIR.
llvm::Function *func = llvm_module->getFunction(func_name);
llvm::SmallVector<llvm::Value *, 4> op_vals;
BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
for (int i = 0; i != operand_shapes.size(); ++i) {
BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
}
b->CreateCall(func, op_vals);
return Status::OK();
}
} // namespace cpu
} // namespace xla