blob: 3187f4f80ef75db61b1797df1e39d3d7e5f7eb97 [file] [log] [blame]
//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_LINALG_LINALGOPS_H_
#define MLIR_LINALG_LINALGOPS_H_
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Linalg/IR/LinalgTraits.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class OperationFolder;
namespace linalg {
/// A linalg.LoadOp is the counterpart of load but operating on ViewType
/// instead of MemRefType.
///
/// ```{.mlir}
/// %0 = linalg.load %V[%c0] : !linalg.view<?xf32>
/// ```
class LoadOp
: public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.load"; }
static void build(Builder *b, OperationState *result, Value *view,
ArrayRef<Value *> indices = {});
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
Value *getView() { return getOperand(0); }
Operation::operand_range getIndices() {
return {operand_begin() + 1, operand_end()};
}
};
/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
/// that represent the min, max and step values of the range.
///
/// ```{.mlir}
/// %3 = linalg.range %0:%1:%2 : !linalg.range
/// ```
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.range"; }
static void build(Builder *b, OperationState *result, Value *min, Value *max,
Value *step);
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
Value *min() { return getOperand(0); }
Value *max() { return getOperand(1); }
Value *step() { return getOperand(2); }
};
/// The "linalg.slice" op produces a linalg.view which is a subview of a given
/// base view. This allows defining a subregion within the underlying buffer to
/// operate on only a subset of the buffer.
///
/// A "linalg.slice" op takes a base view and a variadic number of indexings and
/// produces a linalg.view of the same elemental type as the buffer. An indexing
/// is either:
/// 1. a linalg.range, in which case it does not reduce the rank of the parent
/// view.
/// 2. an index, in which case it reduces the rank of the parent view by one.
///
/// The parent view must be a base view (i.e. either a function argument or has
/// been produced by a linalg.view op). In other words, chains of
/// linalg.slice operations cannot be constructed in the IR. This defines away
/// problems related to keeping track of which dimensions of the base view have
/// been rank-reduced.
///
/// Examples:
/// 1. rank-preserving slice:
///
/// ```{.mlir}
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, !linalg.range,
/// !linalg.range, !linalg.view<?x?xf32>
/// ```
///
/// 2. rank-reducing slice (from 2-D to 1-D):
///
/// ```{.mlir}
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index,
/// !linalg.range, !linalg.view<?xf32>
/// ```
///
/// 3. rank-reducing slice (from 2-D to 0-D):
///
/// ```{.mlir}
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index, index,
/// !linalg.view<f32>
/// ```
class ViewOp;
class SliceOp : public Op<SliceOp, OpTrait::VariadicOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.slice"; }
static void build(Builder *b, OperationState *result, Value *base,
llvm::ArrayRef<Value *> indexings);
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
Value *getBaseView() { return getOperand(0); }
ViewOp getBaseViewOp();
ViewType getBaseViewType();
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
// Get the underlying indexing at a given rank.
Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
// Get all the indexings in this view.
Operation::operand_range getIndexings() {
return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
}
// Get the subset of indexings that are of RangeType.
SmallVector<Value *, 8> getRanges();
};
/// A linalg.StoreOp is the counterpart of affine.store but operating on
/// ViewType instead of MemRefType.
///
/// ```{.mlir}
/// linalg.store %f, %V[%c0] : !linalg.view<?xf32>
/// ```
class StoreOp
: public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.store"; }
static void build(Builder *b, OperationState *result, Value *valueToStore,
Value *view, ArrayRef<Value *> indices = {});
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
Value *getValueToStore() { return getOperand(0); }
Value *getView() { return getOperand(1); }
Operation::operand_range getIndices() {
return {operand_begin() + 2, operand_end()};
}
};
/// Returns the name mangled library call name to disambiguate between different
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
/// type names:
/// 1. form a string which is the concatenation of the linalg op name with all
/// the operand type names, separate by underscores;
/// 2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type.
/// Assumes `op` is a LinalgOp.
///
/// Examples:
///
/// 1. linalg.fill(%A, %f) : !linalg.view<f32>, f32
/// name mangles into `linalg_fill_viewf32_f32_impl`
///
/// 2. linalg.dot(%A, %B, %C) :
/// !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl`
///
/// 3. linalg.matmul(...) :
/// !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
std::string generateLibraryCallName(Operation *op);
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
/// Returns the list of maps that map loops to operands of a Linalg op.
/// The i-th affine map identifies loop indices to subscripts that are used when
/// accessing the i-th operand.
/// For instance, a matmul that can be written in index notation as:
/// `A(i, k) * B(k, j) -> C(i, j)` will have the following, ordered, list of
/// affine maps:
///
/// ```{.mlir}
/// (
/// (i, j, k) -> (i, k),
/// (i, j, k) -> (k, j),
/// (i, j, k) -> (i, j)
/// )
/// ```
///
/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
/// A LinalgOp behaves like a base class for the Linalg operations that are
/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
/// method and dispatches to the appropriate LinalgLibraryOp.
/// This allows writing generic passes, like tiling, for all current and future
/// LinalgOps without requiring templating and dispatch in multiple places.
class LinalgOp : public Op<LinalgOp> {
public:
using Op::Op;
LinalgOp(Operation *op) : Op<LinalgOp>(op) {
impl = ModelDispatch<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
>::dispatch(op);
}
static bool classof(Operation *op) {
return ModelDispatch<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
>::classof(op);
}
unsigned getNumParallelLoops() {
return impl->getNumParallelLoops(getOperation());
}
unsigned getNumReductionLoops() {
return impl->getNumReductionLoops(getOperation());
}
unsigned getNumWindowLoops() {
return impl->getNumWindowLoops(getOperation());
}
unsigned getNumLoops() {
return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
}
unsigned getNumInputs() { return impl->getNumInputs(getOperation()); }
unsigned getNumOutputs() { return impl->getNumOutputs(getOperation()); }
unsigned getNumInputsAndOutputs() {
return impl->getNumInputsAndOutputs(getOperation());
}
Value *getInput(unsigned i) { return impl->getInput(getOperation(), i); }
llvm::Optional<unsigned> getIndexOfInput(Value *view) {
return impl->getIndexOfInput(getOperation(), view);
}
ViewType getInputViewType(unsigned i) {
return impl->getInputViewType(getOperation(), i);
}
Operation::operand_range getInputs() {
return impl->getInputs(getOperation());
}
Value *getOutput(unsigned i) { return impl->getOutput(getOperation(), i); }
llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
return impl->getIndexOfOutput(getOperation(), view);
}
ViewType getOutputViewType(unsigned i) {
return impl->getOutputViewType(getOperation(), i);
}
Operation::operand_range getOutputs() {
return impl->getOutputs(getOperation());
}
Operation::operand_range getInputsAndOutputs() {
return impl->getInputsAndOutputs(getOperation());
}
LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes) {
return LinalgOp(impl->create(builder, loc, operands, attributes));
}
private:
struct Concept {
virtual ~Concept() = default;
virtual unsigned getNumInputs(Operation *op) = 0;
virtual unsigned getNumOutputs(Operation *op) = 0;
virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
virtual unsigned getNumParallelLoops(Operation *op) = 0;
virtual unsigned getNumReductionLoops(Operation *op) = 0;
virtual unsigned getNumWindowLoops(Operation *op) = 0;
virtual Value *getInput(Operation *op, unsigned i) = 0;
virtual llvm::Optional<unsigned> getIndexOfInput(Operation *op,
Value *view) = 0;
virtual ViewType getInputViewType(Operation *op, unsigned i) = 0;
virtual Operation::operand_range getInputs(Operation *op) = 0;
virtual Value *getOutput(Operation *op, unsigned i) = 0;
virtual llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
Value *view) = 0;
virtual ViewType getOutputViewType(Operation *op, unsigned i) = 0;
virtual Operation::operand_range getOutputs(Operation *op) = 0;
virtual Operation::operand_range getInputsAndOutputs(Operation *op) = 0;
virtual Operation *create(OpBuilder &builder, Location loc,
ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes) = 0;
};
/// The implementation is inspired from Sean Parent's concept-based
/// polymorphism. A key difference is that the set of classes erased is
/// statically known, which alleviates the need for using dynamic memory
/// allocation.
/// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
/// virtual table and generate a singleton object for each instantiation of
/// this class.
/// We pay the cost of initialization once on construction (find which class
/// to dispatch to) and then a virtual dispatch on every call.
template <typename ConcreteOp> struct Model : public Concept {
static Model<ConcreteOp> &instance() {
static Model<ConcreteOp> singleton;
return singleton;
}
unsigned getNumInputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumInputs();
}
unsigned getNumOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumOutputs();
}
unsigned getNumInputsAndOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumInputsAndOutputs();
}
unsigned getNumParallelLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumParallelLoops();
}
unsigned getNumReductionLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumReductionLoops();
}
unsigned getNumWindowLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumWindowLoops();
}
Value *getInput(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getInput(i);
}
llvm::Optional<unsigned> getIndexOfInput(Operation *op,
Value *view) override {
return cast<ConcreteOp>(op).getIndexOfInput(view);
}
ViewType getInputViewType(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getInputViewType(i);
}
Operation::operand_range getInputs(Operation *op) override {
return cast<ConcreteOp>(op).getInputs();
}
Value *getOutput(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getOutput(i);
}
llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
Value *view) override {
return cast<ConcreteOp>(op).getIndexOfOutput(view);
}
ViewType getOutputViewType(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getOutputViewType(i);
}
Operation::operand_range getOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getOutputs();
}
Operation::operand_range getInputsAndOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getInputsAndOutputs();
}
Operation *create(OpBuilder &builder, Location loc,
ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes) override {
return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
attributes);
}
};
Concept *impl;
template <typename... Types> struct ModelDispatch;
template <typename First, typename... Rest>
struct ModelDispatch<First, Rest...> {
static bool classof(Operation *op) {
return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
}
static Concept *dispatch(Operation *op) {
return isa<First>(op) ? &Model<First>::instance()
: ModelDispatch<Rest...>::dispatch(op);
}
};
template <typename...> struct ModelDispatch {
static bool classof(Operation *op) { return false; }
static Concept *dispatch(Operation *op) {
llvm_unreachable("Invalid LinalgOp");
}
};
};
} // namespace linalg
} // namespace mlir
#endif // MLIR_LINALG_LINALGOPS_H_