blob: 9cf9c558ba70f401af0247f77e54f4b00f3bc8ef [file] [log] [blame]
//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the Linalg dialect types and dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Parser.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::linalg;
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<BufferType, RangeType, ViewType>();
addOperations<ForOp, LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
>();
}
struct mlir::linalg::BufferTypeStorage : public TypeStorage {
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
Key(Type elementType, int64_t bufferSize = -1)
: elementType(elementType), bufferSize(bufferSize) {}
Type elementType;
int64_t bufferSize;
};
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
using KeyTy = Key;
/// Construction in the llvm::BumpPtrAllocator given a key.
static BufferTypeStorage *construct(TypeStorageAllocator &allocator,
const Key &key) {
return new (allocator.allocate<BufferTypeStorage>()) BufferTypeStorage(key);
}
/// Equality operator for hashing.
bool operator==(const Key &key) const {
return elementType == key.elementType && bufferSize == key.bufferSize;
}
/// Hashing for unique'ing.
static unsigned hashKey(const Key &key) {
return llvm::hash_combine(key.elementType, key.bufferSize);
}
Type getElementType() { return elementType; }
bool hasConstantSize() { return bufferSize >= 0; }
Optional<int64_t> getBufferSize() {
if (hasConstantSize()) {
return bufferSize;
}
return llvm::None;
}
private:
BufferTypeStorage(const Key &key)
: elementType(key.elementType), bufferSize(key.bufferSize) {}
Type elementType;
int64_t bufferSize;
};
BufferType mlir::linalg::BufferType::get(MLIRContext *context, Type elementType,
int64_t bufferSize) {
return Base::get(context, LinalgTypes::Buffer, elementType, bufferSize);
}
Type mlir::linalg::BufferType::getElementType() {
return getImpl()->getElementType();
}
bool mlir::linalg::BufferType::hasConstantSize() {
return getImpl()->hasConstantSize();
}
Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
return getImpl()->getBufferSize();
}
Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
Location loc) const {
StringRef origSpec = spec;
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
else if (spec.consume_front("buffer")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
// Check for '?'
int64_t bufferSize = -1;
if (!spec.consume_front("?")) {
unsigned long long parsedBufferSize = 0;
if (spec.consumeInteger(10, parsedBufferSize)) {
emitError(loc, "expected buffer size to be an unsigned integer");
return Type();
}
bufferSize = static_cast<int64_t>(parsedBufferSize);
}
if (!spec.consume_front("x")) {
emitError(loc, "missing x in buffer type descrition : ") << spec;
return Type();
}
if (auto t = mlir::parseType(spec, context))
return (bufferSize == -1
? BufferType::get(getContext(), t)
: BufferType::get(getContext(), t, bufferSize));
}
} else if (spec.consume_front("view")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
// Just count the number of ? to get the rank.
unsigned rank = 0;
for (unsigned i = 0, e = spec.size(); i < e; ++i) {
if (spec.consume_front("?")) {
++rank;
if (!spec.consume_front("x")) {
emitError(loc, "expected a list of '?x' dimension specifiers: ")
<< spec;
return Type();
}
}
}
if (auto t = mlir::parseType(spec, context))
return ViewType::get(context, t, rank);
}
}
return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
}
struct mlir::linalg::ViewTypeStorage : public TypeStorage {
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
Key(Type elementType, unsigned rank)
: elementType(elementType), rank(rank) {}
Type elementType;
unsigned rank;
};
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
using KeyTy = Key;
/// Construction in the llvm::BumpPtrAllocator given a key.
static ViewTypeStorage *construct(TypeStorageAllocator &allocator,
const Key &key) {
return new (allocator.allocate<ViewTypeStorage>()) ViewTypeStorage(key);
}
/// Equality operator for hashing.
bool operator==(const Key &key) const {
return elementType == key.elementType && rank == key.rank;
}
/// Hashing for unique'ing.
static unsigned hashKey(const Key &key) {
return llvm::hash_combine(key.elementType, key.rank);
}
unsigned getRank() { return rank; };
Type getElementType() { return elementType; };
private:
ViewTypeStorage(const Key &key)
: elementType(key.elementType), rank(key.rank) {}
Type elementType;
unsigned rank;
};
ViewType mlir::linalg::ViewType::get(MLIRContext *context, Type elementType,
unsigned rank) {
return Base::get(context, LinalgTypes::View, elementType, rank);
}
Type mlir::linalg::ViewType::getElementType() {
return getImpl()->getElementType();
}
unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); }
/// BufferType prints as "buffer<element_type>".
static void print(BufferType bt, raw_ostream &os) {
os << "buffer<";
auto bs = bt.getBufferSize();
if (bs) {
os << bs.getValue();
} else {
os << "?";
}
os << "x" << bt.getElementType() << ">";
}
/// RangeType prints as just "range".
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
/// ViewType prints as:
///
/// ```{.mlir}
/// view<?x?xf32>
/// ```
///
/// or
///
/// ```{.mlir}
/// view<?xf32>
/// ```
///
/// for 0-D views (a.k.a pointer to a scalar value).
static void print(mlir::linalg::ViewType rt, raw_ostream &os) {
os << "view<";
for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
os << "?x";
}
os << rt.getElementType();
os << ">";
}
void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
case LinalgTypes::Buffer:
print(type.cast<BufferType>(), os);
break;
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
case LinalgTypes::View:
print(type.cast<ViewType>(), os);
break;
}
}