Return an error when parseType doesnt parse the entire string passed
PiperOrigin-RevId: 255505300
diff --git a/include/mlir/Parser.h b/include/mlir/Parser.h
index 85c183a..a2673ca 100644
--- a/include/mlir/Parser.h
+++ b/include/mlir/Parser.h
@@ -59,7 +59,8 @@
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error message is emitted through a new SourceMgrDiagnosticHandler
/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `typeStr`.
+/// `typeStr`. If the passed `typeStr` has additional tokens that were not part
+/// of the type, an error is emitted.
// TODO(ntv) Improve diagnostic reporting.
Type parseType(llvm::StringRef typeStr, MLIRContext *context);
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 3e0f4e8..3d4204c 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -4205,5 +4205,19 @@
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
ParserState state(sourceMgr, context);
- return Parser(state).parseType();
+ Parser parser(state);
+ auto start = parser.getToken().getLoc();
+ auto ty = parser.parseType();
+ if (!ty)
+ return Type();
+
+ auto end = parser.getToken().getLoc();
+ auto read = end.getPointer() - start.getPointer();
+ // Make sure that the parsing of type consumes the entire string
+ if (static_cast<size_t>(read) < typeStr.size()) {
+ parser.emitError("unexpected additional tokens: '")
+ << typeStr.substr(read) << "' after parsing type: " << ty;
+ return Type();
+ }
+ return ty;
}
diff --git a/lib/SPIRV/SPIRVDialect.cpp b/lib/SPIRV/SPIRVDialect.cpp
index 46ea021..816d673 100644
--- a/lib/SPIRV/SPIRVDialect.cpp
+++ b/lib/SPIRV/SPIRVDialect.cpp
@@ -80,7 +80,7 @@
static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto *context = dialect.getContext();
- auto type = mlir::parseType(spec, context);
+ auto type = mlir::parseType(spec.trim(), context);
if (!type) {
emitError(loc, "cannot parse type: ") << spec;
return Type();
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index a7b3cb2..f55fc5a 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -1050,6 +1050,11 @@
// -----
+// expected-error @+1 {{cannot parse type: i32 f32}}
+func @bad_tuple(!spv.ptr<i32 f32, Uniform>)
+
+// -----
+
func @invalid_region_dominance() {
"foo.region"() ({
// expected-error @+1 {{operand #0 does not dominate this use}}