blob: 3e9f06ccc0f8881f23fd25c7ee14f64db4eb1144 [file] [log] [blame]
//===- Statements.h - MLIR ML Statement Classes -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines classes for special kinds of ML Function statements.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STATEMENTS_H
#define MLIR_IR_STATEMENTS_H
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Statement.h"
#include "mlir/IR/StmtBlock.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
/// Operation statements represent operations inside ML functions.
class OperationStmt final
: public Operation,
public Statement,
private llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult> {
public:
/// Create a new OperationStmt with the specific fields.
static OperationStmt *create(Identifier name, ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
OperationStmt *clone() const;
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
unsigned getNumOperands() const { return numOperands; }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
void setOperand(unsigned idx, MLValue *value) {
return getStmtOperand(idx).set(value);
}
// Support non-const operand iteration.
using operand_iterator = OperandIterator<OperationStmt, MLValue>;
operand_iterator operand_begin() { return operand_iterator(this, 0); }
operand_iterator operand_end() {
return operand_iterator(this, getNumOperands());
}
llvm::iterator_range<operand_iterator> getOperands() {
return {operand_begin(), operand_end()};
}
// Support const operand iteration.
using const_operand_iterator =
OperandIterator<const OperationStmt, const MLValue>;
const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0);
}
const_operand_iterator operand_end() const {
return const_operand_iterator(this, getNumOperands());
}
llvm::iterator_range<const_operand_iterator> getOperands() const {
return {operand_begin(), operand_end()};
}
ArrayRef<StmtOperand> getStmtOperands() const {
return {getTrailingObjects<StmtOperand>(), numOperands};
}
MutableArrayRef<StmtOperand> getStmtOperands() {
return {getTrailingObjects<StmtOperand>(), numOperands};
}
StmtOperand &getStmtOperand(unsigned idx) { return getStmtOperands()[idx]; }
const StmtOperand &getStmtOperand(unsigned idx) const {
return getStmtOperands()[idx];
}
/// This drops all operand uses from this instruction, which is an essential
/// step in breaking cyclic dependences between references when they are to
/// be deleted.
void dropAllReferences();
//===--------------------------------------------------------------------===//
// Results
//===--------------------------------------------------------------------===//
unsigned getNumResults() const { return numResults; }
MLValue *getResult(unsigned idx) { return &getStmtResult(idx); }
const MLValue *getResult(unsigned idx) const { return &getStmtResult(idx); }
// Support non-const result iteration.
typedef ResultIterator<OperationStmt, MLValue> result_iterator;
result_iterator result_begin() { return result_iterator(this, 0); }
result_iterator result_end() {
return result_iterator(this, getNumResults());
}
llvm::iterator_range<result_iterator> getResults() {
return {result_begin(), result_end()};
}
// Support const operand iteration.
typedef ResultIterator<const OperationStmt, const MLValue>
const_result_iterator;
const_result_iterator result_begin() const {
return const_result_iterator(this, 0);
}
const_result_iterator result_end() const {
return const_result_iterator(this, getNumResults());
}
llvm::iterator_range<const_result_iterator> getResults() const {
return {result_begin(), result_end()};
}
ArrayRef<StmtResult> getStmtResults() const {
return {getTrailingObjects<StmtResult>(), numResults};
}
MutableArrayRef<StmtResult> getStmtResults() {
return {getTrailingObjects<StmtResult>(), numResults};
}
StmtResult &getStmtResult(unsigned idx) { return getStmtResults()[idx]; }
const StmtResult &getStmtResult(unsigned idx) const {
return getStmtResults()[idx];
}
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
/// Unlink this statement from its StmtBlock and delete it.
void eraseFromBlock();
void destroy();
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::Operation;
}
static bool classof(const Operation *op) {
return op->getOperationKind() == OperationKind::Statement;
}
private:
const unsigned numOperands, numResults;
OperationStmt(Identifier name, unsigned numOperands, unsigned numResults,
ArrayRef<NamedAttribute> attributes, MLIRContext *context);
~OperationStmt();
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult>;
size_t numTrailingObjects(OverloadToken<StmtOperand>) const {
return numOperands;
}
size_t numTrailingObjects(OverloadToken<StmtResult>) const {
return numResults;
}
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public MLValue, public StmtBlock {
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
explicit ForStmt(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound, AffineConstantExpr *step,
MLIRContext *context);
~ForStmt() {
// Loop bounds and step are immortal objects and don't need to be deleted.
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
// since child statements need to be destroyed before the MLValue that this
// for stmt represents is destroyed.
clear();
}
/// Deep clone this for stmt.
ForStmt *clone() const;
AffineConstantExpr *getLowerBound() const { return lowerBound; }
AffineConstantExpr *getUpperBound() const { return upperBound; }
AffineConstantExpr *getStep() const { return step; }
using Statement::dump;
using Statement::print;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::For;
}
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::For;
}
// For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::ForStmt;
}
private:
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;
AffineConstantExpr *step;
};
/// An if clause represents statements contained within a then or an else clause
/// of an if statement.
class IfClause : public StmtBlock {
public:
explicit IfClause(IfStmt *stmt)
: StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) {
assert(stmt != nullptr && "If clause must have non-null parent");
}
/// Methods for support type inquiry through isa, cast, and dyn_cast
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::IfClause;
}
~IfClause() {}
/// Returns the if statement that contains this clause.
IfStmt *getIf() const { return ifStmt; }
private:
IfStmt *ifStmt;
};
/// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public Statement {
public:
explicit IfStmt()
: Statement(Kind::If), thenClause(new IfClause(this)),
elseClause(nullptr) {}
~IfStmt();
/// Deep clone this IfStmt.
IfStmt *clone() const;
IfClause *getThenClause() const { return thenClause; }
IfClause *getElseClause() const { return elseClause; }
bool hasElseClause() const { return elseClause != nullptr; }
IfClause *createElseClause() { return (elseClause = new IfClause(this)); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::If;
}
private:
IfClause *thenClause;
IfClause *elseClause;
// TODO: Represent IntegerSet condition
};
} //end namespace mlir
#endif // MLIR_IR_STATEMENTS_H