blob: fbb5cded1c92d2eac8d1cbf704596b9ee152f4b4 [file] [log] [blame]
//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines convenience types for working with standard operations
// in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_STANDARDOPS_OPS_H
#define MLIR_STANDARDOPS_OPS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class AffineMap;
class Builder;
namespace detail {
/// A custom binary operation printer that omits the "std." prefix from the
/// operation names.
void printStandardBinaryOp(Operation *op, OpAsmPrinter *p);
} // namespace detail
class StandardOpsDialect : public Dialect {
public:
StandardOpsDialect(MLIRContext *context);
};
#define GET_OP_CLASSES
#include "mlir/StandardOps/Ops.h.inc"
/// The "alloc" operation allocates a region of memory, as specified by its
/// memref type. For example:
///
/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
///
/// The optional list of dimension operands are bound to the dynamic dimensions
/// specified in its memref type. In the example below, the ssa value '%d' is
/// bound to the second dimension of the memref (which is dynamic).
///
/// %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1>
///
/// The optional list of symbol operands are bound to the symbols of the
/// memrefs affine map. In the example below, the ssa value '%s' is bound to
/// the symbol 's0' in the affine map specified in the allocs memref type.
///
/// %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
///
/// This operation returns a single ssa value of memref type, which can be used
/// by subsequent load and store operations.
class AllocOp
: public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public:
using Op::Op;
/// The result of an alloc is always a MemRefType.
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
static StringRef getOperationName() { return "std.alloc"; }
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<Value *> operands = {});
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "br" operation represents a branch operation in a function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types for each successor must match the
/// arguments of the block successor. For example:
///
/// ^bb2:
/// %2 = call @someFn()
/// br ^bb3(%2 : tensor<*xf32>)
/// ^bb3(%3: tensor<*xf32>):
///
class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
public:
using Op::Op;
static StringRef getOperationName() { return "std.br"; }
static void build(Builder *builder, OperationState *result, Block *dest,
ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
/// Return the block this branch jumps to.
Block *getDest();
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
};
/// The "call" operation represents a direct call to a function. The operands
/// and result types of the call must match the specified function type. The
/// callee is encoded as a function attribute named "callee".
///
/// %31 = call @my_add(%0, %1)
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
class CallOp
: public Op<CallOp, OpTrait::VariadicOperands, OpTrait::VariadicResults> {
public:
using Op::Op;
static StringRef getOperationName() { return "std.call"; }
static void build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<Value *> operands);
Function *getCallee() {
return getAttrOfType<FunctionAttr>("callee").getValue();
}
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// The "call_indirect" operation represents an indirect call to a value of
/// function type. Functions are first class types in MLIR, and may be passed
/// as arguments and merged together with block arguments. The operands
/// and result types of the call must match the specified function type.
///
/// %31 = call_indirect %15(%0, %1)
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
///
class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
using Op::Op;
static StringRef getOperationName() { return "std.call_indirect"; }
static void build(Builder *builder, OperationState *result, Value *callee,
ArrayRef<Value *> operands);
Value *getCallee() { return getOperand(0); }
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return ++operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The predicate indicates the type of the comparison to perform:
/// (in)equality; (un)signed less/greater than (or equal to).
enum class CmpIPredicate {
FirstValidValue,
// (In)equality comparisons.
EQ = FirstValidValue,
NE,
// Signed comparisons.
SLT,
SLE,
SGT,
SGE,
// Unsigned comparisons.
ULT,
ULE,
UGT,
UGE,
// Number of predicates.
NumPredicates
};
/// The "cmpi" operation compares its two operands according to the integer
/// comparison rules and the predicate specified by the respective attribute.
/// The predicate defines the type of comparison: (in)equality, (un)signed
/// less/greater than (or equal to). The operands must have the same type, and
/// this type must be an integer type, a vector or a tensor thereof. The result
/// is an i1, or a vector/tensor thereof having the same shape as the inputs.
/// Since integers are signless, the predicate also explicitly indicates
/// whether to interpret the operands as signed or unsigned integers for
/// less/greater than comparisons. For the sake of readability by humans,
/// custom assembly form for the operation uses a string-typed attribute for
/// the predicate. The value of this attribute corresponds to lower-cased name
/// of the predicate constant, e.g., "slt" means "signed less than". The string
/// representation of the attribute is merely a syntactic sugar and is converted
/// to an integer attribute by the parser.
///
/// %r1 = cmpi "eq" %0, %1 : i32
/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
/// %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
class CmpIOp
: public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
CmpIPredicate getPredicate() {
return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
.getInt();
}
static StringRef getOperationName() { return "std.cmpi"; }
static StringRef getPredicateAttrName() { return "predicate"; }
static CmpIPredicate getPredicateByName(StringRef name);
static void build(Builder *builder, OperationState *result, CmpIPredicate,
Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
};
/// The "cond_br" operation represents a conditional branch operation in a
/// function. The operation takes variable number of operands and produces
/// no results. The operand number and types for each successor must match the
// arguments of the block successor. For example:
///
/// ^bb0:
/// %0 = extract_element %arg0[] : tensor<i1>
/// cond_br %0, ^bb1, ^bb2
/// ^bb1:
/// ...
/// ^bb2:
/// ...
///
class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
/// The operands list of a conditional branch operation is layed out as
/// follows:
/// { condition, [true_operands], [false_operands] }
public:
using Op::Op;
static StringRef getOperationName() { return "std.cond_br"; }
static void build(Builder *builder, OperationState *result, Value *condition,
Block *trueDest, ArrayRef<Value *> trueOperands,
Block *falseDest, ArrayRef<Value *> falseOperands);
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
// The condition operand is the first operand in the list.
Value *getCondition() { return getOperand(0); }
/// Return the destination if the condition is true.
Block *getTrueDest();
/// Return the destination if the condition is false.
Block *getFalseDest();
// Accessors for operands to the 'true' destination.
Value *getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
void setTrueOperand(unsigned idx, Value *value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
operand_iterator true_operand_begin() {
return operand_begin() + getTrueDestOperandIndex();
}
operand_iterator true_operand_end() {
return true_operand_begin() + getNumTrueOperands();
}
operand_range getTrueOperands() {
return {true_operand_begin(), true_operand_end()};
}
unsigned getNumTrueOperands();
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index);
// Accessors for operands to the 'false' destination.
Value *getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value *value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_iterator false_operand_begin() { return true_operand_end(); }
operand_iterator false_operand_end() {
return false_operand_begin() + getNumFalseOperands();
}
operand_range getFalseOperands() {
return {false_operand_begin(), false_operand_end()};
}
unsigned getNumFalseOperands();
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index);
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
/// %1 = "std.constant"(){value: 42.0} : bf16
///
class ConstantFloatOp : public ConstantOp {
public:
using ConstantOp::ConstantOp;
/// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type);
APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); }
static bool isClassFor(Operation *op);
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of IntegerType.
///
/// %1 = "std.constant"(){value: 42} : i32
///
class ConstantIntOp : public ConstantOp {
public:
using ConstantOp::ConstantOp;
/// Build a constant int op producing an integer of the specified width.
static void build(Builder *builder, OperationState *result, int64_t value,
unsigned width);
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
static void build(Builder *builder, OperationState *result, int64_t value,
Type type);
int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
static bool isClassFor(Operation *op);
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of Index type.
///
/// %1 = "std.constant"(){value: 99} : () -> index
///
class ConstantIndexOp : public ConstantOp {
public:
using ConstantOp::ConstantOp;
/// Build a constant int op producing an index.
static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
static bool isClassFor(Operation *op);
};
/// The "dealloc" operation frees the region of memory referenced by a memref
/// which was originally created by the "alloc" operation.
/// The "dealloc" operation should not be called on memrefs which alias an
// alloc'd memref (i.e. memrefs returned by the "view" and "reshape"
/// operations).
///
/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
///
/// dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
///
class DeallocOp
: public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public:
using Op::Op;
Value *getMemRef() { return getOperand(); }
void setMemRef(Value *value) { setOperand(value); }
static StringRef getOperationName() { return "std.dealloc"; }
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, Value *memref);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "dim" operation takes a memref or tensor operand and returns an
/// "index". It requires a single integer attribute named "index". It
/// returns the size of the specified dimension. For example:
///
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
///
class DimOp : public Op<DimOp, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result,
Value *memrefOrTensor, unsigned index);
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
/// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() {
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
}
static StringRef getOperationName() { return "std.dim"; }
// Hooks to customize behavior of this op.
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
};
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
// source memref to a destination memref. The source and destination memref need
// not be of the same dimensionality, but need to have the same elemental type.
// The operands include the source and destination memref's each followed by its
// indices, size of the data transfer in terms of the number of elements (of the
// elemental type of the memref), a tag memref with its indices, and optionally
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
// location is used by a DmaWaitOp to check for completion. The indices of the
// source memref, destination memref, and the tag memref have the same
// restrictions as any load/store. The optional stride arguments should be of
// 'index' type, and specify a stride for the slower memory space (memory space
// with a lower memory space id), tranferring chunks of
// number_of_elements_per_stride every stride until %num_elements are
// transferred. Either both or no stride arguments should be specified.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
// 1 at indices [%k, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
// memref<1 x i32>, (d0) -> (d0), 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
// %num_elt_per_stride :
//
// TODO(mlir-team): add additional operands to allow source and destination
// striding, and multiple stride levels.
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
ArrayRef<Value *> srcIndices, Value *destMemRef,
ArrayRef<Value *> destIndices, Value *numElements,
Value *tagMemRef, ArrayRef<Value *> tagIndices,
Value *stride = nullptr,
Value *elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
Value *getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
}
// Returns the source memerf indices for this DMA operation.
operand_range getSrcIndices() {
return {getOperation()->operand_begin() + 1,
getOperation()->operand_begin() + 1 + getSrcMemRefRank()};
}
// Returns the destination MemRefType for this DMA operations.
Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef()->getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
unsigned getDstMemorySpace() {
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1,
getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank()};
}
// Returns the number of elements being transferred by this DMA operation.
Value *getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
Value *getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {getOperation()->operand_begin() + tagIndexStartPos,
getOperation()->operand_begin() + tagIndexStartPos +
getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
static StringRef getOperationName() { return "std.dma_start"; }
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
1 + 1 + getTagMemRefRank();
}
Value *getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
Value *getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
// DmaWaitOp blocks until the completion of a DMA operation associated with the
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
// with the same restrictions as any load/store index. %num_elements is the
// number of elements associated with the DMA operation. For example:
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
// memref<2048 x f32>, (d0) -> (d0), 0>,
// memref<256 x f32>, (d0) -> (d0), 1>
// memref<1 x i32>, (d0) -> (d0), 2>
// ...
// ...
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
//
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
ArrayRef<Value *> tagIndices, Value *numElements);
static StringRef getOperationName() { return "std.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value *getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {getOperation()->operand_begin() + 1,
getOperation()->operand_begin() + 1 + getTagMemRefRank()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "extract_element" op reads a tensor or vector and returns one element
/// from it specified by an index list. The output of extract is a new value
/// with the same type as the elements of the tensor or vector. The arity of
/// indices matches the rank of the accessed value (i.e., if a tensor is of rank
/// 3, then 3 indices are required for the extract). The indices should all be
/// of affine_int type.
///
/// For example:
///
/// %3 = extract_element %0[%1, %2] : vector<4x4xi32>
///
class ExtractElementOp
: public Op<ExtractElementOp, OpTrait::VariadicOperands, OpTrait::OneResult,
OpTrait::HasNoSideEffect> {
public:
using Op::Op;
static void build(Builder *builder, OperationState *result, Value *aggregate,
ArrayRef<Value *> indices = {});
Value *getAggregate() { return getOperand(0); }
operand_range getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
}
static StringRef getOperationName() { return "std.extract_element"; }
// Hooks to customize behavior of this op.
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
};
/// The "load" op reads an element from a memref specified by an index list. The
/// output of load is a new value with the same type as the elements of the
/// memref. The arity of indices is the rank of the memref (i.e., if the memref
/// loaded from is of rank 3, then 3 indices are required for the load following
/// the memref identifier). For example:
///
/// %3 = load %0[%1, %1] : memref<4x4xi32>
///
class LoadOp
: public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public:
using Op::Op;
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<Value *> indices = {});
Value *getMemRef() { return getOperand(0); }
void setMemRef(Value *value) { setOperand(0, value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
operand_range getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
}
static StringRef getOperationName() { return "std.load"; }
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "memref_cast" operation converts a memref from one type to an equivalent
/// type with a compatible shape. The source and destination types are
/// when both are memref types with the same element type, affine mappings,
/// address space, and rank but where the individual dimensions may add or
/// remove constant dimensions from the memref type.
///
/// If the cast converts any dimensions from an unknown to a known size, then it
/// acts as an assertion that fails at runtime of the dynamic dimensions
/// disagree with resultant destination size.
///
/// Assert that the input dynamic shape matches the destination static shape.
/// %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
/// Erase static shape information, replacing it with dynamic information.
/// %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
///
class MemRefCastOp : public CastOp<MemRefCastOp> {
public:
using CastOp::CastOp;
static StringRef getOperationName() { return "std.memref_cast"; }
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a memref_cast is always a memref.
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// The "return" operation represents a return operation within a function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types must match the signature of the function
/// that contains the operation. For example:
///
/// func @foo() : (i32, f8) {
/// ...
/// return %0, %1 : i32, f8
///
class ReturnOp : public Op<ReturnOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
public:
using Op::Op;
static StringRef getOperationName() { return "std.return"; }
static void build(Builder *builder, OperationState *result,
ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// The "select" operation chooses one value based on a binary condition
/// supplied as its first operand. If the value of the first operand is 1, the
/// second operand is chosen, otherwise the third operand is chosen. The second
/// and the third operand must have the same type. The operation applies
/// elementwise to vectors and tensors. The shape of all arguments must be
/// identical. For example, the maximum operation is obtained by combining
/// "select" with "cmpi" as follows.
///
/// %2 = cmpi "gt" %0, %1 : i32 // %2 is i1
/// %3 = select %2, %0, %1 : i32
///
class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
using Op::Op;
static StringRef getOperationName() { return "std.select"; }
static void build(Builder *builder, OperationState *result, Value *condition,
Value *trueValue, Value *falseValue);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Value *getCondition() { return getOperand(0); }
Value *getTrueValue() { return getOperand(1); }
Value *getFalseValue() { return getOperand(2); }
Value *fold();
};
/// The "store" op writes an element to a memref specified by an index list.
/// The arity of indices is the rank of the memref (i.e. if the memref being
/// stored to is of rank 3, then 3 indices are required for the store following
/// the memref identifier). The store operation does not produce a result.
///
/// In the following example, the ssa value '%v' is stored in memref '%A' at
/// indices [%i, %j]:
///
/// store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
///
class StoreOp
: public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
Value *valueToStore, Value *memref,
ArrayRef<Value *> indices = {});
Value *getValueToStore() { return getOperand(0); }
Value *getMemRef() { return getOperand(1); }
void setMemRef(Value *value) { setOperand(1, value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
operand_range getIndices() {
return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
}
static StringRef getOperationName() { return "std.store"; }
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The "tensor_cast" operation converts a tensor from one type to an equivalent
/// type without changing any data elements. The source and destination types
/// must both be tensor types with the same element type. If both are ranked
/// then the rank should be the same and static dimensions should match. The
/// operation is invalid if converting to a mismatching constant dimension.
///
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
/// %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
///
class TensorCastOp : public CastOp<TensorCastOp> {
public:
using CastOp::CastOp;
static StringRef getOperationName() { return "std.tensor_cast"; }
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a tensor_cast is always a tensor.
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// Prints dimension and symbol list.
void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, unsigned numDims,
OpAsmPrinter *p);
/// Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
} // end namespace mlir
#endif // MLIR_STANDARDOPS_OPS_H