Change unranked tensor syntax from tensor<??f32> to tensor<*xf32> per
discussion on the list.
PiperOrigin-RevId: 212838226
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 83b9493..ddcd2aa 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -521,7 +521,7 @@
}
case Type::Kind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(type);
- os << "tensor<??";
+ os << "tensor<*x";
printType(v->getElementType());
os << '>';
return;
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 6b1fcd4..b19ddf3 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -114,11 +114,6 @@
return formToken(Token::minus, tokStart);
case '?':
- if (*curPtr == '?') {
- ++curPtr;
- return formToken(Token::questionquestion, tokStart);
- }
-
return formToken(Token::question, tokStart);
case '/':
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d1413a3..6cfb166 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -177,6 +177,7 @@
// Type parsing.
VectorType *parseVectorType();
+ ParseResult parseXInDimensionList();
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
Type *parseTensorType();
Type *parseMemRefType();
@@ -387,6 +388,23 @@
return VectorType::get(dimensions, elementType);
}
+/// Parse an 'x' token in a dimension list, handling the case where the x is
+/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
+/// token.
+ParseResult Parser::parseXInDimensionList() {
+ if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
+ return emitError("expected 'x' in dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (getTokenSpelling().size() != 1)
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+
+ return ParseSuccess;
+}
+
/// Parse a dimension list of a tensor or memref type. This populates the
/// dimension list, returning -1 for the '?' dimensions.
///
@@ -407,16 +425,8 @@
}
// Make sure we have an 'x' or something like 'xbf32'.
- if (getToken().isNot(Token::bare_identifier) ||
- getTokenSpelling()[0] != 'x')
- return emitError("expected 'x' in dimension list");
-
- // If we had a prefix of 'x', lex the next token immediately after the 'x'.
- if (getTokenSpelling().size() != 1)
- state.lex.resetPointer(getTokenSpelling().data() + 1);
-
- // Consume the 'x'.
- consumeToken(Token::bare_identifier);
+ if (parseXInDimensionList())
+ return ParseFailure;
}
return ParseSuccess;
@@ -425,7 +435,7 @@
/// Parse a tensor type.
///
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
-/// dimension-list ::= dimension-list-ranked | `??`
+/// dimension-list ::= dimension-list-ranked | `*x`
///
Type *Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
@@ -436,8 +446,13 @@
bool isUnranked;
SmallVector<int, 4> dimensions;
- if (consumeIf(Token::questionquestion)) {
+ if (consumeIf(Token::star)) {
+ // This is an unranked tensor type.
isUnranked = true;
+
+ if (parseXInDimensionList())
+ return nullptr;
+
} else {
isUnranked = false;
if (parseDimensionListRanked(dimensions))
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index d2503a8..4e1cd2e 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -69,7 +69,6 @@
TOK_PUNCTUATION(colon, ":")
TOK_PUNCTUATION(comma, ",")
TOK_PUNCTUATION(question, "?")
-TOK_PUNCTUATION(questionquestion, "??")
TOK_PUNCTUATION(l_paren, "(")
TOK_PUNCTUATION(r_paren, ")")
TOK_PUNCTUATION(l_brace, "{")
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index bad177f..591f62d 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -132,12 +132,12 @@
return
}
-// CHECK-LABEL: mlfunc @extract_element(%arg0 : tensor<??i32>, %arg1 : tensor<4x4xf32>) -> i32 {
-mlfunc @extract_element(%arg0 : tensor<??i32>, %arg1 : tensor<4x4xf32>) -> i32 {
+// CHECK-LABEL: mlfunc @extract_element(%arg0 : tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
+mlfunc @extract_element(%arg0 : tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
%c0 = "constant"() {value: 0} : () -> affineint
- // CHECK: %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<??i32>
- %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<??i32>
+ // CHECK: %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
+ %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
// CHECK: %1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
%1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
@@ -146,12 +146,12 @@
}
// CHECK-LABEL: mlfunc @shape_cast(%arg0
-mlfunc @shape_cast(%arg0 : tensor<??f32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
- // CHECK: %0 = shape_cast %arg0 : tensor<??f32> to tensor<?x?xf32>
- %0 = shape_cast %arg0 : tensor<??f32> to tensor<?x?xf32>
+mlfunc @shape_cast(%arg0 : tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
+ // CHECK: %0 = shape_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ %0 = shape_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
- // CHECK: %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<??f32>
- %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<??f32>
+ // CHECK: %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
+ %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
// CHECK: %2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
%2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 8bdfe8b..ec33869 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -57,8 +57,8 @@
// CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
-// CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
-extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
+// CHECK: extfunc @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
+extfunc @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
tensor<1x?x4x?x?xi32>, tensor<i8>)
// CHECK: extfunc @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<i8, #map{{[0-9]+}}>)
@@ -394,8 +394,8 @@
// CHECK-LABEL: cfgfunc @typeattr
cfgfunc @typeattr() -> () {
bb0:
-// CHECK: "foo"() {bar: tensor<??f32>} : () -> ()
- "foo"(){bar: tensor<??f32>} : () -> ()
+// CHECK: "foo"() {bar: tensor<*xf32>} : () -> ()
+ "foo"(){bar: tensor<*xf32>} : () -> ()
return
}