blob: 4a7bcd87d5846c729f55fc41b778dbdd1cd8b7f8 [file] [log] [blame]
//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the Linalg dialect types and dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::linalg;
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<BufferType, RangeType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
>();
}
struct mlir::linalg::BufferTypeStorage : public TypeStorage {
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
Key(Type elementType, int64_t bufferSize = -1)
: elementType(elementType), bufferSize(bufferSize) {}
Type elementType;
int64_t bufferSize;
};
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
using KeyTy = Key;
/// Construction in the llvm::BumpPtrAllocator given a key.
static BufferTypeStorage *construct(TypeStorageAllocator &allocator,
const Key &key) {
return new (allocator.allocate<BufferTypeStorage>()) BufferTypeStorage(key);
}
/// Equality operator for hashing.
bool operator==(const Key &key) const {
return elementType == key.elementType && bufferSize == key.bufferSize;
}
/// Hashing for unique'ing.
static unsigned hashKey(const Key &key) {
return llvm::hash_combine(key.elementType, key.bufferSize);
}
Type getElementType() { return elementType; }
bool hasConstantSize() { return bufferSize >= 0; }
Optional<int64_t> getBufferSize() {
if (hasConstantSize()) {
return bufferSize;
}
return llvm::None;
}
private:
BufferTypeStorage(const Key &key)
: elementType(key.elementType), bufferSize(key.bufferSize) {}
Type elementType;
int64_t bufferSize;
};
BufferType mlir::linalg::BufferType::get(MLIRContext *context, Type elementType,
int64_t bufferSize) {
return Base::get(context, LinalgTypes::Buffer, elementType, bufferSize);
}
Type mlir::linalg::BufferType::getElementType() {
return getImpl()->getElementType();
}
bool mlir::linalg::BufferType::hasConstantSize() {
return getImpl()->hasConstantSize();
}
Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
return getImpl()->getBufferSize();
}
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser,
Location loc) const {
StringRef spec = parser.getFullSymbolSpec();
StringRef origSpec = spec;
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
else if (spec.consume_front("buffer")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
StringRef sizeSpec, typeSpec;
std::tie(sizeSpec, typeSpec) = spec.split('x');
if (typeSpec.empty()) {
emitError(loc, "expected 'x' followed by element type");
return Type();
}
// Check for '?'
int64_t bufferSize = -1;
if (!sizeSpec.consume_front("?")) {
if (sizeSpec.consumeInteger(10, bufferSize)) {
emitError(loc, "expected buffer size to be an unsigned integer");
return Type();
}
}
if (!sizeSpec.empty()) {
emitError(loc, "unexpected token '") << sizeSpec << "'";
}
typeSpec = typeSpec.trim();
auto t = mlir::parseType(typeSpec, context);
if (!t) {
emitError(loc, "invalid type specification: '") << typeSpec << "'";
return Type();
}
return (bufferSize == -1 ? BufferType::get(getContext(), t)
: BufferType::get(getContext(), t, bufferSize));
}
}
return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
}
/// BufferType prints as "buffer<element_type>".
static void print(BufferType bt, DialectAsmPrinter &os) {
os << "buffer<";
auto bs = bt.getBufferSize();
if (bs) {
os << bs.getValue();
} else {
os << "?";
}
os << "x" << bt.getElementType() << ">";
}
/// RangeType prints as just "range".
static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
void mlir::linalg::LinalgDialect::printType(Type type,
DialectAsmPrinter &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
case LinalgTypes::Buffer:
print(type.cast<BufferType>(), os);
break;
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
}
}