blob: f8835f02c268329e9e66e74d6f431a1c12341ef6 [file] [log] [blame]
//===- Function.cpp - MLIR Function Classes -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Function.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Twine.h"
using namespace mlir;
using namespace mlir::detail;
FunctionStorage::FunctionStorage(Location location, StringRef name,
FunctionType type,
ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
FunctionStorage::FunctionStorage(Location location, StringRef name,
FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
MLIRContext *Function::getContext() { return getType().getContext(); }
Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
size_t Offset(
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
}
/// This is a trait method invoked when a Function is added to a Module. We
/// keep the module pointer and module symbol table up to date.
void llvm::ilist_traits<FunctionStorage>::addNodeToList(
FunctionStorage *function) {
assert(!function->module && "already in a module!");
function->module = getContainingModule();
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
void llvm::ilist_traits<FunctionStorage>::removeNodeFromList(
FunctionStorage *function) {
assert(function->module && "not already in a module!");
function->module = nullptr;
}
/// This is a trait method invoked when an operation is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
ilist_traits<FunctionStorage> &otherList, function_iterator first,
function_iterator last) {
// If we are transferring functions within the same module, the Module
// pointer doesn't need to be updated.
Module *curParent = getContainingModule();
if (curParent == otherList.getContainingModule())
return;
// Update the 'module' member and symbol table records for each function.
for (; first != last; ++first) {
removeNodeFromList(&*first);
addNodeToList(&*first);
}
}
/// Unlink this function from its Module and delete it.
void Function::erase() {
if (auto *module = getModule())
getModule()->functions.erase(impl);
else
delete impl;
}
/// Emit an error about fatal conditions with this function, reporting up to
/// any diagnostic handlers that may be listening. This function always
/// returns failure. NOTE: This may terminate the containing application, only
/// use when the IR is in an inconsistent state.
InFlightDiagnostic Function::emitError() { return emitError({}); }
InFlightDiagnostic Function::emitError(const Twine &message) {
return mlir::emitError(getLoc(), message);
}
/// Emit a warning about this function, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic Function::emitWarning() { return emitWarning({}); }
InFlightDiagnostic Function::emitWarning(const Twine &message) {
return mlir::emitWarning(getLoc(), message);
}
/// Emit a remark about this function, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic Function::emitRemark() { return emitRemark({}); }
InFlightDiagnostic Function::emitRemark(const Twine &message) {
return mlir::emitRemark(getLoc(), message);
}
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest.
llvm::MapVector<Identifier, Attribute> newAttrs;
for (auto &attr : dest.getAttrs())
newAttrs.insert(attr);
for (auto &attr : getAttrs()) {
auto insertPair = newAttrs.insert(attr);
// TODO(riverriddle) Verify that the two functions have compatible
// attributes.
(void)insertPair;
assert((insertPair.second || insertPair.first->second == attr.second) &&
"the two functions have incompatible attributes");
}
dest.setAttrs(newAttrs.takeVector());
// Clone the body.
impl->body.cloneInto(&dest.impl->body, mapper);
}
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). Replaces references
/// to cloned sub-values with the corresponding value that is copied, and adds
/// those mappings to the mapper.
Function Function::clone(BlockAndValueMapping &mapper) {
FunctionType newType = impl->type;
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
// argument to the input type vector.
bool isExternalFn = isExternal();
if (!isExternalFn) {
SmallVector<Type, 4> inputTypes;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
inputTypes.push_back(newType.getInput(i));
newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
}
// Create the new function.
Function newFunc = Function::create(getLoc(), getName(), newType);
/// Set the argument attributes for arguments that aren't being replaced.
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
if (isExternalFn || !mapper.contains(getArgument(i)))
newFunc.setArgAttrs(destI++, getArgAttrs(i));
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
Function Function::clone() {
BlockAndValueMapping mapper;
return clone(mapper);
}
//===----------------------------------------------------------------------===//
// Function implementation.
//===----------------------------------------------------------------------===//
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
void Function::addEntryBlock() {
assert(empty() && "function already has an entry block");
auto *entry = new Block();
push_back(entry);
entry->addArguments(impl->type.getInputs());
}
void Function::walk(const std::function<void(Operation *)> &callback) {
getBody().walk(callback);
}
//===----------------------------------------------------------------------===//
// Function Operation.
//===----------------------------------------------------------------------===//
void FuncOp::build(Builder *builder, OperationState *result, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
result->addAttribute("name", builder->getStringAttr(name));
result->addAttribute("type", builder->getTypeAttr(type));
result->attributes.append(attrs.begin(), attrs.end());
result->addRegion();
}
/// Parsing/Printing methods.
static ParseResult
parseArgumentList(OpAsmParser *parser, SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
if (parser->parseLParen())
return failure();
// The argument list either has to consistently have ssa-id's followed by
// types, or just be a type list. It isn't ok to sometimes have SSA ID's and
// sometimes not.
auto parseArgument = [&]() -> ParseResult {
llvm::SMLoc loc = parser->getCurrentLocation();
// Parse argument name if present.
OpAsmParser::OperandType argument;
Type argumentType;
if (succeeded(parser->parseOptionalRegionArgument(argument)) &&
!argument.name.empty()) {
// Reject this if the preceding argument was missing a name.
if (argNames.empty() && !argTypes.empty())
return parser->emitError(loc,
"expected type instead of SSA identifier");
argNames.push_back(argument);
if (parser->parseColonType(argumentType))
return failure();
} else if (!argNames.empty()) {
// Reject this if the preceding argument had a name.
return parser->emitError(loc, "expected SSA identifier");
} else if (parser->parseType(argumentType)) {
return failure();
}
// Add the argument type.
argTypes.push_back(argumentType);
// Parse any argument attributes.
SmallVector<NamedAttribute, 2> attrs;
if (parser->parseOptionalAttributeDict(attrs))
return failure();
argAttrs.push_back(attrs);
return success();
};
// Parse the function arguments.
if (parser->parseOptionalRParen()) {
do {
if (parseArgument())
return failure();
} while (succeeded(parser->parseOptionalComma()));
parser->parseRParen();
}
return success();
}
/// Parse a function signature, starting with a name and including the
/// parameter list.
static ParseResult parseFunctionSignature(
OpAsmParser *parser, FunctionType &type,
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
SmallVector<Type, 4> argTypes;
if (parseArgumentList(parser, argTypes, argNames, argAttrs))
return failure();
// Parse the return types if present.
SmallVector<Type, 4> results;
if (parser->parseOptionalArrowTypeList(results))
return failure();
type = parser->getBuilder().getFunctionType(argTypes, results);
return success();
}
ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
FunctionType type;
SmallVector<OpAsmParser::OperandType, 4> entryArgs;
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
auto &builder = parser->getBuilder();
// Parse the name as a function attribute.
FunctionAttr nameAttr;
if (parser->parseAttribute(nameAttr, "name", result->attributes))
return failure();
// Convert the parsed function attr into a string attr.
result->attributes.back().second = builder.getStringAttr(nameAttr.getValue());
// Parse the function signature.
if (parseFunctionSignature(parser, type, entryArgs, argAttrs))
return failure();
result->addAttribute("type", builder.getTypeAttr(type));
// If function attributes are present, parse them.
if (succeeded(parser->parseOptionalKeyword("attributes")))
if (parser->parseOptionalAttributeDict(result->attributes))
return failure();
// Add the attributes to the function arguments.
SmallString<8> argAttrName;
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (!argAttrs[i].empty())
result->addAttribute(getArgAttrName(i, argAttrName),
builder.getDictionaryAttr(argAttrs[i]));
// Parse the optional function body.
auto *body = result->addRegion();
if (parser->parseOptionalRegion(
*body, entryArgs, entryArgs.empty() ? llvm::None : type.getInputs()))
return failure();
return success();
}
static void printFunctionSignature(OpAsmPrinter *p, FuncOp op) {
*p << '(';
auto fnType = op.getType();
bool isExternal = op.isExternal();
for (unsigned i = 0, e = op.getNumArguments(); i != e; ++i) {
if (i > 0)
*p << ", ";
// If this is an external function, don't print argument labels.
if (!isExternal) {
p->printOperand(op.getArgument(i));
*p << ": ";
}
// Print the type followed by any argument attributes.
p->printType(fnType.getInput(i));
p->printOptionalAttrDict(op.getArgAttrs(i));
}
*p << ')';
p->printOptionalArrowTypeList(fnType.getResults());
}
void FuncOp::print(OpAsmPrinter *p) {
*p << "func @" << getName();
// Print the signature.
printFunctionSignature(p, *this);
// Print out function attributes, if present.
SmallVector<StringRef, 2> ignoredAttrs = {"name", "type"};
// Ignore any argument attributes.
std::vector<SmallString<8>> argAttrStorage;
SmallString<8> argAttrName;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (getAttr(getArgAttrName(i, argAttrName)))
argAttrStorage.emplace_back(argAttrName);
ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
auto attrs = getAttrs();
if (attrs.size() > ignoredAttrs.size()) {
*p << "\n attributes ";
p->printOptionalAttrDict(attrs, ignoredAttrs);
}
// Print the body if this is not an external function.
if (!isExternal()) {
p->printRegion(getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
*p << '\n';
}
*p << '\n';
}
LogicalResult FuncOp::verify() {
// If this function is external there is nothing to do.
if (isExternal())
return success();
// Verify that the argument list of the function and the arg list of the entry
// block line up.
Block &entryBlock = front();
auto fnInputTypes = getType().getInputs();
if (fnInputTypes.size() != entryBlock.getNumArguments())
return emitOpError("entry block must have ")
<< fnInputTypes.size() << " arguments to match function signature";
for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
if (fnInputTypes[i] != entryBlock.getArgument(i)->getType())
return emitOpError("type of entry block argument #")
<< i << '(' << entryBlock.getArgument(i)->getType()
<< ") must match the type of the corresponding argument in "
<< "function signature(" << fnInputTypes[i] << ')';
return success();
}
//===----------------------------------------------------------------------===//
// Function Argument Attribute.
//===----------------------------------------------------------------------===//
/// Set the attributes held by the argument at 'index'.
void FuncOp::setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
assert(index < getNumArguments() && "invalid argument number");
SmallString<8> nameOut;
getArgAttrName(index, nameOut);
if (attributes.empty())
return (void)removeAttr(nameOut);
setAttr(nameOut, DictionaryAttr::get(attributes, getContext()));
}
void FuncOp::setArgAttrs(unsigned index, NamedAttributeList attributes) {
assert(index < getNumArguments() && "invalid argument number");
SmallString<8> nameOut;
if (auto newAttr = attributes.getDictionary())
return setAttr(getArgAttrName(index, nameOut), newAttr);
removeAttr(getArgAttrName(index, nameOut));
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void FuncOp::setArgAttr(unsigned index, Identifier name, Attribute value) {
auto curAttr = getArgAttrDict(index);
NamedAttributeList attrList(curAttr);
attrList.set(name, value);
// If the attribute changed, then set the new arg attribute list.
if (curAttr != attrList.getDictionary())
setArgAttrs(index, attrList);
}
/// Remove the attribute 'name' from the argument at 'index'.
NamedAttributeList::RemoveResult FuncOp::removeArgAttr(unsigned index,
Identifier name) {
// Build an attribute list and remove the attribute at 'name'.
NamedAttributeList attrList(getArgAttrDict(index));
auto result = attrList.remove(name);
// If the attribute was removed, then update the argument dictionary.
if (result == NamedAttributeList::RemoveResult::Removed)
setArgAttrs(index, attrList);
return result;
}
/// Returns the attribute entry name for the set of argument attributes at index
/// 'arg'.
StringRef FuncOp::getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
out.clear();
return ("arg" + Twine(arg)).toStringRef(out);
}