blob: 44e44c821f8d6ca7566fc8919673f05ef89ecd5f [file] [log] [blame]
//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// StmtResult
//===------------------------------------------------------------------===//
/// Return the result number of this result.
unsigned StmtResult::getResultNumber() const {
// Results are always stored consecutively, so use pointer subtraction to
// figure out what number this is.
return this - &getOwner()->getStmtResults()[0];
}
//===----------------------------------------------------------------------===//
// Statement
//===------------------------------------------------------------------===//
// Statements are deleted through the destroy() member because we don't have
// a virtual destructor.
Statement::~Statement() {
assert(block == nullptr && "statement destroyed but still in a block");
}
/// Destroy this statement or one of its subclasses.
void Statement::destroy() {
switch (this->getKind()) {
case Kind::Operation:
cast<OperationStmt>(this)->destroy();
break;
case Kind::For:
delete cast<ForStmt>(this);
break;
case Kind::If:
delete cast<IfStmt>(this);
break;
}
}
Statement *Statement::getParentStmt() const {
return block ? block->getParentStmt() : nullptr;
}
MLFunction *Statement::findFunction() const {
return block ? block->findFunction() : nullptr;
}
bool Statement::isInnermost() const {
struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
unsigned numNestedLoops;
NestedLoopCounter() : numNestedLoops(0) {}
void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
};
NestedLoopCounter nlc;
nlc.walk(const_cast<Statement *>(this));
return nlc.numNestedLoops == 1;
}
Statement *Statement::clone() const {
switch (kind) {
case Kind::Operation:
return cast<OperationStmt>(this)->clone();
case Kind::If:
llvm_unreachable("cloning for if's not implemented yet");
return cast<IfStmt>(this)->clone();
case Kind::For:
llvm_unreachable("cloning for loops not implemented yet");
return cast<ForStmt>(this)->clone();
}
}
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
size_t Offset(
size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
Offset);
}
/// This is a trait method invoked when a statement is added to a block. We
/// keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
assert(!stmt->getBlock() && "already in a statement block!");
stmt->block = getContainingBlock();
}
/// This is a trait method invoked when a statement is removed from a block.
/// We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
Statement *stmt) {
assert(stmt->block && "not already in a statement block!");
stmt->block = nullptr;
}
/// This is a trait method invoked when a statement is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
ilist_traits<Statement> &otherList, stmt_iterator first,
stmt_iterator last) {
// If we are transferring statements within the same block, the block
// pointer doesn't need to be updated.
StmtBlock *curParent = getContainingBlock();
if (curParent == otherList.getContainingBlock())
return;
// Update the 'block' member of each statement.
for (; first != last; ++first)
first->block = curParent;
}
/// Remove this statement (and its descendants) from its StmtBlock and delete
/// all of them.
void Statement::eraseFromBlock() {
assert(getBlock() && "Statement has no block");
getBlock()->getStatements().erase(this);
}
//===----------------------------------------------------------------------===//
// OperationStmt
//===----------------------------------------------------------------------===//
/// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Identifier name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
resultTypes.size());
void *rawMem = malloc(byteSize);
// Initialize the OperationStmt part of the statement.
auto stmt = ::new (rawMem) OperationStmt(
name, operands.size(), resultTypes.size(), attributes, context);
// Initialize the operands and results.
auto stmtOperands = stmt->getStmtOperands();
for (unsigned i = 0, e = operands.size(); i != e; ++i)
new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
auto stmtResults = stmt->getStmtResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
return stmt;
}
/// Clone an existing OperationStmt.
OperationStmt *OperationStmt::clone() const {
SmallVector<MLValue *, 8> operands;
SmallVector<Type *, 8> resultTypes;
// TODO(clattner): switch this to iterator logic.
// Put together operands and results.
for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
operands.push_back(getStmtOperand(i).get());
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
resultTypes.push_back(getStmtResult(i).getType());
return create(getName(), operands, resultTypes, getAttrs(), getContext());
}
OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
unsigned numResults,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
: Operation(name, /*isInstruction=*/false, attributes, context),
Statement(Kind::Operation), numOperands(numOperands),
numResults(numResults) {}
OperationStmt::~OperationStmt() {
// Explicitly run the destructors for the operands and results.
for (auto &operand : getStmtOperands())
operand.~StmtOperand();
for (auto &result : getStmtResults())
result.~StmtResult();
}
void OperationStmt::destroy() {
this->~OperationStmt();
free(this);
}
/// Return the context this operation is associated with.
MLIRContext *OperationStmt::getContext() const {
// If we have a result or operand type, that is a constant time way to get
// to the context.
if (getNumResults())
return getResult(0)->getType()->getContext();
if (getNumOperands())
return getOperand(0)->getType()->getContext();
// In the very odd case where we have no operands or results, fall back to
// doing a find.
return findFunction()->getContext();
}
/// This drops all operand uses from this statement, which is an essential
/// step in breaking cyclic dependences between references when they are to
/// be deleted.
void OperationStmt::dropAllReferences() {
for (auto &op : getStmtOperands())
op.drop();
}
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
AffineConstantExpr *step, MLIRContext *context)
: Statement(Kind::For),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
upperBound(upperBound), step(step) {}
ForStmt *ForStmt::clone() const {
auto *stmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
Statement::findFunction()->getContext());
for (auto &s : getStatements()) {
stmt->getStatements().push_back(s.clone());
}
return stmt;
}
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
IfStmt::~IfStmt() {
delete thenClause;
if (elseClause)
delete elseClause;
}
IfStmt *IfStmt::clone() const {
llvm_unreachable("cloning for if's not implemented yet");
return nullptr;
}