| //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This file implements the MLIR AsmPrinter class, which is used to implement |
| // the various print() methods on the core IR objects. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/CFGFunction.h" |
| #include "mlir/IR/IntegerSet.h" |
| #include "mlir/IR/MLFunction.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/OperationSet.h" |
| #include "mlir/IR/StandardOps.h" |
| #include "mlir/IR/Statements.h" |
| #include "mlir/IR/StmtVisitor.h" |
| #include "mlir/IR/Types.h" |
| #include "mlir/Support/STLExtras.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/SmallString.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringSet.h" |
| using namespace mlir; |
| |
| void Identifier::print(raw_ostream &os) const { os << str(); } |
| |
| void Identifier::dump() const { print(llvm::errs()); } |
| |
| OpAsmPrinter::~OpAsmPrinter() {} |
| |
| //===----------------------------------------------------------------------===// |
| // ModuleState |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class ModuleState { |
| public: |
| /// This is the operation set for the current context if it is knowable (a |
| /// context could be determined), otherwise this is null. |
| OperationSet *const operationSet; |
| |
| explicit ModuleState(MLIRContext *context) |
| : operationSet(context ? &OperationSet::get(context) : nullptr) {} |
| |
| // Initializes module state, populating affine map state. |
| void initialize(const Module *module); |
| |
| int getAffineMapId(const AffineMap *affineMap) const { |
| auto it = affineMapIds.find(affineMap); |
| if (it == affineMapIds.end()) { |
| return -1; |
| } |
| return it->second; |
| } |
| |
| ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; } |
| |
| int getIntegerSetId(const IntegerSet *integerSet) const { |
| auto it = integerSetIds.find(integerSet); |
| if (it == integerSetIds.end()) { |
| return -1; |
| } |
| return it->second; |
| } |
| |
| ArrayRef<const IntegerSet *> getIntegerSetIds() const { |
| return integerSetsById; |
| } |
| |
| private: |
| void recordAffineMapReference(const AffineMap *affineMap) { |
| if (affineMapIds.count(affineMap) == 0) { |
| affineMapIds[affineMap] = affineMapsById.size(); |
| affineMapsById.push_back(affineMap); |
| } |
| } |
| |
| void recordIntegerSetReference(const IntegerSet *integerSet) { |
| if (integerSetIds.count(integerSet) == 0) { |
| integerSetIds[integerSet] = integerSetsById.size(); |
| integerSetsById.push_back(integerSet); |
| } |
| } |
| |
| // Return true if this map could be printed using the shorthand form. |
| static bool hasShorthandForm(const AffineMap *boundMap) { |
| if (boundMap->isSingleConstant()) |
| return true; |
| |
| // Check if the affine map is single dim id or single symbol identity - |
| // (i)->(i) or ()[s]->(i) |
| return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 && |
| (isa<AffineDimExpr>(boundMap->getResult(0)) || |
| isa<AffineSymbolExpr>(boundMap->getResult(0))); |
| } |
| |
| // Visit functions. |
| void visitFunction(const Function *fn); |
| void visitExtFunction(const ExtFunction *fn); |
| void visitCFGFunction(const CFGFunction *fn); |
| void visitMLFunction(const MLFunction *fn); |
| void visitStatement(const Statement *stmt); |
| void visitForStmt(const ForStmt *forStmt); |
| void visitIfStmt(const IfStmt *ifStmt); |
| void visitOperationStmt(const OperationStmt *opStmt); |
| void visitType(const Type *type); |
| void visitAttribute(const Attribute *attr); |
| void visitOperation(const Operation *op); |
| |
| DenseMap<const AffineMap *, int> affineMapIds; |
| std::vector<const AffineMap *> affineMapsById; |
| |
| DenseMap<const IntegerSet *, int> integerSetIds; |
| std::vector<const IntegerSet *> integerSetsById; |
| }; |
| } // end anonymous namespace |
| |
| // TODO Support visiting other types/instructions when implemented. |
| void ModuleState::visitType(const Type *type) { |
| if (auto *funcType = dyn_cast<FunctionType>(type)) { |
| // Visit input and result types for functions. |
| for (auto *input : funcType->getInputs()) |
| visitType(input); |
| for (auto *result : funcType->getResults()) |
| visitType(result); |
| } else if (auto *memref = dyn_cast<MemRefType>(type)) { |
| // Visit affine maps in memref type. |
| for (auto *map : memref->getAffineMaps()) { |
| recordAffineMapReference(map); |
| } |
| } |
| } |
| |
| void ModuleState::visitAttribute(const Attribute *attr) { |
| if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) { |
| recordAffineMapReference(mapAttr->getValue()); |
| } else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) { |
| for (auto elt : arrayAttr->getValue()) { |
| visitAttribute(elt); |
| } |
| } |
| } |
| |
| void ModuleState::visitOperation(const Operation *op) { |
| // Visit all the types used in the operation. |
| for (auto *operand : op->getOperands()) |
| visitType(operand->getType()); |
| for (auto *result : op->getResults()) |
| visitType(result->getType()); |
| |
| // Visit each of the attributes. |
| for (auto elt : op->getAttrs()) |
| visitAttribute(elt.second); |
| } |
| |
| void ModuleState::visitExtFunction(const ExtFunction *fn) { |
| visitType(fn->getType()); |
| } |
| |
| void ModuleState::visitCFGFunction(const CFGFunction *fn) { |
| visitType(fn->getType()); |
| for (auto &block : *fn) { |
| for (auto &op : block.getOperations()) { |
| visitOperation(&op); |
| } |
| } |
| } |
| |
| void ModuleState::visitIfStmt(const IfStmt *ifStmt) { |
| recordIntegerSetReference(ifStmt->getIntegerSet()); |
| for (auto &childStmt : *ifStmt->getThen()) |
| visitStatement(&childStmt); |
| if (ifStmt->hasElse()) |
| for (auto &childStmt : *ifStmt->getElse()) |
| visitStatement(&childStmt); |
| } |
| |
| void ModuleState::visitForStmt(const ForStmt *forStmt) { |
| AffineMap *lbMap = forStmt->getLowerBoundMap(); |
| if (!hasShorthandForm(lbMap)) |
| recordAffineMapReference(lbMap); |
| |
| AffineMap *ubMap = forStmt->getUpperBoundMap(); |
| if (!hasShorthandForm(ubMap)) |
| recordAffineMapReference(ubMap); |
| |
| for (auto &childStmt : *forStmt) |
| visitStatement(&childStmt); |
| } |
| |
| void ModuleState::visitOperationStmt(const OperationStmt *opStmt) { |
| for (auto attr : opStmt->getAttrs()) |
| visitAttribute(attr.second); |
| } |
| |
| void ModuleState::visitStatement(const Statement *stmt) { |
| switch (stmt->getKind()) { |
| case Statement::Kind::If: |
| return visitIfStmt(cast<IfStmt>(stmt)); |
| case Statement::Kind::For: |
| return visitForStmt(cast<ForStmt>(stmt)); |
| case Statement::Kind::Operation: |
| return visitOperationStmt(cast<OperationStmt>(stmt)); |
| default: |
| return; |
| } |
| } |
| |
| void ModuleState::visitMLFunction(const MLFunction *fn) { |
| visitType(fn->getType()); |
| for (auto &stmt : *fn) { |
| ModuleState::visitStatement(&stmt); |
| } |
| } |
| |
| void ModuleState::visitFunction(const Function *fn) { |
| switch (fn->getKind()) { |
| case Function::Kind::ExtFunc: |
| return visitExtFunction(cast<ExtFunction>(fn)); |
| case Function::Kind::CFGFunc: |
| return visitCFGFunction(cast<CFGFunction>(fn)); |
| case Function::Kind::MLFunc: |
| return visitMLFunction(cast<MLFunction>(fn)); |
| } |
| } |
| |
| // Initializes module state, populating affine map and integer set state. |
| void ModuleState::initialize(const Module *module) { |
| for (auto &fn : *module) { |
| visitFunction(&fn); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ModulePrinter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class ModulePrinter { |
| public: |
| ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {} |
| explicit ModulePrinter(const ModulePrinter &printer) |
| : os(printer.os), state(printer.state) {} |
| |
| template <typename Container, typename UnaryFunctor> |
| inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { |
| interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; }); |
| } |
| |
| void print(const Module *module); |
| void printFunctionReference(const Function *func); |
| void printAttribute(const Attribute *attr); |
| void printType(const Type *type); |
| void print(const Function *fn); |
| void print(const ExtFunction *fn); |
| void print(const CFGFunction *fn); |
| void print(const MLFunction *fn); |
| |
| void printAffineMap(const AffineMap *map); |
| void printAffineExpr(const AffineExpr *expr); |
| void printAffineConstraint(const AffineExpr *expr, bool isEq); |
| void printIntegerSet(const IntegerSet *set); |
| |
| protected: |
| raw_ostream &os; |
| ModuleState &state; |
| |
| void printFunctionSignature(const Function *fn); |
| void printFunctionResultType(const FunctionType *type); |
| void printAffineMapId(int affineMapId) const; |
| void printAffineMapReference(const AffineMap *affineMap); |
| void printIntegerSetId(int integerSetId) const; |
| void printIntegerSetReference(const IntegerSet *integerSet); |
| |
| /// This enum is used to represent the binding stength of the enclosing |
| /// context that an AffineExpr is being printed in, so we can intelligently |
| /// produce parens. |
| enum class BindingStrength { |
| Weak, // + and - |
| Strong, // All other binary operators. |
| }; |
| void printAffineExprInternal(const AffineExpr *expr, |
| BindingStrength enclosingTightness); |
| }; |
| } // end anonymous namespace |
| |
| // Prints function with initialized module state. |
| void ModulePrinter::print(const Function *fn) { |
| switch (fn->getKind()) { |
| case Function::Kind::ExtFunc: |
| return print(cast<ExtFunction>(fn)); |
| case Function::Kind::CFGFunc: |
| return print(cast<CFGFunction>(fn)); |
| case Function::Kind::MLFunc: |
| return print(cast<MLFunction>(fn)); |
| } |
| } |
| |
| // Prints affine map identifier. |
| void ModulePrinter::printAffineMapId(int affineMapId) const { |
| os << "#map" << affineMapId; |
| } |
| |
| void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) { |
| int mapId = state.getAffineMapId(affineMap); |
| if (mapId >= 0) { |
| // Map will be printed at top of module so print reference to its id. |
| printAffineMapId(mapId); |
| } else { |
| // Map not in module state so print inline. |
| affineMap->print(os); |
| } |
| } |
| |
| // Prints integer set identifier. |
| void ModulePrinter::printIntegerSetId(int integerSetId) const { |
| os << "@@set" << integerSetId; |
| } |
| |
| void ModulePrinter::printIntegerSetReference(const IntegerSet *integerSet) { |
| int setId; |
| if ((setId = state.getIntegerSetId(integerSet)) >= 0) { |
| // The set will be printed at top of module; so print reference to its id. |
| printIntegerSetId(setId); |
| } else { |
| // Set not in module state so print inline. |
| integerSet->print(os); |
| } |
| } |
| |
| void ModulePrinter::print(const Module *module) { |
| for (const auto &map : state.getAffineMapIds()) { |
| printAffineMapId(state.getAffineMapId(map)); |
| os << " = "; |
| map->print(os); |
| os << '\n'; |
| } |
| for (const auto &set : state.getIntegerSetIds()) { |
| printIntegerSetId(state.getIntegerSetId(set)); |
| os << " = "; |
| set->print(os); |
| os << '\n'; |
| } |
| for (auto const &fn : *module) |
| print(&fn); |
| } |
| |
| /// Print a floating point value in a way that the parser will be able to |
| /// round-trip losslessly. |
| static void printFloatValue(double value, raw_ostream &os) { |
| APFloat apValue(value); |
| |
| // We would like to output the FP constant value in exponential notation, |
| // but we cannot do this if doing so will lose precision. Check here to |
| // make sure that we only output it in exponential format if we can parse |
| // the value back and get the same value. |
| bool isInf = apValue.isInfinity(); |
| bool isNaN = apValue.isNaN(); |
| if (!isInf && !isNaN) { |
| SmallString<128> strValue; |
| apValue.toString(strValue, 6, 0, false); |
| |
| // Check to make sure that the stringized number is not some string like |
| // "Inf" or NaN, that atof will accept, but the lexer will not. Check |
| // that the string matches the "[-+]?[0-9]" regex. |
| assert(((strValue[0] >= '0' && strValue[0] <= '9') || |
| ((strValue[0] == '-' || strValue[0] == '+') && |
| (strValue[1] >= '0' && strValue[1] <= '9'))) && |
| "[-+]?[0-9] regex does not match!"); |
| // Reparse stringized version! |
| if (APFloat(APFloat::IEEEdouble(), strValue).convertToDouble() == value) { |
| os << strValue; |
| return; |
| } |
| } |
| |
| // Otherwise, print it in a hexadecimal form. Convert it to an integer so we |
| // can print it out using integer math. |
| union { |
| double doubleValue; |
| uint64_t integerValue; |
| }; |
| doubleValue = value; |
| os << "0x"; |
| // Print out 16 nibbles worth of hex digit. |
| for (unsigned i = 0; i != 16; ++i) { |
| os << llvm::hexdigit(integerValue >> 60); |
| integerValue <<= 4; |
| } |
| } |
| |
| void ModulePrinter::printFunctionReference(const Function *func) { |
| os << '@' << func->getName(); |
| } |
| |
| void ModulePrinter::printAttribute(const Attribute *attr) { |
| switch (attr->getKind()) { |
| case Attribute::Kind::Bool: |
| os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false"); |
| break; |
| case Attribute::Kind::Integer: |
| os << cast<IntegerAttr>(attr)->getValue(); |
| break; |
| case Attribute::Kind::Float: |
| printFloatValue(cast<FloatAttr>(attr)->getValue(), os); |
| break; |
| case Attribute::Kind::String: |
| os << '"'; |
| printEscapedString(cast<StringAttr>(attr)->getValue(), os); |
| os << '"'; |
| break; |
| case Attribute::Kind::Array: |
| os << '['; |
| interleaveComma(cast<ArrayAttr>(attr)->getValue(), |
| [&](Attribute *attr) { printAttribute(attr); }); |
| os << ']'; |
| break; |
| case Attribute::Kind::AffineMap: |
| printAffineMapReference(cast<AffineMapAttr>(attr)->getValue()); |
| break; |
| case Attribute::Kind::Type: |
| printType(cast<TypeAttr>(attr)->getValue()); |
| break; |
| case Attribute::Kind::Function: { |
| auto *function = cast<FunctionAttr>(attr)->getValue(); |
| if (!function) { |
| os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>"; |
| } else { |
| printFunctionReference(function); |
| os << " : "; |
| printType(function->getType()); |
| } |
| break; |
| } |
| } |
| } |
| |
| void ModulePrinter::printType(const Type *type) { |
| switch (type->getKind()) { |
| case Type::Kind::AffineInt: |
| os << "affineint"; |
| return; |
| case Type::Kind::BF16: |
| os << "bf16"; |
| return; |
| case Type::Kind::F16: |
| os << "f16"; |
| return; |
| case Type::Kind::F32: |
| os << "f32"; |
| return; |
| case Type::Kind::F64: |
| os << "f64"; |
| return; |
| case Type::Kind::TFControl: |
| os << "tf_control"; |
| return; |
| case Type::Kind::TFString: |
| os << "tf_string"; |
| return; |
| |
| case Type::Kind::Integer: { |
| auto *integer = cast<IntegerType>(type); |
| os << 'i' << integer->getWidth(); |
| return; |
| } |
| case Type::Kind::Function: { |
| auto *func = cast<FunctionType>(type); |
| os << '('; |
| interleaveComma(func->getInputs(), [&](Type *type) { printType(type); }); |
| os << ") -> "; |
| auto results = func->getResults(); |
| if (results.size() == 1) |
| os << *results[0]; |
| else { |
| os << '('; |
| interleaveComma(results, [&](Type *type) { printType(type); }); |
| os << ')'; |
| } |
| return; |
| } |
| case Type::Kind::Vector: { |
| auto *v = cast<VectorType>(type); |
| os << "vector<"; |
| for (auto dim : v->getShape()) |
| os << dim << 'x'; |
| os << *v->getElementType() << '>'; |
| return; |
| } |
| case Type::Kind::RankedTensor: { |
| auto *v = cast<RankedTensorType>(type); |
| os << "tensor<"; |
| for (auto dim : v->getShape()) { |
| if (dim < 0) |
| os << '?'; |
| else |
| os << dim; |
| os << 'x'; |
| } |
| os << *v->getElementType() << '>'; |
| return; |
| } |
| case Type::Kind::UnrankedTensor: { |
| auto *v = cast<UnrankedTensorType>(type); |
| os << "tensor<*x"; |
| printType(v->getElementType()); |
| os << '>'; |
| return; |
| } |
| case Type::Kind::MemRef: { |
| auto *v = cast<MemRefType>(type); |
| os << "memref<"; |
| for (auto dim : v->getShape()) { |
| if (dim < 0) |
| os << '?'; |
| else |
| os << dim; |
| os << 'x'; |
| } |
| printType(v->getElementType()); |
| for (auto map : v->getAffineMaps()) { |
| os << ", "; |
| printAffineMapReference(map); |
| } |
| // Only print the memory space if it is the non-default one. |
| if (v->getMemorySpace()) |
| os << ", " << v->getMemorySpace(); |
| os << '>'; |
| return; |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Affine expressions and maps |
| //===----------------------------------------------------------------------===// |
| |
| void ModulePrinter::printAffineExpr(const AffineExpr *expr) { |
| printAffineExprInternal(expr, BindingStrength::Weak); |
| } |
| |
| void ModulePrinter::printAffineExprInternal( |
| const AffineExpr *expr, BindingStrength enclosingTightness) { |
| const char *binopSpelling = nullptr; |
| switch (expr->getKind()) { |
| case AffineExpr::Kind::SymbolId: |
| os << 's' << cast<AffineSymbolExpr>(expr)->getPosition(); |
| return; |
| case AffineExpr::Kind::DimId: |
| os << 'd' << cast<AffineDimExpr>(expr)->getPosition(); |
| return; |
| case AffineExpr::Kind::Constant: |
| os << cast<AffineConstantExpr>(expr)->getValue(); |
| return; |
| case AffineExpr::Kind::Add: |
| binopSpelling = " + "; |
| break; |
| case AffineExpr::Kind::Mul: |
| binopSpelling = " * "; |
| break; |
| case AffineExpr::Kind::FloorDiv: |
| binopSpelling = " floordiv "; |
| break; |
| case AffineExpr::Kind::CeilDiv: |
| binopSpelling = " ceildiv "; |
| break; |
| case AffineExpr::Kind::Mod: |
| binopSpelling = " mod "; |
| break; |
| } |
| |
| auto *binOp = cast<AffineBinaryOpExpr>(expr); |
| |
| // Handle tightly binding binary operators. |
| if (binOp->getKind() != AffineExpr::Kind::Add) { |
| if (enclosingTightness == BindingStrength::Strong) |
| os << '('; |
| |
| printAffineExprInternal(binOp->getLHS(), BindingStrength::Strong); |
| os << binopSpelling; |
| printAffineExprInternal(binOp->getRHS(), BindingStrength::Strong); |
| |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| |
| // Print out special "pretty" forms for add. |
| if (enclosingTightness == BindingStrength::Strong) |
| os << '('; |
| |
| // Pretty print addition to a product that has a negative operand as a |
| // subtraction. |
| if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(binOp->getRHS())) { |
| if (rhs->getKind() == AffineExpr::Kind::Mul) { |
| if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) { |
| if (rrhs->getValue() == -1) { |
| printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak); |
| os << " - "; |
| printAffineExprInternal(rhs->getLHS(), BindingStrength::Weak); |
| |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| |
| if (rrhs->getValue() < -1) { |
| printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak); |
| os << " - "; |
| printAffineExprInternal(rhs->getLHS(), BindingStrength::Strong); |
| os << " * " << -rrhs->getValue(); |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| } |
| } |
| } |
| |
| // Pretty print addition to a negative number as a subtraction. |
| if (auto *rhs = dyn_cast<AffineConstantExpr>(binOp->getRHS())) { |
| if (rhs->getValue() < 0) { |
| printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak); |
| os << " - " << -rhs->getValue(); |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| } |
| |
| printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak); |
| os << " + "; |
| printAffineExprInternal(binOp->getRHS(), BindingStrength::Weak); |
| |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| } |
| |
| void ModulePrinter::printAffineConstraint(const AffineExpr *expr, bool isEq) { |
| printAffineExprInternal(expr, BindingStrength::Weak); |
| isEq ? os << " == 0" : os << " >= 0"; |
| } |
| |
| void ModulePrinter::printAffineMap(const AffineMap *map) { |
| // Dimension identifiers. |
| os << '('; |
| for (int i = 0; i < (int)map->getNumDims() - 1; ++i) |
| os << 'd' << i << ", "; |
| if (map->getNumDims() >= 1) |
| os << 'd' << map->getNumDims() - 1; |
| os << ')'; |
| |
| // Symbolic identifiers. |
| if (map->getNumSymbols() != 0) { |
| os << '['; |
| for (unsigned i = 0; i < map->getNumSymbols() - 1; ++i) |
| os << 's' << i << ", "; |
| if (map->getNumSymbols() >= 1) |
| os << 's' << map->getNumSymbols() - 1; |
| os << ']'; |
| } |
| |
| // AffineMap should have at least one result. |
| assert(!map->getResults().empty()); |
| // Result affine expressions. |
| os << " -> ("; |
| interleaveComma(map->getResults(), |
| [&](AffineExpr *expr) { printAffineExpr(expr); }); |
| os << ')'; |
| |
| if (!map->isBounded()) { |
| return; |
| } |
| |
| // Print range sizes for bounded affine maps. |
| os << " size ("; |
| interleaveComma(map->getRangeSizes(), |
| [&](AffineExpr *expr) { printAffineExpr(expr); }); |
| os << ')'; |
| } |
| |
| void ModulePrinter::printIntegerSet(const IntegerSet *set) { |
| // Dimension identifiers. |
| os << '('; |
| for (unsigned i = 1; i < set->getNumDims(); ++i) |
| os << 'd' << i - 1 << ", "; |
| if (set->getNumDims() >= 1) |
| os << 'd' << set->getNumDims() - 1; |
| os << ')'; |
| |
| // Symbolic identifiers. |
| if (set->getNumSymbols() != 0) { |
| os << '['; |
| for (unsigned i = 0; i < set->getNumSymbols() - 1; ++i) |
| os << 's' << i << ", "; |
| if (set->getNumSymbols() >= 1) |
| os << 's' << set->getNumSymbols() - 1; |
| os << ']'; |
| } |
| |
| // Print constraints. |
| os << " : ("; |
| auto numConstraints = set->getNumConstraints(); |
| for (int i = 1; i < numConstraints; ++i) { |
| printAffineConstraint(set->getConstraint(i - 1), set->isEq(i - 1)); |
| os << ", "; |
| } |
| if (numConstraints >= 1) |
| printAffineConstraint(set->getConstraint(numConstraints - 1), |
| set->isEq(numConstraints - 1)); |
| os << ')'; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Function printing |
| //===----------------------------------------------------------------------===// |
| |
| void ModulePrinter::printFunctionResultType(const FunctionType *type) { |
| switch (type->getResults().size()) { |
| case 0: |
| break; |
| case 1: |
| os << " -> "; |
| printType(type->getResults()[0]); |
| break; |
| default: |
| os << " -> ("; |
| interleaveComma(type->getResults(), |
| [&](Type *eltType) { printType(eltType); }); |
| os << ')'; |
| break; |
| } |
| } |
| |
| void ModulePrinter::printFunctionSignature(const Function *fn) { |
| auto type = fn->getType(); |
| |
| os << "@" << fn->getName() << '('; |
| interleaveComma(type->getInputs(), |
| [&](Type *eltType) { printType(eltType); }); |
| os << ')'; |
| |
| printFunctionResultType(type); |
| } |
| |
| void ModulePrinter::print(const ExtFunction *fn) { |
| os << "extfunc "; |
| printFunctionSignature(fn); |
| os << '\n'; |
| } |
| |
| namespace { |
| |
| // FunctionPrinter contains common functionality for printing |
| // CFG and ML functions. |
| class FunctionPrinter : public ModulePrinter, private OpAsmPrinter { |
| public: |
| FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {} |
| |
| void printOperation(const Operation *op); |
| void printDefaultOp(const Operation *op); |
| |
| // Implement OpAsmPrinter. |
| raw_ostream &getStream() const { return os; } |
| void printType(const Type *type) { ModulePrinter::printType(type); } |
| void printAttribute(const Attribute *attr) { |
| ModulePrinter::printAttribute(attr); |
| } |
| void printAffineMap(const AffineMap *map) { |
| return ModulePrinter::printAffineMapReference(map); |
| } |
| void printIntegerSet(const IntegerSet *set) { |
| return ModulePrinter::printIntegerSetReference(set); |
| } |
| void printAffineExpr(const AffineExpr *expr) { |
| return ModulePrinter::printAffineExpr(expr); |
| } |
| void printFunctionReference(const Function *func) { |
| return ModulePrinter::printFunctionReference(func); |
| } |
| |
| void printOperand(const SSAValue *value) { printValueID(value); } |
| |
| void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| ArrayRef<const char *> elidedAttrs = {}) override; |
| |
| enum { nameSentinel = ~0U }; |
| |
| protected: |
| void numberValueID(const SSAValue *value) { |
| assert(!valueIDs.count(value) && "Value numbered multiple times"); |
| |
| SmallString<32> specialNameBuffer; |
| llvm::raw_svector_ostream specialName(specialNameBuffer); |
| |
| // Give constant integers special names. |
| if (auto *op = value->getDefiningOperation()) { |
| if (auto intOp = op->getAs<ConstantIntOp>()) { |
| // i1 constants get special names. |
| if (intOp->getType()->isInteger(1)) { |
| specialName << (intOp->getValue() ? "true" : "false"); |
| } else { |
| specialName << 'c' << intOp->getValue() << '_' << *intOp->getType(); |
| } |
| } else if (auto intOp = op->getAs<ConstantAffineIntOp>()) { |
| specialName << 'c' << intOp->getValue(); |
| } else if (auto constant = op->getAs<ConstantOp>()) { |
| if (isa<FunctionAttr>(constant->getValue())) |
| specialName << 'f'; |
| else |
| specialName << "cst"; |
| } |
| } |
| |
| if (specialNameBuffer.empty()) { |
| switch (value->getKind()) { |
| case SSAValueKind::BBArgument: |
| // If this is an argument to the function, give it an 'arg' name. |
| if (auto *bb = cast<BBArgument>(value)->getOwner()) |
| if (auto *fn = bb->getFunction()) |
| if (&fn->front() == bb) { |
| specialName << "arg" << nextArgumentID++; |
| break; |
| } |
| // Otherwise number it normally. |
| LLVM_FALLTHROUGH; |
| case SSAValueKind::InstResult: |
| case SSAValueKind::StmtResult: |
| // This is an uninteresting result, give it a boring number and be |
| // done with it. |
| valueIDs[value] = nextValueID++; |
| return; |
| case SSAValueKind::MLFuncArgument: |
| specialName << "arg" << nextArgumentID++; |
| break; |
| case SSAValueKind::ForStmt: |
| specialName << 'i' << nextLoopID++; |
| break; |
| } |
| } |
| |
| // Ok, this value had an interesting name. Remember it with a sentinel. |
| valueIDs[value] = nameSentinel; |
| |
| // Remember that we've used this name, checking to see if we had a conflict. |
| auto insertRes = usedNames.insert(specialName.str()); |
| if (insertRes.second) { |
| // If this is the first use of the name, then we're successful! |
| valueNames[value] = insertRes.first->first(); |
| return; |
| } |
| |
| // Otherwise, we had a conflict - probe until we find a unique name. This |
| // is guaranteed to terminate (and usually in a single iteration) because it |
| // generates new names by incrementing nextConflictID. |
| while (1) { |
| std::string probeName = |
| specialName.str().str() + "_" + llvm::utostr(nextConflictID++); |
| insertRes = usedNames.insert(probeName); |
| if (insertRes.second) { |
| // If this is the first use of the name, then we're successful! |
| valueNames[value] = insertRes.first->first(); |
| return; |
| } |
| } |
| } |
| |
| void printValueID(const SSAValue *value, bool printResultNo = true) const { |
| int resultNo = -1; |
| auto lookupValue = value; |
| |
| // If this is a reference to the result of a multi-result instruction or |
| // statement, print out the # identifier and make sure to map our lookup |
| // to the first result of the instruction. |
| if (auto *result = dyn_cast<InstResult>(value)) { |
| if (result->getOwner()->getNumResults() != 1) { |
| resultNo = result->getResultNumber(); |
| lookupValue = result->getOwner()->getResult(0); |
| } |
| } else if (auto *result = dyn_cast<StmtResult>(value)) { |
| if (result->getOwner()->getNumResults() != 1) { |
| resultNo = result->getResultNumber(); |
| lookupValue = result->getOwner()->getResult(0); |
| } |
| } |
| |
| auto it = valueIDs.find(lookupValue); |
| if (it == valueIDs.end()) { |
| os << "<<INVALID SSA VALUE>>"; |
| return; |
| } |
| |
| os << '%'; |
| if (it->second != nameSentinel) { |
| os << it->second; |
| } else { |
| auto nameIt = valueNames.find(lookupValue); |
| assert(nameIt != valueNames.end() && "Didn't have a name entry?"); |
| os << nameIt->second; |
| } |
| |
| if (resultNo != -1 && printResultNo) |
| os << '#' << resultNo; |
| } |
| |
| private: |
| /// This is the value ID for each SSA value in the current function. If this |
| /// returns ~0, then the valueID has an entry in valueNames. |
| DenseMap<const SSAValue *, unsigned> valueIDs; |
| DenseMap<const SSAValue *, StringRef> valueNames; |
| |
| /// This keeps track of all of the non-numeric names that are in flight, |
| /// allowing us to check for duplicates. |
| llvm::StringSet<> usedNames; |
| |
| /// This is the next value ID to assign in numbering. |
| unsigned nextValueID = 0; |
| /// This is the ID to assign to the next induction variable. |
| unsigned nextLoopID = 0; |
| /// This is the next ID to assign to an MLFunction argument. |
| unsigned nextArgumentID = 0; |
| |
| /// This is the next ID to assign when a name conflict is detected. |
| unsigned nextConflictID = 0; |
| }; |
| } // end anonymous namespace |
| |
| void FunctionPrinter::printOptionalAttrDict( |
| ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs) { |
| // If there are no attributes, then there is nothing to be done. |
| if (attrs.empty()) |
| return; |
| |
| // Filter out any attributes that shouldn't be included. |
| SmallVector<NamedAttribute, 8> filteredAttrs; |
| for (auto attr : attrs) { |
| auto attrName = attr.first.strref(); |
| // Never print attributes that start with a colon. These are internal |
| // attributes that represent location or other internal metadata. |
| if (attrName.startswith(":")) |
| continue; |
| |
| // If the caller has requested that this attribute be ignored, then drop it. |
| bool ignore = false; |
| for (const char *elide : elidedAttrs) |
| ignore |= attrName == StringRef(elide); |
| |
| // Otherwise add it to our filteredAttrs list. |
| if (!ignore) |
| filteredAttrs.push_back(attr); |
| } |
| |
| // If there are no attributes left to print after filtering, then we're done. |
| if (filteredAttrs.empty()) |
| return; |
| |
| // Otherwise, print them all out in braces. |
| os << " {"; |
| interleaveComma(filteredAttrs, [&](NamedAttribute attr) { |
| os << attr.first << ": "; |
| printAttribute(attr.second); |
| }); |
| os << '}'; |
| } |
| |
| void FunctionPrinter::printOperation(const Operation *op) { |
| if (op->getNumResults()) { |
| printValueID(op->getResult(0), /*printResultNo=*/false); |
| os << " = "; |
| } |
| |
| // Check to see if this is a known operation. If so, use the registered |
| // custom printer hook. |
| if (auto *opInfo = state.operationSet->lookup(op->getName())) { |
| opInfo->printAssembly(op, this); |
| return; |
| } |
| |
| // Otherwise use the standard verbose printing approach. |
| printDefaultOp(op); |
| } |
| |
| void FunctionPrinter::printDefaultOp(const Operation *op) { |
| os << '"'; |
| printEscapedString(op->getName(), os); |
| os << "\"("; |
| |
| interleaveComma(op->getOperands(), |
| [&](const SSAValue *value) { printValueID(value); }); |
| |
| os << ')'; |
| auto attrs = op->getAttrs(); |
| printOptionalAttrDict(attrs); |
| |
| // Print the type signature of the operation. |
| os << " : ("; |
| interleaveComma(op->getOperands(), |
| [&](const SSAValue *value) { printType(value->getType()); }); |
| os << ") -> "; |
| |
| if (op->getNumResults() == 1) { |
| printType(op->getResult(0)->getType()); |
| } else { |
| os << '('; |
| interleaveComma(op->getResults(), [&](const SSAValue *result) { |
| printType(result->getType()); |
| }); |
| os << ')'; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CFG Function printing |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class CFGFunctionPrinter : public FunctionPrinter { |
| public: |
| CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other); |
| |
| const CFGFunction *getFunction() const { return function; } |
| |
| void print(); |
| void print(const BasicBlock *block); |
| |
| void print(const Instruction *inst); |
| void print(const OperationInst *inst); |
| void print(const ReturnInst *inst); |
| void print(const BranchInst *inst); |
| void print(const CondBranchInst *inst); |
| |
| unsigned getBBID(const BasicBlock *block) { |
| auto it = basicBlockIDs.find(block); |
| assert(it != basicBlockIDs.end() && "Block not in this function?"); |
| return it->second; |
| } |
| |
| private: |
| const CFGFunction *function; |
| DenseMap<const BasicBlock *, unsigned> basicBlockIDs; |
| |
| void numberValuesInBlock(const BasicBlock *block); |
| }; |
| } // end anonymous namespace |
| |
| CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function, |
| const ModulePrinter &other) |
| : FunctionPrinter(other), function(function) { |
| // Each basic block gets a unique ID per function. |
| unsigned blockID = 0; |
| for (auto &block : *function) { |
| basicBlockIDs[&block] = blockID++; |
| numberValuesInBlock(&block); |
| } |
| } |
| |
| /// Number all of the SSA values in the specified basic block. |
| void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) { |
| for (auto *arg : block->getArguments()) { |
| numberValueID(arg); |
| } |
| for (auto &op : *block) { |
| // We number instruction that have results, and we only number the first |
| // result. |
| if (op.getNumResults() != 0) |
| numberValueID(op.getResult(0)); |
| } |
| |
| // Terminators do not define values. |
| } |
| |
| void CFGFunctionPrinter::print() { |
| os << "cfgfunc "; |
| printFunctionSignature(getFunction()); |
| os << " {\n"; |
| |
| for (auto &block : *function) |
| print(&block); |
| os << "}\n\n"; |
| } |
| |
| void CFGFunctionPrinter::print(const BasicBlock *block) { |
| os << "bb" << getBBID(block); |
| |
| if (!block->args_empty()) { |
| os << '('; |
| interleaveComma(block->getArguments(), [&](const BBArgument *arg) { |
| printValueID(arg); |
| os << ": "; |
| printType(arg->getType()); |
| }); |
| os << ')'; |
| } |
| os << ':'; |
| |
| // Print out some context information about the predecessors of this block. |
| if (!block->getFunction()) { |
| os << "\t// block is not in a function!"; |
| } else if (block->hasNoPredecessors()) { |
| // Don't print "no predecessors" for the entry block. |
| if (block != &block->getFunction()->front()) |
| os << "\t// no predecessors"; |
| } else if (auto *pred = block->getSinglePredecessor()) { |
| os << "\t// pred: bb" << getBBID(pred); |
| } else { |
| // We want to print the predecessors in increasing numeric order, not in |
| // whatever order the use-list is in, so gather and sort them. |
| SmallVector<unsigned, 4> predIDs; |
| for (auto *pred : block->getPredecessors()) |
| predIDs.push_back(getBBID(pred)); |
| llvm::array_pod_sort(predIDs.begin(), predIDs.end()); |
| |
| os << "\t// " << predIDs.size() << " preds: "; |
| |
| interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; }); |
| } |
| os << '\n'; |
| |
| for (auto &inst : block->getOperations()) { |
| os << " "; |
| print(&inst); |
| os << '\n'; |
| } |
| |
| os << " "; |
| print(block->getTerminator()); |
| os << '\n'; |
| } |
| |
| void CFGFunctionPrinter::print(const Instruction *inst) { |
| if (!inst) { |
| os << "<<null instruction>>\n"; |
| return; |
| } |
| switch (inst->getKind()) { |
| case Instruction::Kind::Operation: |
| return print(cast<OperationInst>(inst)); |
| case TerminatorInst::Kind::Branch: |
| return print(cast<BranchInst>(inst)); |
| case TerminatorInst::Kind::CondBranch: |
| return print(cast<CondBranchInst>(inst)); |
| case TerminatorInst::Kind::Return: |
| return print(cast<ReturnInst>(inst)); |
| } |
| } |
| |
| void CFGFunctionPrinter::print(const OperationInst *inst) { |
| printOperation(inst); |
| } |
| |
| void CFGFunctionPrinter::print(const BranchInst *inst) { |
| os << "br bb" << getBBID(inst->getDest()); |
| |
| if (inst->getNumOperands() != 0) { |
| os << '('; |
| interleaveComma(inst->getOperands(), |
| [&](const CFGValue *operand) { printValueID(operand); }); |
| os << ") : "; |
| interleaveComma(inst->getOperands(), [&](const CFGValue *operand) { |
| printType(operand->getType()); |
| }); |
| } |
| } |
| |
| void CFGFunctionPrinter::print(const CondBranchInst *inst) { |
| os << "cond_br "; |
| printValueID(inst->getCondition()); |
| |
| os << ", bb" << getBBID(inst->getTrueDest()); |
| if (inst->getNumTrueOperands() != 0) { |
| os << '('; |
| interleaveComma(inst->getTrueOperands(), |
| [&](const CFGValue *operand) { printValueID(operand); }); |
| os << " : "; |
| interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) { |
| printType(operand->getType()); |
| }); |
| os << ")"; |
| } |
| |
| os << ", bb" << getBBID(inst->getFalseDest()); |
| if (inst->getNumFalseOperands() != 0) { |
| os << '('; |
| interleaveComma(inst->getFalseOperands(), |
| [&](const CFGValue *operand) { printValueID(operand); }); |
| os << " : "; |
| interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) { |
| printType(operand->getType()); |
| }); |
| os << ")"; |
| } |
| } |
| |
| void CFGFunctionPrinter::print(const ReturnInst *inst) { |
| os << "return"; |
| |
| if (inst->getNumOperands() == 0) |
| return; |
| |
| os << ' '; |
| interleaveComma(inst->getOperands(), |
| [&](const CFGValue *operand) { printValueID(operand); }); |
| os << " : "; |
| interleaveComma(inst->getOperands(), [&](const CFGValue *operand) { |
| printType(operand->getType()); |
| }); |
| } |
| |
| void ModulePrinter::print(const CFGFunction *fn) { |
| CFGFunctionPrinter(fn, *this).print(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ML Function printing |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class MLFunctionPrinter : public FunctionPrinter { |
| public: |
| MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other); |
| |
| const MLFunction *getFunction() const { return function; } |
| |
| // Prints ML function. |
| void print(); |
| |
| // Prints ML function signature. |
| void printFunctionSignature(); |
| |
| // Methods to print ML function statements. |
| void print(const Statement *stmt); |
| void print(const OperationStmt *stmt); |
| void print(const ForStmt *stmt); |
| void print(const IfStmt *stmt); |
| void print(const StmtBlock *block); |
| |
| // Print loop bounds. |
| void printDimAndSymbolList(ArrayRef<StmtOperand> ops, unsigned numDims); |
| void printBound(AffineBound bound, const char *prefix); |
| |
| // Number of spaces used for indenting nested statements. |
| const static unsigned indentWidth = 2; |
| |
| private: |
| void numberValues(); |
| |
| const MLFunction *function; |
| int numSpaces; |
| }; |
| } // end anonymous namespace |
| |
| MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function, |
| const ModulePrinter &other) |
| : FunctionPrinter(other), function(function), numSpaces(0) { |
| assert(function && "Cannot print nullptr function"); |
| numberValues(); |
| } |
| |
| /// Number all of the SSA values in this ML function. |
| void MLFunctionPrinter::numberValues() { |
| // Numbers ML function arguments. |
| for (auto *arg : function->getArguments()) |
| numberValueID(arg); |
| |
| // Walks ML function statements and numbers for statements and |
| // the first result of the operation statements. |
| struct NumberValuesPass : public StmtWalker<NumberValuesPass> { |
| NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {} |
| void visitOperationStmt(OperationStmt *stmt) { |
| if (stmt->getNumResults() != 0) |
| printer->numberValueID(stmt->getResult(0)); |
| } |
| void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); } |
| MLFunctionPrinter *printer; |
| }; |
| |
| NumberValuesPass pass(this); |
| // TODO: it'd be cleaner to have constant visitor instead of using const_cast. |
| pass.walk(const_cast<MLFunction *>(function)); |
| } |
| |
| void MLFunctionPrinter::print() { |
| os << "mlfunc "; |
| printFunctionSignature(); |
| os << " {\n"; |
| print(function); |
| os << "}\n\n"; |
| } |
| |
| void MLFunctionPrinter::printFunctionSignature() { |
| auto type = function->getType(); |
| |
| os << "@" << function->getName() << '('; |
| |
| for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) { |
| if (i > 0) |
| os << ", "; |
| auto *arg = function->getArgument(i); |
| printOperand(arg); |
| os << " : "; |
| printType(arg->getType()); |
| } |
| os << ")"; |
| printFunctionResultType(type); |
| } |
| |
| void MLFunctionPrinter::print(const StmtBlock *block) { |
| numSpaces += indentWidth; |
| for (auto &stmt : block->getStatements()) { |
| print(&stmt); |
| os << "\n"; |
| } |
| numSpaces -= indentWidth; |
| } |
| |
| void MLFunctionPrinter::print(const Statement *stmt) { |
| switch (stmt->getKind()) { |
| case Statement::Kind::Operation: |
| return print(cast<OperationStmt>(stmt)); |
| case Statement::Kind::For: |
| return print(cast<ForStmt>(stmt)); |
| case Statement::Kind::If: |
| return print(cast<IfStmt>(stmt)); |
| } |
| } |
| |
| void MLFunctionPrinter::print(const OperationStmt *stmt) { |
| os.indent(numSpaces); |
| printOperation(stmt); |
| } |
| |
| void MLFunctionPrinter::print(const ForStmt *stmt) { |
| os.indent(numSpaces) << "for "; |
| printOperand(stmt); |
| os << " = "; |
| printBound(stmt->getLowerBound(), "max"); |
| os << " to "; |
| printBound(stmt->getUpperBound(), "min"); |
| |
| if (stmt->getStep() != 1) |
| os << " step " << stmt->getStep(); |
| |
| os << " {\n"; |
| print(static_cast<const StmtBlock *>(stmt)); |
| os.indent(numSpaces) << "}"; |
| } |
| |
| void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops, |
| unsigned numDims) { |
| auto printComma = [&]() { os << ", "; }; |
| os << '('; |
| interleave(ops.begin(), ops.begin() + numDims, |
| [&](const StmtOperand &v) { printOperand(v.get()); }, printComma); |
| os << ')'; |
| |
| if (numDims < ops.size()) { |
| os << '['; |
| interleave(ops.begin() + numDims, ops.end(), |
| [&](const StmtOperand &v) { printOperand(v.get()); }, |
| printComma); |
| os << ']'; |
| } |
| } |
| |
| void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) { |
| AffineMap *map = bound.getMap(); |
| |
| // Check if this bound should be printed using short-hand notation. |
| if (map->getNumResults() == 1) { |
| AffineExpr *expr = map->getResult(0); |
| |
| // Print constant bound. |
| if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) { |
| os << constExpr->getValue(); |
| return; |
| } |
| |
| // Print bound that consists of a single SSA id. |
| if (isa<AffineDimExpr>(expr) || isa<AffineSymbolExpr>(expr)) { |
| printOperand(bound.getOperand(0)); |
| return; |
| } |
| } else { |
| // Map has multiple results. Print 'min' or 'max' prefix. |
| os << prefix << ' '; |
| } |
| |
| // Print the map and the operands. |
| printAffineMapReference(map); |
| printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims()); |
| } |
| |
| void MLFunctionPrinter::print(const IfStmt *stmt) { |
| os.indent(numSpaces) << "if "; |
| IntegerSet *set = stmt->getIntegerSet(); |
| printIntegerSetReference(set); |
| printDimAndSymbolList(stmt->getStmtOperands(), set->getNumDims()); |
| os << " {\n"; |
| print(stmt->getThen()); |
| os.indent(numSpaces) << "}"; |
| if (stmt->hasElse()) { |
| os << " else {\n"; |
| print(stmt->getElse()); |
| os.indent(numSpaces) << "}"; |
| } |
| } |
| |
| void ModulePrinter::print(const MLFunction *fn) { |
| MLFunctionPrinter(fn, *this).print(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // print and dump methods |
| //===----------------------------------------------------------------------===// |
| |
| void Attribute::print(raw_ostream &os) const { |
| ModuleState state(/*no context is known*/ nullptr); |
| ModulePrinter(os, state).printAttribute(this); |
| } |
| |
| void Attribute::dump() const { print(llvm::errs()); } |
| |
| void Type::print(raw_ostream &os) const { |
| ModuleState state(getContext()); |
| ModulePrinter(os, state).printType(this); |
| } |
| |
| void Type::dump() const { print(llvm::errs()); } |
| |
| void AffineMap::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void AffineExpr::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void IntegerSet::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void AffineExpr::print(raw_ostream &os) const { |
| ModuleState state(/*no context is known*/ nullptr); |
| ModulePrinter(os, state).printAffineExpr(this); |
| } |
| |
| void AffineMap::print(raw_ostream &os) const { |
| ModuleState state(/*no context is known*/ nullptr); |
| ModulePrinter(os, state).printAffineMap(this); |
| } |
| |
| void IntegerSet::print(raw_ostream &os) const { |
| ModuleState state(/*no context is known*/ nullptr); |
| ModulePrinter(os, state).printIntegerSet(this); |
| } |
| |
| void SSAValue::print(raw_ostream &os) const { |
| switch (getKind()) { |
| case SSAValueKind::BBArgument: |
| // TODO: Improve this. |
| os << "<bb argument>\n"; |
| return; |
| case SSAValueKind::InstResult: |
| return getDefiningInst()->print(os); |
| case SSAValueKind::MLFuncArgument: |
| // TODO: Improve this. |
| os << "<function argument>\n"; |
| return; |
| case SSAValueKind::StmtResult: |
| return getDefiningStmt()->print(os); |
| case SSAValueKind::ForStmt: |
| return cast<ForStmt>(this)->print(os); |
| } |
| } |
| |
| void SSAValue::dump() const { print(llvm::errs()); } |
| |
| void Instruction::print(raw_ostream &os) const { |
| if (!getFunction()) { |
| os << "<<UNLINKED INSTRUCTION>>\n"; |
| return; |
| } |
| ModuleState state(getFunction()->getContext()); |
| ModulePrinter modulePrinter(os, state); |
| CFGFunctionPrinter(getFunction(), modulePrinter).print(this); |
| } |
| |
| void Instruction::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void BasicBlock::print(raw_ostream &os) const { |
| if (!getFunction()) { |
| os << "<<UNLINKED BLOCK>>\n"; |
| return; |
| } |
| ModuleState state(getFunction()->getContext()); |
| ModulePrinter modulePrinter(os, state); |
| CFGFunctionPrinter(getFunction(), modulePrinter).print(this); |
| } |
| |
| void BasicBlock::dump() const { print(llvm::errs()); } |
| |
| void Statement::print(raw_ostream &os) const { |
| MLFunction *function = findFunction(); |
| if (!function) { |
| os << "<<UNLINKED STATEMENT>>\n"; |
| return; |
| } |
| |
| ModuleState state(function->getContext()); |
| ModulePrinter modulePrinter(os, state); |
| MLFunctionPrinter(function, modulePrinter).print(this); |
| } |
| |
| void Statement::dump() const { print(llvm::errs()); } |
| |
| void StmtBlock::printBlock(raw_ostream &os) const { |
| MLFunction *function = findFunction(); |
| ModuleState state(function->getContext()); |
| ModulePrinter modulePrinter(os, state); |
| MLFunctionPrinter(function, modulePrinter).print(this); |
| } |
| |
| void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); } |
| |
| void Function::print(raw_ostream &os) const { |
| ModuleState state(getContext()); |
| ModulePrinter(os, state).print(this); |
| } |
| |
| void Function::dump() const { print(llvm::errs()); } |
| |
| void Module::print(raw_ostream &os) const { |
| ModuleState state(getContext()); |
| state.initialize(this); |
| ModulePrinter(os, state).print(this); |
| } |
| |
| void Module::dump() const { print(llvm::errs()); } |