blob: 78810da909d128adc11fa12f67e5116df052b7f5 [file] [log] [blame]
//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the base classes for Function's instruction visitors and
// walkers. A visit is a O(1) operation that visits just the node in question. A
// walk visits the node it's called on as well as the node's descendants.
//
// Instruction visitors/walkers are used when you want to perform different
// actions for different kinds of instructions without having to use lots of
// casts and a big switch instruction.
//
// To define your own visitor/walker, inherit from these classes, specifying
// your new type for the 'SubClass' template parameter, and "override" visitXXX
// functions in your class. This class is defined in terms of statically
// resolved overloading, not virtual functions.
//
// For example, here is a walker that counts the number of for loops in an
// Function.
//
// /// Declare the class. Note that we derive from InstWalker instantiated
// /// with _our new subclasses_ type.
// struct LoopCounter : public InstWalker<LoopCounter> {
// unsigned numLoops;
// LoopCounter() : numLoops(0) {}
// void visitForInst(ForInst &fs) { ++numLoops; }
// };
//
// And this class would be used like this:
// LoopCounter lc;
// lc.walk(function);
// numLoops = lc.numLoops;
//
// There are 'visit' methods for OperationInst, ForInst, and
// Function, which recursively process all contained instructions.
//
// Note that if you don't implement visitXXX for some instruction type,
// the visitXXX method for Instruction superclass will be invoked.
//
// The optional second template argument specifies the type that instruction
// visitation functions should return. If you specify this, you *MUST* provide
// an implementation of every visit<#Instruction>(InstType *).
//
// Note that these classes are specifically designed as a template to avoid
// virtual function call overhead. Defining and using a InstVisitor is just
// as efficient as having your own switch instruction over the instruction
// opcode.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_INSTVISITOR_H
#define MLIR_IR_INSTVISITOR_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Instructions.h"
namespace mlir {
/// Base class for instruction visitors.
template <typename SubClass, typename RetTy = void> class InstVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the InstVisitor that you
// use to visit instructions.
public:
// Function to visit a instruction.
RetTy visit(Instruction *s) {
static_assert(std::is_base_of<InstVisitor, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->visitOperationInst(
cast<OperationInst>(s));
}
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//
// When visiting a for inst, if inst, or an operation inst directly, these
// methods get called to indicate when transitioning into a new unit.
void visitForInst(ForInst *forInst) {}
void visitOperationInst(OperationInst *opInst) {}
};
/// Base class for instruction walkers. A walker can traverse depth first in
/// pre-order or post order. The walk methods without a suffix do a pre-order
/// traversal while those that traverse in post order have a PostOrder suffix.
template <typename SubClass, typename RetTy = void> class InstWalker {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the InstWalker used to
// walk instructions.
public:
// Generic walk method - allow walk to all instructions in a range.
template <class Iterator> void walk(Iterator Start, Iterator End) {
while (Start != End) {
walk(&(*Start++));
}
}
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
while (Start != End) {
walkPostOrder(&(*Start++));
}
}
// Define walkers for Function and all Function instruction kinds.
void walk(Function *f) {
for (auto &block : *f)
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
}
void walkPostOrder(Function *f) {
for (auto it = f->rbegin(), e = f->rend(); it != e; ++it)
static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end());
}
void walkOpInst(OperationInst *opInst) {
static_cast<SubClass *>(this)->visitOperationInst(opInst);
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
}
void walkOpInstPostOrder(OperationInst *opInst) {
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
static_cast<SubClass *>(this)->visitOperationInst(opInst);
}
void walkForInst(ForInst *forInst) {
static_cast<SubClass *>(this)->visitForInst(forInst);
auto *body = forInst->getBody();
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
}
void walkForInstPostOrder(ForInst *forInst) {
auto *body = forInst->getBody();
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
static_cast<SubClass *>(this)->visitForInst(forInst);
}
// Function to walk a instruction.
RetTy walk(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
"Must pass the derived type to this template!");
static_cast<SubClass *>(this)->visitInstruction(s);
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
}
}
// Function to walk a instruction in post order DFS.
RetTy walkPostOrder(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
"Must pass the derived type to this template!");
static_cast<SubClass *>(this)->visitInstruction(s);
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInstPostOrder(
cast<ForInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInstPostOrder(
cast<OperationInst>(s));
}
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
// When visiting a specific inst directly during a walk, these methods get
// called. These are typically O(1) complexity and shouldn't be recursively
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitForInst(ForInst *forInst) {}
void visitOperationInst(OperationInst *opInst) {}
void visitInstruction(Instruction *inst) {}
};
} // end namespace mlir
#endif // MLIR_IR_INSTVISITOR_H