Add DialectAsmParser/Printer classes to simplify dialect attribute and type parsing.
These classes are functionally similar to the OpAsmParser/Printer classes and provide hooks for parsing attributes/tokens/types/etc. This change merely sets up the base infrastructure and updates the parser hooks, followups will add hooks as needed to simplify existing handrolled dialect parsers.
This has various different benefits:
*) Attribute/Type parsing is much simpler to define.
*) Dialect attributes/types that contain other attributes/types can now use aliases.
*) It provides a 'spec' with which we may use in the future to auto-generate parsers/printers.
*) Error messages emitted by attribute/type parsers can provide character exact locations rather than "beginning of the string"
PiperOrigin-RevId: 278005322
diff --git a/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 67fccec..d09e815 100644
--- a/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -173,10 +173,10 @@
llvm::Module &getLLVMModule();
/// Parse a type registered to this dialect.
- Type parseType(StringRef tyData, Location loc) const override;
+ Type parseType(DialectAsmParser &parser, Location loc) const override;
/// Print a type registered to this dialect.
- void printType(Type type, raw_ostream &os) const override;
+ void printType(Type type, DialectAsmPrinter &os) const override;
/// Verify a region argument attribute registered to this dialect.
/// Returns failure if the verification failed, success otherwise.
diff --git a/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 1835073..8888953 100644
--- a/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -37,10 +37,10 @@
static StringRef getDialectNamespace() { return "linalg"; }
/// Parse a type registered to this dialect.
- Type parseType(llvm::StringRef spec, Location loc) const override;
+ Type parseType(DialectAsmParser &parser, Location loc) const override;
/// Print a type registered to this dialect.
- void printType(Type type, llvm::raw_ostream &os) const override;
+ void printType(Type type, DialectAsmPrinter &os) const override;
};
/// A BufferType represents a contiguous block of memory that can be allocated
diff --git a/include/mlir/Dialect/QuantOps/QuantOps.h b/include/mlir/Dialect/QuantOps/QuantOps.h
index 8753cd2..f1ac383 100644
--- a/include/mlir/Dialect/QuantOps/QuantOps.h
+++ b/include/mlir/Dialect/QuantOps/QuantOps.h
@@ -35,10 +35,10 @@
QuantizationDialect(MLIRContext *context);
/// Parse a type registered to this dialect.
- Type parseType(StringRef spec, Location loc) const override;
+ Type parseType(DialectAsmParser &parser, Location loc) const override;
/// Print a type registered to this dialect.
- void printType(Type type, raw_ostream &os) const override;
+ void printType(Type type, DialectAsmPrinter &os) const override;
};
#define GET_OP_CLASSES
diff --git a/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/include/mlir/Dialect/SPIRV/SPIRVDialect.h
index 8e98270..6401eba 100644
--- a/include/mlir/Dialect/SPIRV/SPIRVDialect.h
+++ b/include/mlir/Dialect/SPIRV/SPIRVDialect.h
@@ -46,10 +46,10 @@
static std::string getAttributeName(Decoration decoration);
/// Parses a type registered to this dialect.
- Type parseType(llvm::StringRef spec, Location loc) const override;
+ Type parseType(DialectAsmParser &parser, Location loc) const override;
/// Prints a type registered to this dialect.
- void printType(Type type, llvm::raw_ostream &os) const override;
+ void printType(Type type, DialectAsmPrinter &os) const override;
/// Provides a hook for materializing a constant to this dialect.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
diff --git a/include/mlir/IR/Dialect.h b/include/mlir/IR/Dialect.h
index bf7db91..bd84bee 100644
--- a/include/mlir/IR/Dialect.h
+++ b/include/mlir/IR/Dialect.h
@@ -25,6 +25,8 @@
#include "mlir/IR/OperationSupport.h"
namespace mlir {
+class DialectAsmParser;
+class DialectAsmPrinter;
class DialectInterface;
class OpBuilder;
class Type;
@@ -115,21 +117,21 @@
/// Parse an attribute registered to this dialect. If 'type' is nonnull, it
/// refers to the expected type of the attribute.
- virtual Attribute parseAttribute(StringRef attrData, Type type,
+ virtual Attribute parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const;
/// Print an attribute registered to this dialect. Note: The type of the
/// attribute need not be printed by this method as it is always printed by
/// the caller.
- virtual void printAttribute(Attribute, raw_ostream &) const {
+ virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
llvm_unreachable("dialect has no registered attribute printing hook");
}
/// Parse a type registered to this dialect.
- virtual Type parseType(StringRef tyData, Location loc) const;
+ virtual Type parseType(DialectAsmParser &parser, Location loc) const;
/// Print a type registered to this dialect.
- virtual void printType(Type, raw_ostream &) const {
+ virtual void printType(Type, DialectAsmPrinter &) const {
llvm_unreachable("dialect has no registered type printing hook");
}
diff --git a/include/mlir/IR/DialectImplementation.h b/include/mlir/IR/DialectImplementation.h
new file mode 100644
index 0000000..c662a4c
--- /dev/null
+++ b/include/mlir/IR/DialectImplementation.h
@@ -0,0 +1,139 @@
+//===- DialectImplementation.h ----------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains utilities classes for implementing dialect attributes and
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIALECTIMPLEMENTATION_H
+#define MLIR_IR_DIALECTIMPLEMENTATION_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+class Builder;
+
+//===----------------------------------------------------------------------===//
+// DialectAsmPrinter
+//===----------------------------------------------------------------------===//
+
+/// This is a pure-virtual base class that exposes the asmprinter hooks
+/// necessary to implement a custom printAttribute/printType() method on a
+/// dialect.
+class DialectAsmPrinter {
+public:
+ DialectAsmPrinter() {}
+ virtual ~DialectAsmPrinter();
+ virtual raw_ostream &getStream() const = 0;
+
+ /// Print the given attribute to the stream.
+ virtual void printAttribute(Attribute attr) = 0;
+
+ /// Print the given floating point value in a stabilized form that can be
+ /// roundtripped through the IR. This is the companion to the 'parseFloat'
+ /// hook on the DialectAsmParser.
+ virtual void printFloat(const APFloat &value) = 0;
+
+ /// Print the given type to the stream.
+ virtual void printType(Type type) = 0;
+
+private:
+ DialectAsmPrinter(const DialectAsmPrinter &) = delete;
+ void operator=(const DialectAsmPrinter &) = delete;
+};
+
+// Make the implementations convenient to use.
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) {
+ p.printAttribute(attr);
+ return p;
+}
+
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p,
+ const APFloat &value) {
+ p.printFloat(value);
+ return p;
+}
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) {
+ return p << APFloat(value);
+}
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) {
+ return p << APFloat(value);
+}
+
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) {
+ p.printType(type);
+ return p;
+}
+
+// Support printing anything that isn't convertible to one of the above types,
+// even if it isn't exactly one of them. For example, we want to print
+// FunctionType with the Type version above, not have it match this.
+template <typename T, typename std::enable_if<
+ !std::is_convertible<T &, Attribute &>::value &&
+ !std::is_convertible<T &, Type &>::value &&
+ !std::is_convertible<T &, APFloat &>::value &&
+ !llvm::is_one_of<T, double, float>::value,
+ T>::type * = nullptr>
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) {
+ p.getStream() << other;
+ return p;
+}
+
+//===----------------------------------------------------------------------===//
+// DialectAsmParser
+//===----------------------------------------------------------------------===//
+
+/// The DialectAsmParser has methods for interacting with the asm parser:
+/// parsing things from it, emitting errors etc. It has an intentionally
+/// high-level API that is designed to reduce/constrain syntax innovation in
+/// individual attributes or types.
+class DialectAsmParser {
+public:
+ virtual ~DialectAsmParser();
+
+ /// Emit a diagnostic at the specified location and return failure.
+ virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
+ const Twine &message = {}) = 0;
+
+ /// Return a builder which provides useful access to MLIRContext, global
+ /// objects like types and attributes.
+ virtual Builder &getBuilder() const = 0;
+
+ /// Get the location of the next token and store it into the argument. This
+ /// always succeeds.
+ virtual llvm::SMLoc getCurrentLocation() = 0;
+ ParseResult getCurrentLocation(llvm::SMLoc *loc) {
+ *loc = getCurrentLocation();
+ return success();
+ }
+
+ /// Return the location of the original name token.
+ virtual llvm::SMLoc getNameLoc() const = 0;
+
+ /// Returns the full specification of the symbol being parsed. This allows for
+ /// using a separate parser if necessary.
+ virtual StringRef getFullSymbolSpec() const = 0;
+};
+
+} // end namespace mlir
+
+#endif
diff --git a/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d7d3307..39decf9 100644
--- a/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
@@ -1249,7 +1250,9 @@
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
/// Parse a type registered to this dialect.
-Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
+Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
+ StringRef tyData = parser.getFullSymbolSpec();
+
// LLVM is not thread-safe, so lock access to it.
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
@@ -1261,11 +1264,11 @@
}
/// Print a type registered to this dialect.
-void LLVMDialect::printType(Type type, raw_ostream &os) const {
+void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
auto llvmType = type.dyn_cast<LLVMType>();
assert(llvmType && "printing wrong type");
assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
- llvmType.getUnderlyingType()->print(os);
+ llvmType.getUnderlyingType()->print(os.getStream());
}
/// Verify LLVMIR function argument attributes.
diff --git a/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index c09b75e..4a7bcd8 100644
--- a/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -20,9 +20,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/StandardTypes.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/Support/LLVM.h"
@@ -107,8 +108,9 @@
return getImpl()->getBufferSize();
}
-Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
+Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser,
Location loc) const {
+ StringRef spec = parser.getFullSymbolSpec();
StringRef origSpec = spec;
MLIRContext *context = getContext();
if (spec == "range")
@@ -146,9 +148,8 @@
return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
}
-
/// BufferType prints as "buffer<element_type>".
-static void print(BufferType bt, raw_ostream &os) {
+static void print(BufferType bt, DialectAsmPrinter &os) {
os << "buffer<";
auto bs = bt.getBufferSize();
if (bs) {
@@ -160,9 +161,10 @@
}
/// RangeType prints as just "range".
-static void print(RangeType rt, raw_ostream &os) { os << "range"; }
+static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
-void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const {
+void mlir::linalg::LinalgDialect::printType(Type type,
+ DialectAsmPrinter &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
diff --git a/lib/Dialect/QuantOps/IR/TypeParser.cpp b/lib/Dialect/QuantOps/IR/TypeParser.cpp
index 726c20c..360c1b5 100644
--- a/lib/Dialect/QuantOps/IR/TypeParser.cpp
+++ b/lib/Dialect/QuantOps/IR/TypeParser.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
@@ -615,9 +616,10 @@
}
/// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(StringRef spec, Location loc) const {
- TypeParser parser(spec, getContext(), loc);
- Type parsedType = parser.parseType();
+Type QuantizationDialect::parseType(DialectAsmParser &parser,
+ Location loc) const {
+ TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc);
+ Type parsedType = typeParser.parseType();
if (parsedType == nullptr) {
// Error.
// TODO(laurenzo): Do something?
@@ -723,19 +725,20 @@
}
/// Print a type registered to this dialect.
-void QuantizationDialect::printType(Type type, raw_ostream &os) const {
+void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled quantized type");
case QuantizationTypes::Any:
- printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
+ printAnyQuantizedType(type.cast<AnyQuantizedType>(), os.getStream());
break;
case QuantizationTypes::UniformQuantized:
- printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
+ printUniformQuantizedType(type.cast<UniformQuantizedType>(),
+ os.getStream());
break;
case QuantizationTypes::UniformQuantizedPerAxis:
printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
- os);
+ os.getStream());
break;
}
}
diff --git a/lib/Dialect/SPIRV/SPIRVDialect.cpp b/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 96777b1..26d1ff1 100644
--- a/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
@@ -609,7 +610,9 @@
// | pointer-type
// | runtime-array-type
// | struct-type
-Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
+Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const {
+ StringRef spec = parser.getFullSymbolSpec();
+
if (spec.startswith("array"))
return parseArrayType(*this, spec, loc);
if (spec.startswith("image"))
@@ -629,7 +632,7 @@
// Type Printing
//===----------------------------------------------------------------------===//
-static void print(ArrayType type, llvm::raw_ostream &os) {
+static void print(ArrayType type, DialectAsmPrinter &os) {
os << "array<" << type.getNumElements() << " x " << type.getElementType();
if (type.hasLayout()) {
os << " [" << type.getArrayStride() << "]";
@@ -637,16 +640,16 @@
os << ">";
}
-static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
+static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
os << "rtarray<" << type.getElementType() << ">";
}
-static void print(PointerType type, llvm::raw_ostream &os) {
+static void print(PointerType type, DialectAsmPrinter &os) {
os << "ptr<" << type.getPointeeType() << ", "
<< stringifyStorageClass(type.getStorageClass()) << ">";
}
-static void print(ImageType type, llvm::raw_ostream &os) {
+static void print(ImageType type, DialectAsmPrinter &os) {
os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
<< ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
<< stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
@@ -655,7 +658,7 @@
<< stringifyImageFormat(type.getImageFormat()) << ">";
}
-static void print(StructType type, llvm::raw_ostream &os) {
+static void print(StructType type, DialectAsmPrinter &os) {
os << "struct<";
auto printMember = [&](unsigned i) {
os << type.getElementType(i);
@@ -680,7 +683,7 @@
os << ">";
}
-void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
+void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TypeKind::Array:
print(type.cast<ArrayType>(), os);
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 0200e98..0e6b788 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
@@ -53,6 +54,8 @@
void OperationName::dump() const { print(llvm::errs()); }
+DialectAsmPrinter::~DialectAsmPrinter() {}
+
OpAsmPrinter::~OpAsmPrinter() {}
//===----------------------------------------------------------------------===//
@@ -391,6 +394,9 @@
: os(printer.os), printerFlags(printer.printerFlags),
state(printer.state) {}
+ /// Returns the output stream of the printer.
+ raw_ostream &getStream() { return os; }
+
template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
mlir::interleaveComma(c, os, each_fn);
@@ -420,6 +426,9 @@
void printLocationInternal(LocationAttr loc, bool pretty = false);
void printDenseElementsAttr(DenseElementsAttr attr);
+ void printDialectAttribute(Attribute attr);
+ void printDialectType(Type type);
+
/// This enum is used to represent the binding strength of the enclosing
/// context that an AffineExprStorage is being printed in, so we can
/// intelligently produce parens.
@@ -715,19 +724,9 @@
}
switch (attr.getKind()) {
- default: {
- auto &dialect = attr.getDialect();
+ default:
+ return printDialectAttribute(attr);
- // Ask the dialect to serialize the attribute to a string.
- std::string attrName;
- {
- llvm::raw_string_ostream attrNameStr(attrName);
- dialect.printAttribute(attr, attrNameStr);
- }
-
- printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
- break;
- }
case StandardAttributes::Opaque: {
auto opaqueAttr = attr.cast<OpaqueAttr>();
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
@@ -950,19 +949,9 @@
}
switch (type.getKind()) {
- default: {
- auto &dialect = type.getDialect();
+ default:
+ return printDialectType(type);
- // Ask the dialect to serialize the type to a string.
- std::string typeName;
- {
- llvm::raw_string_ostream typeNameStr(typeName);
- dialect.printType(type, typeNameStr);
- }
-
- printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
- return;
- }
case Type::Kind::Opaque: {
auto opaqueTy = type.cast<OpaqueType>();
printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
@@ -1073,6 +1062,65 @@
}
//===----------------------------------------------------------------------===//
+// CustomDialectAsmPrinter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides the main specialication of the DialectAsmPrinter that is
+/// used to provide support for print attributes and types. This hooks allows
+/// for dialects to hook into the main ModulePrinter.
+struct CustomDialectAsmPrinter : public DialectAsmPrinter {
+public:
+ CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
+ ~CustomDialectAsmPrinter() override {}
+
+ raw_ostream &getStream() const override { return printer.getStream(); }
+
+ /// Print the given attribute to the stream.
+ void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
+
+ /// Print the given floating point value in a stablized form.
+ void printFloat(const APFloat &value) override {
+ printFloatValue(value, getStream());
+ }
+
+ /// Print the given type to the stream.
+ void printType(Type type) override { printer.printType(type); }
+
+ /// The main module printer.
+ ModulePrinter &printer;
+};
+} // end anonymous namespace
+
+void ModulePrinter::printDialectAttribute(Attribute attr) {
+ auto &dialect = attr.getDialect();
+
+ // Ask the dialect to serialize the attribute to a string.
+ std::string attrName;
+ {
+ llvm::raw_string_ostream attrNameStr(attrName);
+ ModulePrinter subPrinter(attrNameStr, printerFlags, state);
+ CustomDialectAsmPrinter printer(subPrinter);
+ dialect.printAttribute(attr, printer);
+ }
+ printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
+}
+
+void ModulePrinter::printDialectType(Type type) {
+ auto &dialect = type.getDialect();
+
+ // Ask the dialect to serialize the type to a string.
+ std::string typeName;
+ {
+ llvm::raw_string_ostream typeNameStr(typeName);
+ ModulePrinter subPrinter(typeNameStr, printerFlags, state);
+ CustomDialectAsmPrinter printer(subPrinter);
+ dialect.printType(type, printer);
+ }
+ printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
+}
+
+//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Dialect.cpp b/lib/IR/Dialect.cpp
index f8539c0..7882e4f 100644
--- a/lib/IR/Dialect.cpp
+++ b/lib/IR/Dialect.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@@ -28,6 +29,8 @@
using namespace mlir;
using namespace detail;
+DialectAsmParser::~DialectAsmParser() {}
+
//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//
@@ -99,7 +102,7 @@
}
/// Parse an attribute registered to this dialect.
-Attribute Dialect::parseAttribute(StringRef attrData, Type type,
+Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const {
emitError(loc) << "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
@@ -107,11 +110,11 @@
}
/// Parse a type registered to this dialect.
-Type Dialect::parseType(StringRef tyData, Location loc) const {
+Type Dialect::parseType(DialectAsmParser &parser, Location loc) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
- return OpaqueType::get(ns, tyData, getContext());
+ return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
}
emitError(loc) << "dialect '" << getNamespace()
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index af7e0b6..a6e0227 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
@@ -51,37 +52,43 @@
class Parser;
//===----------------------------------------------------------------------===//
-// ParserState
+// AliasState
//===----------------------------------------------------------------------===//
-/// This class refers to all of the state maintained globally by the parser,
-/// such as the current lexer position etc. The Parser base class provides
-/// methods to access this.
-class ParserState {
-public:
- ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
- : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {}
-
+/// This class contains record of any parsed top-level aliases.
+struct AliasState {
// A map from attribute alias identifier to Attribute.
llvm::StringMap<Attribute> attributeAliasDefinitions;
// A map from type alias identifier to Type.
llvm::StringMap<Type> typeAliasDefinitions;
+};
-private:
+//===----------------------------------------------------------------------===//
+// ParserState
+//===----------------------------------------------------------------------===//
+
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc.
+struct ParserState {
+ ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
+ AliasState &aliases)
+ : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
+ aliases(aliases) {}
ParserState(const ParserState &) = delete;
void operator=(const ParserState &) = delete;
- friend class Parser;
-
- // The context we're parsing into.
+ /// The context we're parsing into.
MLIRContext *const context;
- // The lexer for the source file we're parsing.
+ /// The lexer for the source file we're parsing.
Lexer lex;
- // This is the next token that hasn't been consumed yet.
+ /// This is the next token that hasn't been consumed yet.
Token curToken;
+
+ /// Any parsed alias state.
+ AliasState &aliases;
};
//===----------------------------------------------------------------------===//
@@ -348,6 +355,55 @@
return success();
}
+//===----------------------------------------------------------------------===//
+// DialectAsmParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides the main implementation of the DialectAsmParser that
+/// allows for dialects to parse attributes and types. This allows for dialect
+/// hooking into the main MLIR parsing logic.
+class CustomDialectAsmParser : public DialectAsmParser {
+public:
+ CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
+ : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
+ parser(parser) {}
+ ~CustomDialectAsmParser() override {}
+
+ /// Emit a diagnostic at the specified location and return failure.
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+ return parser.emitError(loc, message);
+ }
+
+ /// Return a builder which provides useful access to MLIRContext, global
+ /// objects like types and attributes.
+ Builder &getBuilder() const override { return parser.builder; }
+
+ /// Get the location of the next token and store it into the argument. This
+ /// always succeeds.
+ llvm::SMLoc getCurrentLocation() override {
+ return parser.getToken().getLoc();
+ }
+
+ /// Return the location of the original name token.
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ /// Returns the full specification of the symbol being parsed. This allows
+ /// for using a separate parser if necessary.
+ StringRef getFullSymbolSpec() const override { return fullSpec; }
+
+private:
+ /// The full symbol specification.
+ StringRef fullSpec;
+
+ /// The source location of the dialect symbol.
+ SMLoc nameLoc;
+
+ /// The main parser.
+ Parser &parser;
+};
+} // namespace
+
/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
/// and may be recursive. Return with the 'prettyName' StringRef encompassing
/// the entire pretty name.
@@ -486,8 +542,42 @@
}
// Call into the provided symbol construction function.
- auto encodedLoc = p.getEncodedSourceLocation(loc);
- return createSymbol(dialectName, symbolData, encodedLoc);
+ return createSymbol(dialectName, symbolData, loc);
+}
+
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
+ AliasState &aliasState, ParserFn &&parserFn,
+ size_t *numRead = nullptr) {
+ SourceMgr sourceMgr;
+ auto memBuffer = MemoryBuffer::getMemBuffer(
+ inputStr, /*BufferName=*/"<mlir_parser_buffer>",
+ /*RequiresNullTerminator=*/false);
+ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+ ParserState state(sourceMgr, context, aliasState);
+ Parser parser(state);
+
+ Token startTok = parser.getToken();
+ T symbol = parserFn(parser);
+ if (!symbol)
+ return T();
+
+ // If 'numRead' is valid, then provide the number of bytes that were read.
+ Token endTok = parser.getToken();
+ if (numRead) {
+ *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
+ startTok.getLoc().getPointer());
+
+ // Otherwise, ensure that all of the tokens were parsed.
+ } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
+ parser.emitError(endTok.getLoc(),
+ "encountered unexpected tokens after parsing");
+ return T();
+ }
+ return symbol;
}
//===----------------------------------------------------------------------===//
@@ -611,16 +701,24 @@
///
Type Parser::parseExtendedType() {
return parseExtendedSymbol<Type>(
- *this, Token::exclamation_identifier, state.typeAliasDefinitions,
- [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type {
+ *this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions,
+ [&](StringRef dialectName, StringRef symbolData,
+ llvm::SMLoc loc) -> Type {
+ Location encodedLoc = getEncodedSourceLocation(loc);
+
// If we found a registered dialect, then ask it to parse the type.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName))
- return dialect->parseType(symbolData, loc);
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ return parseSymbol<Type>(
+ symbolData, state.context, state.aliases, [&](Parser &parser) {
+ CustomDialectAsmParser customParser(symbolData, parser);
+ return dialect->parseType(customParser, encodedLoc);
+ });
+ }
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
Identifier::get(dialectName, state.context), symbolData,
- state.context, loc);
+ state.context, encodedLoc);
});
}
@@ -1217,22 +1315,29 @@
///
Attribute Parser::parseExtendedAttr(Type type) {
Attribute attr = parseExtendedSymbol<Attribute>(
- *this, Token::hash_identifier, state.attributeAliasDefinitions,
+ *this, Token::hash_identifier, state.aliases.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
- Location loc) -> Attribute {
+ llvm::SMLoc loc) -> Attribute {
// Parse an optional trailing colon type.
Type attrType = type;
if (consumeIf(Token::colon) && !(attrType = parseType()))
return Attribute();
// If we found a registered dialect, then ask it to parse the attribute.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName))
- return dialect->parseAttribute(symbolData, attrType, loc);
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ return parseSymbol<Attribute>(
+ symbolData, state.context, state.aliases, [&](Parser &parser) {
+ CustomDialectAsmParser customParser(symbolData, parser);
+ return dialect->parseAttribute(customParser, attrType,
+ getEncodedSourceLocation(loc));
+ });
+ }
// Otherwise, form a new opaque attribute.
return OpaqueAttr::getChecked(
Identifier::get(dialectName, state.context), symbolData,
- attrType ? attrType : NoneType::get(state.context), loc);
+ attrType ? attrType : NoneType::get(state.context),
+ getEncodedSourceLocation(loc));
});
// Ensure that the attribute has the same type as requested.
@@ -4137,7 +4242,7 @@
StringRef aliasName = getTokenSpelling().drop_front();
// Check for redefinitions.
- if (getState().attributeAliasDefinitions.count(aliasName) > 0)
+ if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0)
return emitError("redefinition of attribute alias id '" + aliasName + "'");
// Make sure this isn't invading the dialect attribute namespace.
@@ -4156,7 +4261,7 @@
if (!attr)
return failure();
- getState().attributeAliasDefinitions[aliasName] = attr;
+ getState().aliases.attributeAliasDefinitions[aliasName] = attr;
return success();
}
@@ -4169,7 +4274,7 @@
StringRef aliasName = getTokenSpelling().drop_front();
// Check for redefinitions.
- if (getState().typeAliasDefinitions.count(aliasName) > 0)
+ if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0)
return emitError("redefinition of type alias id '" + aliasName + "'");
// Make sure this isn't invading the dialect type namespace.
@@ -4190,7 +4295,7 @@
return failure();
// Register this alias with the parser state.
- getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType);
+ getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
return success();
}
@@ -4269,7 +4374,8 @@
OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
- ParserState state(sourceMgr, context);
+ AliasState aliasState;
+ ParserState state(sourceMgr, context, aliasState);
if (ModuleParser(state).parseModule(*module))
return nullptr;
@@ -4334,23 +4440,16 @@
template <typename T, typename ParserFn>
static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
size_t &numRead, ParserFn &&parserFn) {
- SourceMgr sourceMgr;
- auto memBuffer = MemoryBuffer::getMemBuffer(
- inputStr, /*BufferName=*/"<mlir_parser_buffer>",
- /*RequiresNullTerminator=*/false);
- sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
- SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
- ParserState state(sourceMgr, context);
- Parser parser(state);
-
- auto start = parser.getToken().getLoc();
- T symbol = parserFn(parser);
- if (!symbol)
- return T();
-
- auto end = parser.getToken().getLoc();
- numRead = static_cast<size_t>(end.getPointer() - start.getPointer());
- return symbol;
+ AliasState aliasState;
+ return parseSymbol<T>(
+ inputStr, context, aliasState,
+ [&](Parser &parser) {
+ SourceMgrDiagnosticHandler handler(
+ const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
+ parser.getContext());
+ return parserFn(parser);
+ },
+ &numRead);
}
Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context) {