blob: c749af3a1c31da7fc689a7f3ee25f68a288f979c [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
#include <memory>
#include "absl/memory/memory.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:local_config_mlir
#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // TF:local_config_mlir
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:local_config_mlir
#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir
#include "mlir/Dialect/GPU/Passes.h" // TF:local_config_mlir
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // TF:local_config_mlir
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:local_config_mlir
#include "mlir/Dialect/Linalg/Passes.h" // TF:local_config_mlir
#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/Region.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace mlir_gpu {
namespace {
using ::mlir::xla_lhlo::FusionOp;
// Following are some small transformations that are required to clean up code
// after lowering from linalg to loops.
// A simple pass that applies lowering of HLO to LHLO only within Fusion
// operations. This is needed, as FusionOp is not closed from above and hence
// nested pass managers can not be applied.
struct FusionToLhloConverter
: public mlir::FunctionPass<FusionToLhloConverter> {
void runOnFunction() override {
auto& ctx = getContext();
mlir::OwningRewritePatternList patterns;
mlir::ConversionTarget target(ctx);
target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
getFunction().walk([&](FusionOp op) {
if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
signalPassFailure();
}
});
}
};
// Replaces a FusionOp by the operations contained in its region.
struct FusionOpRemover : public mlir::FunctionPass<FusionOpRemover> {
void runOnFunction() override {
getFunction().walk([&](FusionOp op) {
mlir::OpBuilder builder(op);
// FusionOp has a single region with a single block, so we can just walk
// over it and clone operations to the outside.
mlir::BlockAndValueMapping mapping;
for (auto& nested_op : op.region().front().without_terminator()) {
auto clone = builder.clone(nested_op, mapping);
for (auto pair :
llvm::zip(nested_op.getResults(), clone->getResults())) {
mapping.map(std::get<0>(pair), std::get<1>(pair));
}
}
op.erase();
});
}
};
// Rewrite the single-trip loops we get out of linalg into just their bodies.
// TODO(herhut): Make this a general pattern.
struct SingleTripLoopRemoval
: public mlir::FunctionPass<SingleTripLoopRemoval> {
void runOnFunction() override {
auto getConstantValue = [](mlir::Value* value) -> llvm::Optional<int64_t> {
auto definingOp = value->getDefiningOp();
if (!definingOp) return llvm::None;
auto constantOp = llvm::dyn_cast<mlir::ConstantOp>(definingOp);
if (!constantOp) return llvm::None;
auto integer = constantOp.getValue().dyn_cast<mlir::IntegerAttr>();
if (!integer) return llvm::None;
return integer.getInt();
};
getFunction().walk([&](mlir::loop::ForOp forOp) {
auto lower = getConstantValue(forOp.lowerBound());
auto upper = getConstantValue(forOp.upperBound());
auto step = getConstantValue(forOp.step());
if (!lower || !upper || !step) return;
if ((lower.getValue() < upper.getValue()) &&
(lower.getValue() + step.getValue() >= upper.getValue())) {
// This loop has a single trip, so we can move the body in front.
mlir::BlockAndValueMapping mapping;
mlir::OpBuilder b(forOp);
mapping.map(forOp.getInductionVar(), forOp.lowerBound());
for (auto& nested_op : forOp.getBody()->without_terminator()) {
auto clone = b.clone(nested_op, mapping);
for (auto pair :
llvm::zip(nested_op.getResults(), clone->getResults())) {
mapping.map(std::get<0>(pair), std::get<1>(pair));
}
}
forOp.erase();
}
});
}
};
// Simple pass that replaces a load that immediately follows a store to the
// same address with the stored value. This needs generalization.
struct StoreForwardingPass : mlir::FunctionPass<StoreForwardingPass> {
void runOnFunction() override {
getFunction().walk([&](mlir::LoadOp loadOp) {
auto block = loadOp.getOperation()->getBlock();
auto iterator = std::find_if(block->rbegin(), block->rend(),
[&loadOp](mlir::Operation& other) {
return &other == loadOp.getOperation();
});
if (++iterator == block->rend()) return;
mlir::StoreOp storeOp = llvm::dyn_cast<mlir::StoreOp>(&*(iterator));
if (!storeOp) return;
// Check both store to the same value.
if (storeOp.memref() != loadOp.memref()) return;
auto storeIndices = storeOp.getIndices();
auto loadIndices = loadOp.getIndices();
if (!std::equal(storeIndices.begin(), storeIndices.end(),
loadIndices.begin(), loadIndices.end())) {
return;
}
loadOp.replaceAllUsesWith(storeOp.getValueToStore());
loadOp.erase();
});
};
};
// Simple pass that removes temporary buffers that are only written to but
// never read from or that are read but the read value is not used.
// Needs an analysis that proves that loads and stores are side-effect free
// (in bounds, no aliasing, etc.).
struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> {
bool operationConsideredDead(mlir::Operation* op) {
for (auto result : op->getResults()) {
if (!llvm::all_of(result->getUsers(), [&](mlir::Operation* op) {
// Store and Dealloc is OK.
if (llvm::isa<mlir::StoreOp>(op) ||
llvm::isa<mlir::DeallocOp>(op)) {
return true;
}
// Load without uses is also ok.
if (auto loadOp = llvm::dyn_cast<mlir::LoadOp>(op)) {
return loadOp.use_empty();
}
// Subview is ok if it is dead itself.
if (llvm::isa<mlir::SubViewOp>(op)) {
return operationConsideredDead(op);
}
return false;
})) {
return false;
}
}
return true;
}
void recursiveErase(mlir::Operation* op) {
for (auto result : op->getResults()) {
for (auto user : llvm::make_early_inc_range(result->getUsers())) {
recursiveErase(user);
}
}
op->erase();
}
void runOnFunction() override {
getFunction().walk([&](mlir::AllocOp allocOp) {
if (!operationConsideredDead(allocOp)) {
return;
}
// TODO(herhut): There should be a generic helper for this.
recursiveErase(allocOp);
});
}
};
// Neat little helper pass to dump the IR inbetween passes.
struct DumpPass : public mlir::ModulePass<DumpPass> {
void runOnModule() override {
#if DEBUG
getModule().dump();
#endif
}
};
} // namespace
Status LowerLHLOToGPU(mlir::ModuleOp module) {
mlir::PassManager pm(module.getContext());
// First, lower bodies of fusion operations from hlo to lhlo.
pm.addPass(absl::make_unique<FusionToLhloConverter>());
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
// Transform lhlo operations to LinAlg.
pm.addPass(::mlir::xla_lhlo::createLegalizeToLinalgPass());
// Fuse linalg operations. This will yield a single tiled loop nest where
// the inner loops are single trip.
pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg());
pm.addPass(absl::make_unique<DumpPass>());
// Go from linalg to normal loops.
pm.addPass(::mlir::linalg::createConvertLinalgToLoopsPass());
pm.addPass(absl::make_unique<DumpPass>());
// Canonicalize the code to simplify index computations.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addPass(absl::make_unique<DumpPass>());
// The innermost loops will be single-trip.
pm.addPass(absl::make_unique<SingleTripLoopRemoval>());
pm.addPass(absl::make_unique<DumpPass>());
// Run CSE to ensure that loads and stores to the same subview get
// recognized as such.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
pm.addPass(absl::make_unique<DumpPass>());
// Forward stores to buffers to loads.
pm.addPass(absl::make_unique<StoreForwardingPass>());
pm.addPass(absl::make_unique<DumpPass>());
// Remove now unused temporary buffers.
pm.addPass(absl::make_unique<DeadTempBufferRemoval>());
pm.addPass(absl::make_unique<DumpPass>());
// Coalesce generated loops to have 1d loops.
pm.addPass(::mlir::createLoopCoalescingPass());
// Transform the now 1d loops to gpu launches.
pm.addPass(::mlir::createSimpleLoopsToGPUPass(/*numBlockDims=*/0,
/*numThreadDims=*/1));
// Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Take launches to launches with kernels.
pm.addPass(::mlir::createGpuKernelOutliningPass());
if (failed(pm.run(module))) {
return InternalError("Lowering to GPU kernels failed.");
}
return Status::OK();
}
Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
// We cannot verify as the signature of the kernel is rewritten.
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
// Rewrite kernel functions to LLVM IR.
auto& kernelPm = pm.nest<::mlir::ModuleOp>();
kernelPm.addPass(::mlir::createLowerGpuOpsToNVVMOpsPass());
// Some basic cleanup.
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
if (failed(pm.run(module))) {
return InternalError("Lowering to NVVM IR failed.");
}
return Status::OK();
}
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module) {
auto kernelModule = ::mlir::ModuleOp::create(module.getLoc());
// TODO(b/137624192): This also needs to resolve naming conflicts.
module.walk([&kernelModule](mlir::ModuleOp nestedModule) {
if (nestedModule.getAttrOfType<mlir::UnitAttr>(
mlir::gpu::GPUDialect::getKernelModuleAttrName())) {
for (auto& fn : nestedModule) {
kernelModule.push_back(fn.clone());
}
}
});
return kernelModule;
}
} // namespace mlir_gpu
} // namespace xla