blob: 42fd62032bd33d4396c305fe7721700934320cd4 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <vector>
#include "llvm/ADT/None.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Analysis/BufferViewFlowAnalysis.h" // from @llvm-project
#include "mlir/Analysis/Liveness.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
constexpr llvm::StringRef
mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName;
constexpr llvm::StringRef
mlir::kernel_gen::tf_framework::TFAllocOp::kReuseInputCandidatesAttrName;
constexpr llvm::StringRef
mlir::kernel_gen::tf_framework::TFFrameworkDialect::kTFEntryAttrName;
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
class BufferReuseAnalysis {
public:
explicit BufferReuseAnalysis(func::FuncOp f) { build(f); }
static constexpr int32_t kIndexAmbiguous = -1;
Optional<SmallVector<int32_t, 2>> get_reuse_candiates(memref::AllocOp op) {
auto it = reuse_candidates_.find(op);
if (it == reuse_candidates_.end()) return llvm::None;
return it->second;
}
Optional<int32_t> get_output_index(memref::AllocOp op) {
auto it = output_indices_.find(op);
if (it == output_indices_.end()) return llvm::None;
return it->second;
}
private:
void build(func::FuncOp &f) {
BufferViewFlowAnalysis aliases(f);
find_output_indices(f, aliases);
find_reuse_candiates(f, aliases);
}
void find_output_indices(func::FuncOp &f, BufferViewFlowAnalysis &aliases) {
f.walk([&](memref::AllocOp alloc_op) {
int32_t output_index = kIndexAmbiguous;
int count_return_uses = 0;
auto buffer_aliases = aliases.resolve(alloc_op.getResult());
for (Value alias : buffer_aliases) {
for (auto &use : alias.getUses()) {
if (isa<func::ReturnOp>(use.getOwner())) {
int32_t index = use.getOperandNumber();
if (count_return_uses++ == 0)
output_index = index;
else if (output_index != index)
output_index = kIndexAmbiguous;
}
}
}
output_indices_[alloc_op] = output_index;
});
}
void find_reuse_candiates(func::FuncOp &f, BufferViewFlowAnalysis &aliases) {
Liveness liveness(f);
f.walk([&](Block *block) {
find_reuse_candiates(block, aliases, liveness.getLiveness(block),
f.getArguments());
});
}
void find_reuse_candiates(Block *block, BufferViewFlowAnalysis &aliases,
const LivenessBlockInfo *liveness,
ArrayRef<BlockArgument> arguments) {
for (Operation &op : *block) {
auto alloc_op = dyn_cast<memref::AllocOp>(op);
if (!alloc_op) continue;
// Find first use of the newly allocated buffer within this block.
Value new_buffer = alloc_op.getResult();
Operation *first_reuse = find_first_use_in_block(new_buffer, block);
assert((first_reuse == nullptr || first_reuse->getBlock() == block) &&
"Expected first use in same block if found.");
// Find reuse candidates for the regarded allocation.
SmallVector<int32_t, 2> local_reuse_candidates;
for (BlockArgument old_buffer : arguments) {
if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
// Lifetime criterion: Only reuse buffers that are no longer used on
// first reuse, i.e. they are no longer alive.
bool lifetimes_compatible = true;
for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
if (first_reuse == nullptr) {
// If the first use is beyond the end of this block we look at the
// block end. An argument buffer that is already reusable there is
// certainly reusable at any later actual use. Otherwise, lifetimes
// are incompatible.
if (liveness->isLiveOut(old_buffer_alias)) {
lifetimes_compatible = false;
break;
}
} else {
// A buffer is reusable if
// i) its last use is before the point of reuse, or
// ii) its last use is also its first reuse and the operation
// allows for local reuse.
// Otherwise, lifetimes are incompatible.
Operation *last_use =
liveness->getEndOperation(old_buffer_alias, &block->front());
assert(last_use != nullptr && last_use->getBlock() == block &&
"Expected last use in same block.");
if (first_reuse->isBeforeInBlock(last_use)) {
lifetimes_compatible = false;
break;
}
if (first_reuse == last_use &&
!can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
lifetimes_compatible = false;
break;
}
}
}
if (lifetimes_compatible) {
// All criteria are fulfilled 🙂.
int32_t old_buffer_index = old_buffer.getArgNumber();
local_reuse_candidates.push_back(old_buffer_index);
}
}
reuse_candidates_[&op] = local_reuse_candidates;
}
}
Operation *find_first_use_in_block(Value value, Block *block) {
Operation *first_use = nullptr;
for (Operation *op : value.getUsers()) {
Operation *ancestor_op = block->findAncestorOpInBlock(*op);
if (ancestor_op == nullptr) continue;
if (first_use == nullptr || ancestor_op->isBeforeInBlock(first_use))
first_use = ancestor_op;
}
return first_use;
}
std::vector<Value> get_buffer_arguments(func::FuncOp &f) {
std::vector<Value> buffer_arguments;
for (BlockArgument arg : f.getArguments()) {
if (arg.getType().isa<BaseMemRefType>()) buffer_arguments.push_back(arg);
}
return buffer_arguments;
}
bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) {
// For now, we support only memrefs with the same memory layout.
auto old_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
auto new_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
if (!old_buffer_ty || !new_buffer_ty ||
old_buffer_ty.getLayout() != new_buffer_ty.getLayout())
return false;
if (auto generic_op = dyn_cast<linalg::GenericOp>(op)) {
SmallVector<OpOperand *> op_operands =
generic_op.getInputAndOutputOperands();
auto old_it = llvm::find_if(op_operands, [&](OpOperand *op_operand) {
return op_operand->get() == old_buffer;
});
auto new_it = llvm::find_if(op_operands, [&](OpOperand *op_operand) {
return op_operand->get() == new_buffer;
});
assert(old_it != op_operands.end() && new_it != op_operands.end() &&
"Expect `old/new_buffer` to be operand of `op`.");
auto is_projection = [](AffineMap map) {
// Allow dropping dimensions but no permutations.
int64_t i = -1;
for (AffineExpr expr : map.getResults()) {
auto dim_expr = expr.dyn_cast<AffineDimExpr>();
if (!dim_expr || dim_expr.getPosition() <= i) return false;
i = dim_expr.getPosition();
}
return true;
};
// If `linalg.generic` indexing maps are the same for input and output
// buffer then the last use of the input buffer happens before its first
// reuse (per memory location). Since we know that the inputs and outputs
// have the same size we also know that when one side has an identity map
// and the other side only drops dimensions, these dimensions have to be
// of size 1.
AffineMap old_indexing_map = generic_op.getTiedIndexingMap(*old_it);
AffineMap new_indexing_map = generic_op.getTiedIndexingMap(*new_it);
return (old_indexing_map == new_indexing_map &&
old_indexing_map.isProjectedPermutation()) ||
(old_indexing_map.isIdentity() &&
is_projection(new_indexing_map)) ||
(is_projection(old_indexing_map) && new_indexing_map.isIdentity());
}
return false;
}
DenseMap<Operation *, SmallVector<int32_t, 2>> reuse_candidates_;
DenseMap<Operation *, int32_t> output_indices_;
};
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
void runOnOperation() override {
if (!getOperation()->getAttrOfType<UnitAttr>(
tf_framework::TFFrameworkDialect::kTFEntryAttrName))
return;
BufferReuseAnalysis analysis(getOperation());
// Annotate IR with reuse candidates and output indices per allocation.
Builder builder(&getContext());
getOperation().walk([&](memref::AllocOp op) {
if (auto output_index = analysis.get_output_index(op)) {
auto attr = builder.getI32IntegerAttr(*output_index);
op.getOperation()->setAttr(
tf_framework::TFAllocOp::kReuseOutputAttrName, attr);
}
if (auto reuse_candiates = analysis.get_reuse_candiates(op)) {
auto attr = builder.getI32ArrayAttr(*reuse_candiates);
op.getOperation()->setAttr(
tf_framework::TFAllocOp::kReuseInputCandidatesAttrName, attr);
}
});
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> CreateBufferReusePass() {
return std::make_unique<BufferReusePass>();
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir