blob: d6e82ee27f81138fd1c27f7f8b41e8d51f086db6 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains the analysis and transformation to rewrite kernel
// functions such that they use a single set of arguments for the strides and
// sizes of operands with equal shapes.
#include <memory>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
#define DEBUG_TYPE "kernel-gen-shapes"
namespace {
using mlir::ArrayRef;
using mlir::SmallVector;
using mlir::Value;
/// Represents a value or constant. Used to unify operands for operations that
/// take both ssa values and attributes.
struct ValueOrConst {
explicit ValueOrConst(Value v) : value_or_constant(v), is_constant(false) {}
explicit ValueOrConst(int64_t c) : value_or_constant(c), is_constant(true) {}
Value value() const {
assert(!is_constant);
return value_or_constant.value;
}
int64_t constant() const {
assert(is_constant);
return value_or_constant.constant;
}
bool isConstant() const { return is_constant; }
private:
union ValueOrConstStorage {
explicit ValueOrConstStorage(Value v) : value(v) {}
explicit ValueOrConstStorage(size_t c) : constant(c) {}
Value value;
int64_t constant;
} value_or_constant;
bool is_constant;
};
llvm::hash_code hash_value(ValueOrConst value) {
return value.isConstant() ? static_cast<llvm::hash_code>(value.constant())
: mlir::hash_value(value.value());
}
bool operator==(ValueOrConst lhs, ValueOrConst rhs) {
if (lhs.isConstant()) {
return rhs.isConstant() && lhs.constant() == rhs.constant();
} else {
return !rhs.isConstant() && lhs.value() == rhs.value();
}
}
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ValueOrConst &value) {
if (value.isConstant()) {
os << value.constant();
} else {
Value val = value.value();
mlir::AsmState asm_state(
val.getParentRegion()->getParentOfType<mlir::func::FuncOp>());
val.printAsOperand(os, asm_state);
}
return os;
}
/// Represents a shape, as either a single SSA value that represents the entire
/// shape vector or as a vector of SSA values representing scalars.
struct ShapeValue {
explicit ShapeValue(Value vector)
: shape({ValueOrConst{vector}}), is_vector(true) {}
explicit ShapeValue(ValueOrConst vector) : shape({vector}), is_vector(true) {
assert(!vector.isConstant());
}
template <typename T>
explicit ShapeValue(T values)
: shape(values.begin(), values.end()), is_vector(false) {}
ValueOrConst vector() const {
assert(is_vector);
return shape.front();
}
ArrayRef<ValueOrConst> scalars() const {
assert(!is_vector);
return llvm::makeArrayRef(shape);
}
bool isVector() const { return is_vector; }
private:
SmallVector<ValueOrConst, 4> shape;
bool is_vector;
};
llvm::hash_code hash_value(ShapeValue shape) {
return shape.isVector() ? hash_value(shape.vector())
: hash_value(shape.scalars());
}
bool operator==(ShapeValue lhs, ShapeValue rhs) {
if (lhs.isVector()) {
return rhs.isVector() && lhs.vector() == rhs.vector();
} else {
return !rhs.isVector() && lhs.scalars() == rhs.scalars();
}
}
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ShapeValue &shape) {
if (shape.isVector()) {
os << shape.vector();
return os;
}
os << "[";
bool first = true;
for (auto scalar : shape.scalars()) {
if (!first) {
os << ", ";
}
first = false;
os << scalar;
}
os << "]";
return os;
}
} // namespace
namespace llvm {
template <>
struct DenseMapInfo<ShapeValue> {
static ShapeValue getEmptyKey() {
return ShapeValue(DenseMapInfo<mlir::Value>::getEmptyKey());
}
static ShapeValue getTombstoneKey() {
return ShapeValue(DenseMapInfo<mlir::Value>::getTombstoneKey());
}
static unsigned getHashValue(ShapeValue shape) { return hash_value(shape); }
static bool isEqual(ShapeValue LHS, ShapeValue RHS) { return LHS == RHS; }
};
} // namespace llvm
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
// A basic shape equality inference. This should be superceeded by a proper
// inference once available. Until then, we just build this out to the needs of
// the kernel generator project.
class ShapeEqualityKnowledge {
public:
/// Checks all operations for potential shape equality of their respective
/// results.
void build(func::FuncOp function) {
function.walk([&](Operation *op) {
if (auto reshape = dyn_cast<memref::ReshapeOp>(op)) {
registerAssociation(ShapeValue{reshape.shape()}, reshape.result());
return;
}
if (auto cast = dyn_cast<memref::ReinterpretCastOp>(op)) {
// Only support fully dynamic sizes for now.
// TODO(herhut): Fix once the op has canonicalizers that break this.
for (unsigned int p = 0, e = cast.getResultRank(); p < e; ++p) {
if (!cast.isDynamicSize(p)) {
return;
}
}
registerAssociation(ShapeValue{cast.sizes()}, cast.result());
return;
}
if (auto alloc = dyn_cast<memref::AllocOp>(op)) {
SmallVector<ValueOrConst, 4> shape;
ShapedType type = alloc.getResult().getType().cast<ShapedType>();
fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape);
registerAssociation(ShapeValue{shape}, alloc.getResult());
return;
}
if (auto alloc = dyn_cast<tf_framework::TFAllocOp>(op)) {
// Construct a symbol representing the allocated shape.
SmallVector<ValueOrConst, 4> shape;
ShapedType type = alloc.getResult().getType().cast<ShapedType>();
fillShapeFromAllocLike(alloc.dyn_sizes(), type, shape);
registerAssociation(ShapeValue{shape}, alloc.getResult());
return;
}
});
}
/// Checks whether `one` and `other` are known to have the same shape and
/// strides.
bool haveSameShape(Value one, Value other) {
return equal_shapes_.isEquivalent(one.getAsOpaquePointer(),
other.getAsOpaquePointer());
}
private:
static void fillShapeFromAllocLike(mlir::OperandRange operands,
ShapedType type,
SmallVectorImpl<ValueOrConst> &shape) {
assert(type.hasRank());
auto dynamic_sizes = operands.begin();
for (auto extent : type.getShape()) {
shape.push_back(ShapedType::isDynamic(extent)
? ValueOrConst{*(dynamic_sizes++)}
: ValueOrConst{extent});
}
}
/// Registers the value `value` to have the shape represented by `shape`. If
/// `shape` has been registered before, place `value` into the same
/// equivalence class. Otherwise register `value` as an equivalence class of
/// its own.
void registerAssociation(ShapeValue shape, Value value) {
LLVM_DEBUG({ llvm::dbgs() << "Processing " << value << "\n"; });
auto insert_symbolic = symbolic_shapes_.insert({shape, value});
if (insert_symbolic.second) {
LLVM_DEBUG({ llvm::dbgs() << "New symbolic shape " << shape << "\n"; });
equal_shapes_.insert(value.getAsOpaquePointer());
// We have seen this symbolic shape for the first time. Try to match it
// with a vector or shape we already know and alias classes if possible.
// This could be based on shape dialect if we weren't late in the
// lowering.
tryEvaluateShapeToRoot(shape, value);
} else {
auto rep = insert_symbolic.first->second;
LLVM_DEBUG({ llvm::dbgs() << "Aliasing with rep " << rep << "\n"; });
equal_shapes_.unionSets(rep.getAsOpaquePointer(),
value.getAsOpaquePointer());
}
}
/// Follows the definition chains of the ShapeValue `shape` to identify cases
/// where `shape` is derived from some other value's shape. In such case, the
/// equivalence classes of that other value and `value` are unioned.
/// This is based on pattern matching and not complete.
void tryEvaluateShapeToRoot(ShapeValue shape, Value value) {
// Just some pattern matching for common cases here.
if (!shape.isVector()) {
// Patterns that revolve around scalars.
// Check whether the scalars are all dim operations for some other memref.
Value candidate;
bool all_are_dimops =
llvm::all_of(llvm::enumerate(shape.scalars()), [&candidate](auto p) {
ValueOrConst val = p.value();
if (val.isConstant()) return false;
auto dimOp = val.value().getDefiningOp<memref::DimOp>();
if (!dimOp) return false;
if (!candidate) candidate = dimOp.source();
auto index = dimOp.getConstantIndex();
if (!index.hasValue()) return false;
return candidate == dimOp.source() && p.index() == index.getValue();
});
if (all_are_dimops && candidate) {
equal_shapes_.unionSets(candidate.getAsOpaquePointer(),
value.getAsOpaquePointer());
}
}
}
// These are values with identical shapes (or rather their opaque pointers).
llvm::EquivalenceClasses<void *> equal_shapes_;
// A map from a value that encodes a shape to a value that has this shape.
llvm::DenseMap<ShapeValue, Value> symbolic_shapes_;
};
/// For arguments to kernels that have the same shape, use the stride and
/// shape information of the left-most argument inside of the kernel function.
/// That way, llvm can CSE index computations on same-shaped inputs.
struct PropagateShapeKnowledgeToKernels
: public PropagateShapeKnowledgeToKernelsBase<
PropagateShapeKnowledgeToKernels> {
void runOnOperation() override {
ShapeEqualityKnowledge knowledge;
knowledge.build(getOperation());
getOperation().walk([&](gpu::LaunchFuncOp launch) {
auto module = launch->getParentOfType<ModuleOp>();
auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
if (!kernel || kernel.isExternal()) return;
llvm::SmallVector<std::pair<Value, int>, 4> seen_memrefs;
// Position of the kernel argument we are currently at.
int kernel_p = 0;
for (auto operand : launch.operands()) {
auto memref = operand.getType().dyn_cast<MemRefType>();
if (!memref) {
// Scalar argument, advance kernel position by one.
kernel_p++;
continue;
}
for (auto previous : seen_memrefs) {
if (!knowledge.haveSameShape(operand, previous.first)) {
continue;
}
auto previous_type = previous.first.getType().cast<MemRefType>();
// We use the first equality found and replace uses of corresponding
// size and (potentially) stride information here.
auto args_to_replace = memref.getRank();
// If both memrefs have identity layouts, we can also reuse the
// strides here, as they are the identity strides and hence fully
// determinded by the shape.
if (previous_type.getLayout().isIdentity() &&
memref.getLayout().isIdentity()) {
args_to_replace *= 2;
}
int previous_args_pos = previous.second;
auto previous_args = kernel.getArguments()
.drop_front(previous_args_pos + 3)
.take_front(args_to_replace);
auto current_args = kernel.getArguments()
.drop_front(kernel_p + 3)
.take_front(args_to_replace);
for (auto pair : llvm::zip(previous_args, current_args)) {
mlir::BlockArgument prev, curr;
std::tie(prev, curr) = pair;
curr.replaceAllUsesWith(prev);
}
break;
}
seen_memrefs.push_back({operand, kernel_p});
// Advance base, aligned, offset, strides and sizes many arguments.
kernel_p += memref.getRank() * 2 + 3;
}
});
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
CreatePropagateShapeKnowledgeToKernels() {
return std::make_unique<PropagateShapeKnowledgeToKernels>();
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir