blob: 51e489310b4d256899c53ea323c57ea219c707b2 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <vector>
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Analysis/BufferAliasAnalysis.h" // from @llvm-project
#include "mlir/Analysis/Liveness.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
// Needed to build `llvm::EquivalenceClasses` of `mlir::Value`s.
namespace mlir {
static bool operator<(const Value &lhs, const Value &rhs) {
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
}
} // namespace mlir
constexpr llvm::StringRef
mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName;
constexpr llvm::StringRef
mlir::kernel_gen::tf_framework::TFAllocOp::kReuseInputCandidatesAttrName;
constexpr llvm::StringRef
mlir::kernel_gen::tf_framework::TFFrameworkDialect::kTFEntryAttrName;
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
/// A temporary buffer size analysis that is correct but may be incomplete.
class BufferSizeAnalysis {
public:
BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) {
build(f, aliases);
}
bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
private:
void build(FuncOp &f, const BufferAliasAnalysis &aliases) {
auto buffers = find_buffer_values(f);
// Memrefs with statically known same shape and same symbol-free affine maps
// must be of the same size.
int n = buffers.size();
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
Value a = buffers[i];
Value b = buffers[j];
auto a_ty = a.getType().dyn_cast<MemRefType>();
auto b_ty = b.getType().dyn_cast<MemRefType>();
if (a_ty && b_ty && a_ty.hasStaticShape() && b_ty.hasStaticShape() &&
a_ty.getNumElements() == b_ty.getNumElements() &&
a_ty.getElementType() == b_ty.getElementType() &&
affine_maps_symbol_free_and_equal(a_ty.getAffineMaps(),
b_ty.getAffineMaps())) {
ecs_.unionSets(a, b);
}
}
}
// Operands to `linalg.generic` with equal affine maps must be of same size.
f.walk([&](linalg::GenericOp genericOp) {
auto operand_buffers = genericOp.getInputsAndOutputBuffers();
int n = operand_buffers.size();
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
Value a = operand_buffers[i];
Value b = operand_buffers[j];
auto a_ty = a.getType().dyn_cast<MemRefType>();
auto b_ty = b.getType().dyn_cast<MemRefType>();
if (a_ty && b_ty && a_ty.getElementType() == b_ty.getElementType() &&
a_ty.getAffineMaps() == b_ty.getAffineMaps()) {
AffineMap map_i = genericOp.getIndexingMap(i);
AffineMap map_j = genericOp.getIndexingMap(j);
if (map_i == map_j && map_i.isPermutation()) ecs_.unionSets(a, b);
}
}
}
});
// All aliases of a memref must be of the same underlying buffer size.
for (auto e : aliases) {
Value value = e.getFirst();
if (!value.getType().isa<BaseMemRefType>()) continue;
for (Value alias : e.getSecond()) {
assert(alias.getType().isa<BaseMemRefType>() &&
"Expected aliases of memref to be memrefs.");
ecs_.unionSets(value, alias);
}
}
}
bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
ArrayRef<AffineMap> bs) {
auto is_symbol_free = [](AffineMap map) {
return map.getNumSymbols() == 0;
};
return llvm::all_of(as, is_symbol_free) &&
llvm::all_of(bs, is_symbol_free) && as == bs;
}
llvm::SmallVector<Value, 8> find_buffer_values(FuncOp f) {
llvm::SmallVector<Value, 8> buffers;
f.walk([&](Operation *op) {
for (Value val : op->getResults())
if (val.getType().isa<BaseMemRefType>()) buffers.push_back(val);
});
f.walk([&](Block *block) {
for (Value val : block->getArguments()) {
if (val.getType().isa<BaseMemRefType>()) buffers.push_back(val);
}
});
return buffers;
}
llvm::EquivalenceClasses<Value> ecs_;
};
class BufferReuseAnalysis {
public:
explicit BufferReuseAnalysis(FuncOp f) { build(f); }
static constexpr int32_t kIndexAmbiguous = -1;
Optional<SmallVector<int32_t, 2>> get_reuse_candiates(AllocOp op) {
auto it = reuse_candidates_.find(op);
if (it == reuse_candidates_.end()) return llvm::None;
return it->second;
}
Optional<int32_t> get_output_index(AllocOp op) {
auto it = output_indices_.find(op);
if (it == output_indices_.end()) return llvm::None;
return it->second;
}
private:
void build(FuncOp &f) {
BufferAliasAnalysis aliases(f);
find_output_indices(f, aliases);
find_reuse_candiates(f, aliases);
}
void find_output_indices(FuncOp &f, BufferAliasAnalysis &aliases) {
f.walk([&](AllocOp alloc_op) {
int32_t output_index = kIndexAmbiguous;
int count_return_uses = 0;
auto buffer_aliases = aliases.resolve(alloc_op.getResult());
for (Value alias : buffer_aliases) {
for (auto &use : alias.getUses()) {
if (isa<ReturnOp>(use.getOwner())) {
int32_t index = use.getOperandNumber();
if (count_return_uses++ == 0)
output_index = index;
else if (output_index != index)
output_index = kIndexAmbiguous;
}
}
}
output_indices_[alloc_op] = output_index;
});
}
void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
Liveness liveness(f);
BufferSizeAnalysis size_equivalences(f, aliases);
f.walk([&](Block *block) {
find_reuse_candiates(block, aliases, liveness.getLiveness(block),
size_equivalences, f.getArguments());
});
}
void find_reuse_candiates(Block *block, BufferAliasAnalysis &aliases,
const LivenessBlockInfo *liveness,
BufferSizeAnalysis &size_equivalences,
ArrayRef<BlockArgument> arguments) {
for (Operation &op : *block) {
auto alloc_op = dyn_cast<AllocOp>(op);
if (!alloc_op) continue;
// Find first use of the newly allocated buffer within this block.
Value new_buffer = alloc_op.getResult();
Operation *first_reuse = find_first_use_in_block(new_buffer, block);
assert((first_reuse == nullptr || first_reuse->getBlock() == block) &&
"Expected first use in same block if found.");
// Find reuse candidates for the regarded allocation.
SmallVector<int32_t, 2> local_reuse_candidates;
for (BlockArgument old_buffer : arguments) {
if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
// Size criterion: Do not reuse buffers of different size as they may be
// too small.
if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
// Lifetime criterion: Only reuse buffers that are no longer used on
// first reuse, i.e. they are no longer alive.
bool lifetimes_compatible = true;
for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
if (first_reuse == nullptr) {
// If the first use is beyond the end of this block we look at the
// block end. An argument buffer that is already reusable there is
// certainly reusable at any later actual use. Otherwise, lifetimes
// are incompatible.
if (liveness->isLiveOut(old_buffer_alias)) {
lifetimes_compatible = false;
break;
}
} else {
// A buffer is reusable if
// i) its last use is before the point of reuse, or
// ii) its last use is also its first reuse and the operation
// allows for local reuse.
// Otherwise, lifetimes are incompatible.
Operation *last_use =
liveness->getEndOperation(old_buffer_alias, &block->front());
assert(last_use != nullptr && last_use->getBlock() == block &&
"Expected last use in same block.");
if (first_reuse->isBeforeInBlock(last_use)) {
lifetimes_compatible = false;
break;
}
if (first_reuse == last_use &&
!can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
lifetimes_compatible = false;
break;
}
}
}
if (lifetimes_compatible) {
// All criteria are fulfilled 🙂.
int32_t old_buffer_index = old_buffer.getArgNumber();
local_reuse_candidates.push_back(old_buffer_index);
}
}
reuse_candidates_[&op] = local_reuse_candidates;
}
}
Operation *find_first_use_in_block(Value value, Block *block) {
Operation *first_use = nullptr;
for (Operation *op : value.getUsers()) {
Operation *ancestor_op = block->findAncestorOpInBlock(*op);
if (ancestor_op == nullptr) continue;
if (first_use == nullptr || ancestor_op->isBeforeInBlock(first_use))
first_use = ancestor_op;
}
return first_use;
}
std::vector<Value> get_buffer_arguments(FuncOp &f) {
std::vector<Value> buffer_arguments;
for (BlockArgument arg : f.getArguments()) {
if (arg.getType().isa<BaseMemRefType>()) buffer_arguments.push_back(arg);
}
return buffer_arguments;
}
bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) {
// For now, we support only memrefs with the same memory layout.
auto old_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
auto new_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
if (!old_buffer_ty || !new_buffer_ty ||
old_buffer_ty.getAffineMaps() != new_buffer_ty.getAffineMaps())
return false;
if (auto generic_op = dyn_cast<linalg::GenericOp>(op)) {
assert(llvm::find(op->getOperands(), old_buffer) !=
op->getOperands().end() &&
llvm::find(op->getOperands(), new_buffer) !=
op->getOperands().end() &&
"Expect `old/new_buffer` to be operand of `op`.");
// If `linalg.generic` indexing maps are the same for input and output
// buffer then the last use of the input buffer happens before its first
// reuse (per memory location).
auto operand_buffers = generic_op.getInputsAndOutputBuffers();
int old_index =
llvm::find(operand_buffers, old_buffer) - operand_buffers.begin();
int new_index =
llvm::find(operand_buffers, new_buffer) - operand_buffers.begin();
AffineMap old_indexing_map = generic_op.getIndexingMap(old_index);
AffineMap new_indexing_map = generic_op.getIndexingMap(new_index);
return old_indexing_map == new_indexing_map &&
old_indexing_map.isPermutation();
}
return false;
}
DenseMap<Operation *, SmallVector<int32_t, 2>> reuse_candidates_;
DenseMap<Operation *, int32_t> output_indices_;
};
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
void runOnFunction() override {
if (!getFunction().getAttrOfType<UnitAttr>(
tf_framework::TFFrameworkDialect::kTFEntryAttrName))
return;
BufferReuseAnalysis analysis(getFunction());
// Annotate IR with reuse candidates and output indices per allocation.
Builder builder(&getContext());
getFunction().walk([&](AllocOp op) {
if (auto output_index = analysis.get_output_index(op)) {
auto attr = builder.getI32IntegerAttr(*output_index);
op.getOperation()->setAttr(
tf_framework::TFAllocOp::kReuseOutputAttrName, attr);
}
if (auto reuse_candiates = analysis.get_reuse_candiates(op)) {
auto attr = builder.getI32ArrayAttr(*reuse_candiates);
op.getOperation()->setAttr(
tf_framework::TFAllocOp::kReuseInputCandidatesAttrName, attr);
}
});
}
};
} // namespace
std::unique_ptr<FunctionPass> CreateBufferReusePass() {
return std::make_unique<BufferReusePass>();
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir