Add DialectAsmParser/Printer classes to simplify dialect attribute and type parsing.

These classes are functionally similar to the OpAsmParser/Printer classes and provide hooks for parsing attributes/tokens/types/etc. This change merely sets up the base infrastructure and updates the parser hooks, followups will add hooks as needed to simplify existing handrolled dialect parsers.

This has various different benefits:
*) Attribute/Type parsing is much simpler to define.
*) Dialect attributes/types that contain other attributes/types can now use aliases.
*) It provides a 'spec' with which we may use in the future to auto-generate parsers/printers.
*) Error messages emitted by attribute/type parsers can provide character exact locations rather than "beginning of the string"

PiperOrigin-RevId: 278005322
diff --git a/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 67fccec..d09e815 100644
--- a/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -173,10 +173,10 @@
   llvm::Module &getLLVMModule();
 
   /// Parse a type registered to this dialect.
-  Type parseType(StringRef tyData, Location loc) const override;
+  Type parseType(DialectAsmParser &parser, Location loc) const override;
 
   /// Print a type registered to this dialect.
-  void printType(Type type, raw_ostream &os) const override;
+  void printType(Type type, DialectAsmPrinter &os) const override;
 
   /// Verify a region argument attribute registered to this dialect.
   /// Returns failure if the verification failed, success otherwise.
diff --git a/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 1835073..8888953 100644
--- a/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -37,10 +37,10 @@
   static StringRef getDialectNamespace() { return "linalg"; }
 
   /// Parse a type registered to this dialect.
-  Type parseType(llvm::StringRef spec, Location loc) const override;
+  Type parseType(DialectAsmParser &parser, Location loc) const override;
 
   /// Print a type registered to this dialect.
-  void printType(Type type, llvm::raw_ostream &os) const override;
+  void printType(Type type, DialectAsmPrinter &os) const override;
 };
 
 /// A BufferType represents a contiguous block of memory that can be allocated
diff --git a/include/mlir/Dialect/QuantOps/QuantOps.h b/include/mlir/Dialect/QuantOps/QuantOps.h
index 8753cd2..f1ac383 100644
--- a/include/mlir/Dialect/QuantOps/QuantOps.h
+++ b/include/mlir/Dialect/QuantOps/QuantOps.h
@@ -35,10 +35,10 @@
   QuantizationDialect(MLIRContext *context);
 
   /// Parse a type registered to this dialect.
-  Type parseType(StringRef spec, Location loc) const override;
+  Type parseType(DialectAsmParser &parser, Location loc) const override;
 
   /// Print a type registered to this dialect.
-  void printType(Type type, raw_ostream &os) const override;
+  void printType(Type type, DialectAsmPrinter &os) const override;
 };
 
 #define GET_OP_CLASSES
diff --git a/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/include/mlir/Dialect/SPIRV/SPIRVDialect.h
index 8e98270..6401eba 100644
--- a/include/mlir/Dialect/SPIRV/SPIRVDialect.h
+++ b/include/mlir/Dialect/SPIRV/SPIRVDialect.h
@@ -46,10 +46,10 @@
   static std::string getAttributeName(Decoration decoration);
 
   /// Parses a type registered to this dialect.
-  Type parseType(llvm::StringRef spec, Location loc) const override;
+  Type parseType(DialectAsmParser &parser, Location loc) const override;
 
   /// Prints a type registered to this dialect.
-  void printType(Type type, llvm::raw_ostream &os) const override;
+  void printType(Type type, DialectAsmPrinter &os) const override;
 
   /// Provides a hook for materializing a constant to this dialect.
   Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
diff --git a/include/mlir/IR/Dialect.h b/include/mlir/IR/Dialect.h
index bf7db91..bd84bee 100644
--- a/include/mlir/IR/Dialect.h
+++ b/include/mlir/IR/Dialect.h
@@ -25,6 +25,8 @@
 #include "mlir/IR/OperationSupport.h"
 
 namespace mlir {
+class DialectAsmParser;
+class DialectAsmPrinter;
 class DialectInterface;
 class OpBuilder;
 class Type;
@@ -115,21 +117,21 @@
 
   /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
   /// refers to the expected type of the attribute.
-  virtual Attribute parseAttribute(StringRef attrData, Type type,
+  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type,
                                    Location loc) const;
 
   /// Print an attribute registered to this dialect. Note: The type of the
   /// attribute need not be printed by this method as it is always printed by
   /// the caller.
-  virtual void printAttribute(Attribute, raw_ostream &) const {
+  virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
     llvm_unreachable("dialect has no registered attribute printing hook");
   }
 
   /// Parse a type registered to this dialect.
-  virtual Type parseType(StringRef tyData, Location loc) const;
+  virtual Type parseType(DialectAsmParser &parser, Location loc) const;
 
   /// Print a type registered to this dialect.
-  virtual void printType(Type, raw_ostream &) const {
+  virtual void printType(Type, DialectAsmPrinter &) const {
     llvm_unreachable("dialect has no registered type printing hook");
   }
 
diff --git a/include/mlir/IR/DialectImplementation.h b/include/mlir/IR/DialectImplementation.h
new file mode 100644
index 0000000..c662a4c
--- /dev/null
+++ b/include/mlir/IR/DialectImplementation.h
@@ -0,0 +1,139 @@
+//===- DialectImplementation.h ----------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains utilities classes for implementing dialect attributes and
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIALECTIMPLEMENTATION_H
+#define MLIR_IR_DIALECTIMPLEMENTATION_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+class Builder;
+
+//===----------------------------------------------------------------------===//
+// DialectAsmPrinter
+//===----------------------------------------------------------------------===//
+
+/// This is a pure-virtual base class that exposes the asmprinter hooks
+/// necessary to implement a custom printAttribute/printType() method on a
+/// dialect.
+class DialectAsmPrinter {
+public:
+  DialectAsmPrinter() {}
+  virtual ~DialectAsmPrinter();
+  virtual raw_ostream &getStream() const = 0;
+
+  /// Print the given attribute to the stream.
+  virtual void printAttribute(Attribute attr) = 0;
+
+  /// Print the given floating point value in a stabilized form that can be
+  /// roundtripped through the IR. This is the companion to the 'parseFloat'
+  /// hook on the DialectAsmParser.
+  virtual void printFloat(const APFloat &value) = 0;
+
+  /// Print the given type to the stream.
+  virtual void printType(Type type) = 0;
+
+private:
+  DialectAsmPrinter(const DialectAsmPrinter &) = delete;
+  void operator=(const DialectAsmPrinter &) = delete;
+};
+
+// Make the implementations convenient to use.
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) {
+  p.printAttribute(attr);
+  return p;
+}
+
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p,
+                                     const APFloat &value) {
+  p.printFloat(value);
+  return p;
+}
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) {
+  return p << APFloat(value);
+}
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) {
+  return p << APFloat(value);
+}
+
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) {
+  p.printType(type);
+  return p;
+}
+
+// Support printing anything that isn't convertible to one of the above types,
+// even if it isn't exactly one of them.  For example, we want to print
+// FunctionType with the Type version above, not have it match this.
+template <typename T, typename std::enable_if<
+                          !std::is_convertible<T &, Attribute &>::value &&
+                              !std::is_convertible<T &, Type &>::value &&
+                              !std::is_convertible<T &, APFloat &>::value &&
+                              !llvm::is_one_of<T, double, float>::value,
+                          T>::type * = nullptr>
+inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) {
+  p.getStream() << other;
+  return p;
+}
+
+//===----------------------------------------------------------------------===//
+// DialectAsmParser
+//===----------------------------------------------------------------------===//
+
+/// The DialectAsmParser has methods for interacting with the asm parser:
+/// parsing things from it, emitting errors etc.  It has an intentionally
+/// high-level API that is designed to reduce/constrain syntax innovation in
+/// individual attributes or types.
+class DialectAsmParser {
+public:
+  virtual ~DialectAsmParser();
+
+  /// Emit a diagnostic at the specified location and return failure.
+  virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
+                                       const Twine &message = {}) = 0;
+
+  /// Return a builder which provides useful access to MLIRContext, global
+  /// objects like types and attributes.
+  virtual Builder &getBuilder() const = 0;
+
+  /// Get the location of the next token and store it into the argument.  This
+  /// always succeeds.
+  virtual llvm::SMLoc getCurrentLocation() = 0;
+  ParseResult getCurrentLocation(llvm::SMLoc *loc) {
+    *loc = getCurrentLocation();
+    return success();
+  }
+
+  /// Return the location of the original name token.
+  virtual llvm::SMLoc getNameLoc() const = 0;
+
+  /// Returns the full specification of the symbol being parsed. This allows for
+  /// using a separate parser if necessary.
+  virtual StringRef getFullSymbolSpec() const = 0;
+};
+
+} // end namespace mlir
+
+#endif
diff --git a/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d7d3307..39decf9 100644
--- a/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -21,6 +21,7 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
@@ -1249,7 +1250,9 @@
 llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
 
 /// Parse a type registered to this dialect.
-Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
+Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
+  StringRef tyData = parser.getFullSymbolSpec();
+
   // LLVM is not thread-safe, so lock access to it.
   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
 
@@ -1261,11 +1264,11 @@
 }
 
 /// Print a type registered to this dialect.
-void LLVMDialect::printType(Type type, raw_ostream &os) const {
+void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
   auto llvmType = type.dyn_cast<LLVMType>();
   assert(llvmType && "printing wrong type");
   assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
-  llvmType.getUnderlyingType()->print(os);
+  llvmType.getUnderlyingType()->print(os.getStream());
 }
 
 /// Verify LLVMIR function argument attributes.
diff --git a/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index c09b75e..4a7bcd8 100644
--- a/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -20,9 +20,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/Parser.h"
 #include "mlir/Support/LLVM.h"
 
@@ -107,8 +108,9 @@
   return getImpl()->getBufferSize();
 }
 
-Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
+Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser,
                                             Location loc) const {
+  StringRef spec = parser.getFullSymbolSpec();
   StringRef origSpec = spec;
   MLIRContext *context = getContext();
   if (spec == "range")
@@ -146,9 +148,8 @@
   return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
 }
 
-
 /// BufferType prints as "buffer<element_type>".
-static void print(BufferType bt, raw_ostream &os) {
+static void print(BufferType bt, DialectAsmPrinter &os) {
   os << "buffer<";
   auto bs = bt.getBufferSize();
   if (bs) {
@@ -160,9 +161,10 @@
 }
 
 /// RangeType prints as just "range".
-static void print(RangeType rt, raw_ostream &os) { os << "range"; }
+static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
 
-void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const {
+void mlir::linalg::LinalgDialect::printType(Type type,
+                                            DialectAsmPrinter &os) const {
   switch (type.getKind()) {
   default:
     llvm_unreachable("Unhandled Linalg type");
diff --git a/lib/Dialect/QuantOps/IR/TypeParser.cpp b/lib/Dialect/QuantOps/IR/TypeParser.cpp
index 726c20c..360c1b5 100644
--- a/lib/Dialect/QuantOps/IR/TypeParser.cpp
+++ b/lib/Dialect/QuantOps/IR/TypeParser.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/Dialect/QuantOps/QuantOps.h"
 #include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/Types.h"
@@ -615,9 +616,10 @@
 }
 
 /// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(StringRef spec, Location loc) const {
-  TypeParser parser(spec, getContext(), loc);
-  Type parsedType = parser.parseType();
+Type QuantizationDialect::parseType(DialectAsmParser &parser,
+                                    Location loc) const {
+  TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc);
+  Type parsedType = typeParser.parseType();
   if (parsedType == nullptr) {
     // Error.
     // TODO(laurenzo): Do something?
@@ -723,19 +725,20 @@
 }
 
 /// Print a type registered to this dialect.
-void QuantizationDialect::printType(Type type, raw_ostream &os) const {
+void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
   switch (type.getKind()) {
   default:
     llvm_unreachable("Unhandled quantized type");
   case QuantizationTypes::Any:
-    printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
+    printAnyQuantizedType(type.cast<AnyQuantizedType>(), os.getStream());
     break;
   case QuantizationTypes::UniformQuantized:
-    printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
+    printUniformQuantizedType(type.cast<UniformQuantizedType>(),
+                              os.getStream());
     break;
   case QuantizationTypes::UniformQuantizedPerAxis:
     printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
-                                     os);
+                                     os.getStream());
     break;
   }
 }
diff --git a/lib/Dialect/SPIRV/SPIRVDialect.cpp b/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 96777b1..26d1ff1 100644
--- a/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Parser.h"
@@ -609,7 +610,9 @@
 //              | pointer-type
 //              | runtime-array-type
 //              | struct-type
-Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
+Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const {
+  StringRef spec = parser.getFullSymbolSpec();
+
   if (spec.startswith("array"))
     return parseArrayType(*this, spec, loc);
   if (spec.startswith("image"))
@@ -629,7 +632,7 @@
 // Type Printing
 //===----------------------------------------------------------------------===//
 
-static void print(ArrayType type, llvm::raw_ostream &os) {
+static void print(ArrayType type, DialectAsmPrinter &os) {
   os << "array<" << type.getNumElements() << " x " << type.getElementType();
   if (type.hasLayout()) {
     os << " [" << type.getArrayStride() << "]";
@@ -637,16 +640,16 @@
   os << ">";
 }
 
-static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
+static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
   os << "rtarray<" << type.getElementType() << ">";
 }
 
-static void print(PointerType type, llvm::raw_ostream &os) {
+static void print(PointerType type, DialectAsmPrinter &os) {
   os << "ptr<" << type.getPointeeType() << ", "
      << stringifyStorageClass(type.getStorageClass()) << ">";
 }
 
-static void print(ImageType type, llvm::raw_ostream &os) {
+static void print(ImageType type, DialectAsmPrinter &os) {
   os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
      << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
      << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
@@ -655,7 +658,7 @@
      << stringifyImageFormat(type.getImageFormat()) << ">";
 }
 
-static void print(StructType type, llvm::raw_ostream &os) {
+static void print(StructType type, DialectAsmPrinter &os) {
   os << "struct<";
   auto printMember = [&](unsigned i) {
     os << type.getElementType(i);
@@ -680,7 +683,7 @@
   os << ">";
 }
 
-void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
+void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   switch (type.getKind()) {
   case TypeKind::Array:
     print(type.cast<ArrayType>(), os);
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 0200e98..0e6b788 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
@@ -53,6 +54,8 @@
 
 void OperationName::dump() const { print(llvm::errs()); }
 
+DialectAsmPrinter::~DialectAsmPrinter() {}
+
 OpAsmPrinter::~OpAsmPrinter() {}
 
 //===----------------------------------------------------------------------===//
@@ -391,6 +394,9 @@
       : os(printer.os), printerFlags(printer.printerFlags),
         state(printer.state) {}
 
+  /// Returns the output stream of the printer.
+  raw_ostream &getStream() { return os; }
+
   template <typename Container, typename UnaryFunctor>
   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
     mlir::interleaveComma(c, os, each_fn);
@@ -420,6 +426,9 @@
   void printLocationInternal(LocationAttr loc, bool pretty = false);
   void printDenseElementsAttr(DenseElementsAttr attr);
 
+  void printDialectAttribute(Attribute attr);
+  void printDialectType(Type type);
+
   /// This enum is used to represent the binding strength of the enclosing
   /// context that an AffineExprStorage is being printed in, so we can
   /// intelligently produce parens.
@@ -715,19 +724,9 @@
   }
 
   switch (attr.getKind()) {
-  default: {
-    auto &dialect = attr.getDialect();
+  default:
+    return printDialectAttribute(attr);
 
-    // Ask the dialect to serialize the attribute to a string.
-    std::string attrName;
-    {
-      llvm::raw_string_ostream attrNameStr(attrName);
-      dialect.printAttribute(attr, attrNameStr);
-    }
-
-    printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
-    break;
-  }
   case StandardAttributes::Opaque: {
     auto opaqueAttr = attr.cast<OpaqueAttr>();
     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
@@ -950,19 +949,9 @@
   }
 
   switch (type.getKind()) {
-  default: {
-    auto &dialect = type.getDialect();
+  default:
+    return printDialectType(type);
 
-    // Ask the dialect to serialize the type to a string.
-    std::string typeName;
-    {
-      llvm::raw_string_ostream typeNameStr(typeName);
-      dialect.printType(type, typeNameStr);
-    }
-
-    printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
-    return;
-  }
   case Type::Kind::Opaque: {
     auto opaqueTy = type.cast<OpaqueType>();
     printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
@@ -1073,6 +1062,65 @@
 }
 
 //===----------------------------------------------------------------------===//
+// CustomDialectAsmPrinter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides the main specialication of the DialectAsmPrinter that is
+/// used to provide support for print attributes and types. This hooks allows
+/// for dialects to hook into the main ModulePrinter.
+struct CustomDialectAsmPrinter : public DialectAsmPrinter {
+public:
+  CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
+  ~CustomDialectAsmPrinter() override {}
+
+  raw_ostream &getStream() const override { return printer.getStream(); }
+
+  /// Print the given attribute to the stream.
+  void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
+
+  /// Print the given floating point value in a stablized form.
+  void printFloat(const APFloat &value) override {
+    printFloatValue(value, getStream());
+  }
+
+  /// Print the given type to the stream.
+  void printType(Type type) override { printer.printType(type); }
+
+  /// The main module printer.
+  ModulePrinter &printer;
+};
+} // end anonymous namespace
+
+void ModulePrinter::printDialectAttribute(Attribute attr) {
+  auto &dialect = attr.getDialect();
+
+  // Ask the dialect to serialize the attribute to a string.
+  std::string attrName;
+  {
+    llvm::raw_string_ostream attrNameStr(attrName);
+    ModulePrinter subPrinter(attrNameStr, printerFlags, state);
+    CustomDialectAsmPrinter printer(subPrinter);
+    dialect.printAttribute(attr, printer);
+  }
+  printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
+}
+
+void ModulePrinter::printDialectType(Type type) {
+  auto &dialect = type.getDialect();
+
+  // Ask the dialect to serialize the type to a string.
+  std::string typeName;
+  {
+    llvm::raw_string_ostream typeNameStr(typeName);
+    ModulePrinter subPrinter(typeNameStr, printerFlags, state);
+    CustomDialectAsmPrinter printer(subPrinter);
+    dialect.printType(type, printer);
+  }
+  printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
+}
+
+//===----------------------------------------------------------------------===//
 // Affine expressions and maps
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/IR/Dialect.cpp b/lib/IR/Dialect.cpp
index f8539c0..7882e4f 100644
--- a/lib/IR/Dialect.cpp
+++ b/lib/IR/Dialect.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectHooks.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
@@ -28,6 +29,8 @@
 using namespace mlir;
 using namespace detail;
 
+DialectAsmParser::~DialectAsmParser() {}
+
 //===----------------------------------------------------------------------===//
 // Dialect Registration
 //===----------------------------------------------------------------------===//
@@ -99,7 +102,7 @@
 }
 
 /// Parse an attribute registered to this dialect.
-Attribute Dialect::parseAttribute(StringRef attrData, Type type,
+Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type,
                                   Location loc) const {
   emitError(loc) << "dialect '" << getNamespace()
                  << "' provides no attribute parsing hook";
@@ -107,11 +110,11 @@
 }
 
 /// Parse a type registered to this dialect.
-Type Dialect::parseType(StringRef tyData, Location loc) const {
+Type Dialect::parseType(DialectAsmParser &parser, Location loc) const {
   // If this dialect allows unknown types, then represent this with OpaqueType.
   if (allowsUnknownTypes()) {
     auto ns = Identifier::get(getNamespace(), getContext());
-    return OpaqueType::get(ns, tyData, getContext());
+    return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
   }
 
   emitError(loc) << "dialect '" << getNamespace()
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index af7e0b6..a6e0227 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
@@ -51,37 +52,43 @@
 class Parser;
 
 //===----------------------------------------------------------------------===//
-// ParserState
+// AliasState
 //===----------------------------------------------------------------------===//
 
-/// This class refers to all of the state maintained globally by the parser,
-/// such as the current lexer position etc. The Parser base class provides
-/// methods to access this.
-class ParserState {
-public:
-  ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
-      : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {}
-
+/// This class contains record of any parsed top-level aliases.
+struct AliasState {
   // A map from attribute alias identifier to Attribute.
   llvm::StringMap<Attribute> attributeAliasDefinitions;
 
   // A map from type alias identifier to Type.
   llvm::StringMap<Type> typeAliasDefinitions;
+};
 
-private:
+//===----------------------------------------------------------------------===//
+// ParserState
+//===----------------------------------------------------------------------===//
+
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc.
+struct ParserState {
+  ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
+              AliasState &aliases)
+      : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
+        aliases(aliases) {}
   ParserState(const ParserState &) = delete;
   void operator=(const ParserState &) = delete;
 
-  friend class Parser;
-
-  // The context we're parsing into.
+  /// The context we're parsing into.
   MLIRContext *const context;
 
-  // The lexer for the source file we're parsing.
+  /// The lexer for the source file we're parsing.
   Lexer lex;
 
-  // This is the next token that hasn't been consumed yet.
+  /// This is the next token that hasn't been consumed yet.
   Token curToken;
+
+  /// Any parsed alias state.
+  AliasState &aliases;
 };
 
 //===----------------------------------------------------------------------===//
@@ -348,6 +355,55 @@
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DialectAsmParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides the main implementation of the DialectAsmParser that
+/// allows for dialects to parse attributes and types. This allows for dialect
+/// hooking into the main MLIR parsing logic.
+class CustomDialectAsmParser : public DialectAsmParser {
+public:
+  CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
+      : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
+        parser(parser) {}
+  ~CustomDialectAsmParser() override {}
+
+  /// Emit a diagnostic at the specified location and return failure.
+  InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+    return parser.emitError(loc, message);
+  }
+
+  /// Return a builder which provides useful access to MLIRContext, global
+  /// objects like types and attributes.
+  Builder &getBuilder() const override { return parser.builder; }
+
+  /// Get the location of the next token and store it into the argument.  This
+  /// always succeeds.
+  llvm::SMLoc getCurrentLocation() override {
+    return parser.getToken().getLoc();
+  }
+
+  /// Return the location of the original name token.
+  llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+  /// Returns the full specification of the symbol being parsed. This allows
+  /// for using a separate parser if necessary.
+  StringRef getFullSymbolSpec() const override { return fullSpec; }
+
+private:
+  /// The full symbol specification.
+  StringRef fullSpec;
+
+  /// The source location of the dialect symbol.
+  SMLoc nameLoc;
+
+  /// The main parser.
+  Parser &parser;
+};
+} // namespace
+
 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
 /// and may be recursive.  Return with the 'prettyName' StringRef encompassing
 /// the entire pretty name.
@@ -486,8 +542,42 @@
   }
 
   // Call into the provided symbol construction function.
-  auto encodedLoc = p.getEncodedSourceLocation(loc);
-  return createSymbol(dialectName, symbolData, encodedLoc);
+  return createSymbol(dialectName, symbolData, loc);
+}
+
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
+                     AliasState &aliasState, ParserFn &&parserFn,
+                     size_t *numRead = nullptr) {
+  SourceMgr sourceMgr;
+  auto memBuffer = MemoryBuffer::getMemBuffer(
+      inputStr, /*BufferName=*/"<mlir_parser_buffer>",
+      /*RequiresNullTerminator=*/false);
+  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+  ParserState state(sourceMgr, context, aliasState);
+  Parser parser(state);
+
+  Token startTok = parser.getToken();
+  T symbol = parserFn(parser);
+  if (!symbol)
+    return T();
+
+  // If 'numRead' is valid, then provide the number of bytes that were read.
+  Token endTok = parser.getToken();
+  if (numRead) {
+    *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
+                                   startTok.getLoc().getPointer());
+
+    // Otherwise, ensure that all of the tokens were parsed.
+  } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
+    parser.emitError(endTok.getLoc(),
+                     "encountered unexpected tokens after parsing");
+    return T();
+  }
+  return symbol;
 }
 
 //===----------------------------------------------------------------------===//
@@ -611,16 +701,24 @@
 ///
 Type Parser::parseExtendedType() {
   return parseExtendedSymbol<Type>(
-      *this, Token::exclamation_identifier, state.typeAliasDefinitions,
-      [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type {
+      *this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions,
+      [&](StringRef dialectName, StringRef symbolData,
+          llvm::SMLoc loc) -> Type {
+        Location encodedLoc = getEncodedSourceLocation(loc);
+
         // If we found a registered dialect, then ask it to parse the type.
-        if (auto *dialect = state.context->getRegisteredDialect(dialectName))
-          return dialect->parseType(symbolData, loc);
+        if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+          return parseSymbol<Type>(
+              symbolData, state.context, state.aliases, [&](Parser &parser) {
+                CustomDialectAsmParser customParser(symbolData, parser);
+                return dialect->parseType(customParser, encodedLoc);
+              });
+        }
 
         // Otherwise, form a new opaque type.
         return OpaqueType::getChecked(
             Identifier::get(dialectName, state.context), symbolData,
-            state.context, loc);
+            state.context, encodedLoc);
       });
 }
 
@@ -1217,22 +1315,29 @@
 ///
 Attribute Parser::parseExtendedAttr(Type type) {
   Attribute attr = parseExtendedSymbol<Attribute>(
-      *this, Token::hash_identifier, state.attributeAliasDefinitions,
+      *this, Token::hash_identifier, state.aliases.attributeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData,
-          Location loc) -> Attribute {
+          llvm::SMLoc loc) -> Attribute {
         // Parse an optional trailing colon type.
         Type attrType = type;
         if (consumeIf(Token::colon) && !(attrType = parseType()))
           return Attribute();
 
         // If we found a registered dialect, then ask it to parse the attribute.
-        if (auto *dialect = state.context->getRegisteredDialect(dialectName))
-          return dialect->parseAttribute(symbolData, attrType, loc);
+        if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+          return parseSymbol<Attribute>(
+              symbolData, state.context, state.aliases, [&](Parser &parser) {
+                CustomDialectAsmParser customParser(symbolData, parser);
+                return dialect->parseAttribute(customParser, attrType,
+                                               getEncodedSourceLocation(loc));
+              });
+        }
 
         // Otherwise, form a new opaque attribute.
         return OpaqueAttr::getChecked(
             Identifier::get(dialectName, state.context), symbolData,
-            attrType ? attrType : NoneType::get(state.context), loc);
+            attrType ? attrType : NoneType::get(state.context),
+            getEncodedSourceLocation(loc));
       });
 
   // Ensure that the attribute has the same type as requested.
@@ -4137,7 +4242,7 @@
   StringRef aliasName = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  if (getState().attributeAliasDefinitions.count(aliasName) > 0)
+  if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0)
     return emitError("redefinition of attribute alias id '" + aliasName + "'");
 
   // Make sure this isn't invading the dialect attribute namespace.
@@ -4156,7 +4261,7 @@
   if (!attr)
     return failure();
 
-  getState().attributeAliasDefinitions[aliasName] = attr;
+  getState().aliases.attributeAliasDefinitions[aliasName] = attr;
   return success();
 }
 
@@ -4169,7 +4274,7 @@
   StringRef aliasName = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  if (getState().typeAliasDefinitions.count(aliasName) > 0)
+  if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0)
     return emitError("redefinition of type alias id '" + aliasName + "'");
 
   // Make sure this isn't invading the dialect type namespace.
@@ -4190,7 +4295,7 @@
     return failure();
 
   // Register this alias with the parser state.
-  getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType);
+  getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
   return success();
 }
 
@@ -4269,7 +4374,8 @@
   OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
       sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
 
-  ParserState state(sourceMgr, context);
+  AliasState aliasState;
+  ParserState state(sourceMgr, context, aliasState);
   if (ModuleParser(state).parseModule(*module))
     return nullptr;
 
@@ -4334,23 +4440,16 @@
 template <typename T, typename ParserFn>
 static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
                      size_t &numRead, ParserFn &&parserFn) {
-  SourceMgr sourceMgr;
-  auto memBuffer = MemoryBuffer::getMemBuffer(
-      inputStr, /*BufferName=*/"<mlir_parser_buffer>",
-      /*RequiresNullTerminator=*/false);
-  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
-  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
-  ParserState state(sourceMgr, context);
-  Parser parser(state);
-
-  auto start = parser.getToken().getLoc();
-  T symbol = parserFn(parser);
-  if (!symbol)
-    return T();
-
-  auto end = parser.getToken().getLoc();
-  numRead = static_cast<size_t>(end.getPointer() - start.getPointer());
-  return symbol;
+  AliasState aliasState;
+  return parseSymbol<T>(
+      inputStr, context, aliasState,
+      [&](Parser &parser) {
+        SourceMgrDiagnosticHandler handler(
+            const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
+            parser.getContext());
+        return parserFn(parser);
+      },
+      &numRead);
 }
 
 Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context) {