Enable arithmetics for index types.
Arithmetic and comparison instructions are necessary to implement, e.g.,
control flow when lowering MLFunctions to CFGFunctions. (While it is possible
to replace some of the arithmetics by affine_apply instructions for loop
bounds, it is still necessary for loop bounds checking, steps, if-conditions,
non-trivial memref subscripts, etc.) Furthermore, working with indirect
accesses in, e.g., lookup tables for large embeddings, may require operating on
tensors of indexes. For example, the equivalents to C code "LUT[Index[i]]" or
"ResultIndex[i] = i + j" where i, j are loop induction variables require the
arithmetics on indices as well as the possibility to operate on tensors
thereof. Allow arithmetic and comparison operations to apply to index types by
declaring them integer-like. Allow tensors whose element type is index for
indirection purposes.
The absence of vectors with "index" element type is explicitly tested, but the
only justification for this restriction in the CL introducing the test is
"because we don't need them". Do NOT enable vectors of index types, although
it makes vector and tensor types inconsistent with respect to allowed element
types.
PiperOrigin-RevId: 220614055
diff --git a/g3doc/LangRef.md b/g3doc/LangRef.md
index 52485cf..afcacdf 100644
--- a/g3doc/LangRef.md
+++ b/g3doc/LangRef.md
@@ -218,9 +218,8 @@
### Dimensions and Symbols {#dimensions-and-symbols}
Dimensions and symbols are the two kinds of identifiers that can appear in the
-polyhedral structures, and are always of '[index](#other-types)' type.
-Dimensions are declared in parentheses and symbols are declared in square
-brackets.
+polyhedral structures, and are always of '[index](#index-type)' type. Dimensions
+are declared in parentheses and symbols are declared in square brackets.
Examples:
@@ -303,7 +302,7 @@
second argument is always positive, its results are always positive in our
usage. The `integer-literal` operand for ceildiv, floordiv, and mod is always
expected to be positive. `bare-id` is an identifier which must have type
-[index](#other-types). The precedence of operations in an affine expression are
+[index](#index-type). The precedence of operations in an affine expression are
ordered from highest to lowest in the order: (1) parenthesization, (2) negation,
(3) modulo, multiplication, floordiv, and ceildiv, and (4) addition and
subtraction. All of these operators associate from left to right.
@@ -513,6 +512,7 @@
``` {.ebnf}
type ::= integer-type
+ | index-type
| float-type
| other-type
| vector-type
@@ -553,6 +553,22 @@
TODO: Need to decide on a representation for quantized integers
[[initial thoughts](Rationale.md#quantized-integer-operations)].
+### Index Type {#index-type}
+
+The `index` type is a signless integer whose size is equal to the natural
+machine word of the target ([rationale](Rationale.md#signless-types)) and is
+used by the affine constructs in MLIR.
+
+Syntax:
+
+``` {.ebnf}
+// Target word-sized integer.
+index-type ::= `index`
+```
+
+**Rationale:** integers of platform-specific bit widths are practical to express
+sizes, dimensionalities and subscripts.
+
### Floating Point Types {#floating-point-types}
Syntax:
@@ -571,18 +587,11 @@
MLIR supports some special purpose types:
``` {.ebnf}
-// Target word-sized integer.
-other-type ::= `index`
-
// TensorFlow specific types (TODO: the rest ref data types)
other-type ::= `tf_control` | `tf_resource` | `tf_variant` | `tf_string`
`tf_complex64` | `tf_complex128` | `tf_f32ref`
```
-The `index` type is a signless integer whose size is equal to the natural
-machine word of the target [[rationale](Rationale.md#signless-types)] and is
-used by the affine constructs in MLIR.
-
`tf_control` is used in TensorFlow graphs to represent
[control dependence edges](https://docs.google.com/document/d/1Iey7MfrAlBWd0nrHNdnVKvIKRoo8XHsWG5g5pi1iDV4/edit?ts=5b5a0a9f#heading=h.1dv5wuya469j).
@@ -622,7 +631,7 @@
``` {.ebnf}
tensor-type ::= `tensor` `<` dimension-list vector-element-type `>`
-tensor-memref-element-type ::= vector-element-type | vector-type
+tensor-memref-element-type ::= vector-element-type | vector-type | index-type
// memref requires a known rank, but tensor does not.
dimension-list ::= dimension-list-ranked | `*` `x`
@@ -1101,7 +1110,7 @@
The `for` statement in an ML Function represents an affine loop nest, defining
an SSA value for its induction variable. This SSA value always has type
-[`index`](#other-types), which is the size of the machine word.
+[`index`](#index-type), which is the size of the machine word.
The `for` statement executes its body a number of times iterating from a lower
bound to an upper bound by a stride. The stride, represented by `step`, is a
@@ -1324,7 +1333,7 @@
```
The `dim` operation takes a memref or tensor operand and a dimension index, and
-returns an ['index'](#other-types) that is the size of that dimension.
+returns an ['index'](#index-type) that is the size of that dimension.
The `dim` operation is represented with a single integer attribute named
`index`, and the type specifies the type of the memref or tensor operand.
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index a3be153..89346b8 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -617,8 +617,8 @@
}
};
-/// This class verifies that any results of the specified op have an integer
-/// type, a vector thereof, or a tensor thereof.
+/// This class verifies that any results of the specified op have an integer or
+/// index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreIntegerLike
: public TraitBase<ConcreteType, ResultsAreIntegerLike> {
@@ -648,8 +648,8 @@
}
};
-/// This class verifies that all operands of the specified op have an integer
-/// type, a vector thereof, or a tensor thereof.
+/// This class verifies that all operands of the specified op have an integer or
+/// index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreIntegerLike
: public TraitBase<ConcreteType, OperandsAreIntegerLike> {
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 8f3f069..1f2f242 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -500,7 +500,8 @@
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type type) {
return type.isa<FloatType>() || type.isa<VectorType>() ||
- type.isa<IntegerType>() || type.isa<OtherType>();
+ type.isa<IntegerType>() || type.isa<OtherType>() ||
+ type.isa<IndexType>();
}
} // end namespace mlir
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a05a233..2c02592 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -749,7 +749,7 @@
// Check that memref is formed from allowed types.
if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
- !elementType.isa<VectorType>()) {
+ !elementType.isa<VectorType>() && !elementType.isa<IntegerType>()) {
if (location)
context->emitDiagnostic(location, "invalid memref element type",
MLIRContext::DiagnosticKind::Error);
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index c2e013d..9501128 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -362,10 +362,17 @@
return type;
}
+// Checks if the given type is an integer or an index type. Following LLVM's
+// convention, returns true if the check fails and false otherwise.
+static inline bool checkIntegerLikeType(Type type) {
+ return !(type.isa<IntegerType>() || type.isa<IndexType>());
+}
+
bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) {
for (auto *operand : op->getOperands()) {
- if (!getTensorOrVectorElementType(operand->getType()).isa<IntegerType>())
- return op->emitOpError("requires an integer type");
+ auto type = getTensorOrVectorElementType(operand->getType());
+ if (checkIntegerLikeType(type))
+ return op->emitOpError("requires an integer or index type");
}
return false;
}
@@ -436,8 +443,9 @@
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
for (auto *result : op->getResults()) {
- if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
- return op->emitOpError("requires an integer type");
+ auto type = getTensorOrVectorElementType(result->getType());
+ if (checkIntegerLikeType(type))
+ return op->emitOpError("requires an integer or index type");
}
return false;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 99b8837..c55f830 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -289,6 +289,7 @@
/// Parse an arbitrary type.
///
/// type ::= integer-type
+/// | index-type
/// | float-type
/// | other-type
/// | vector-type
@@ -296,8 +297,9 @@
/// | memref-type
/// | function-type
///
+/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
-/// other-type ::= `index` | `tf_control`
+/// other-type ::= `tf_control`
///
Type Parser::parseType() {
switch (getToken().getKind()) {
@@ -338,10 +340,12 @@
consumeToken(Token::kw_f64);
return builder.getF64Type();
- // other-type
+ // index-type
case Token::kw_index:
consumeToken(Token::kw_index);
return builder.getIndexType();
+
+ // other-type
case Token::kw_tf_control:
consumeToken(Token::kw_tf_control);
return builder.getTFControlType();
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index 63c3e1e..f2c22f2 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -429,7 +429,7 @@
// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Builder *build, Type type) {
auto i1Type = build->getIntegerType(1);
- if (type.isa<IntegerType>() || type.isa<FloatType>())
+ if (type.isa<IntegerType>() || type.isa<FloatType>() || type.isa<IndexType>())
return i1Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return build->getTensorType(tensorType.getShape(), i1Type);
@@ -458,7 +458,8 @@
// Checks if "type" has the same shape (scalar, vector or tensor) as "pattern"
// and contains i1.
static bool checkI1SameShape(Type pattern, Type type) {
- if (pattern.isa<IntegerType>() || pattern.isa<FloatType>())
+ if (pattern.isa<IntegerType>() || pattern.isa<FloatType>() ||
+ pattern.isa<IndexType>())
return !isI1(type);
if (auto patternTensorType = pattern.dyn_cast<TensorType>())
return implCheckI1SameShape(patternTensorType, type);
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index 68ba145..78b0c07 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -24,44 +24,50 @@
return
}
-// CHECK-LABEL: cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32, i32) {
-cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32, i32) {
-// CHECK: bb0(%arg0: tensor<4x4x?xf32>, %arg1: f32, %arg2: i32):
-bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32):
+// CHECK-LABEL: cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
+cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
+// CHECK: bb0(%arg0: tensor<4x4x?xf32>, %arg1: f32, %arg2: i32, %arg3: index):
+bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index):
// CHECK: %0 = dim %arg0, 2 : tensor<4x4x?xf32>
%a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> index
// CHECK: %1 = dim %arg0, 2 : tensor<4x4x?xf32>
%a2 = dim %t, 2 : tensor<4x4x?xf32>
-
+
// CHECK: %2 = addf %arg1, %arg1 : f32
%f2 = "addf"(%f, %f) : (f32,f32) -> f32
// CHECK: %3 = addf %2, %2 : f32
%f3 = addf %f2, %f2 : f32
-
+
// CHECK: %4 = addi %arg2, %arg2 : i32
%i2 = "addi"(%i, %i) : (i32,i32) -> i32
// CHECK: %5 = addi %4, %4 : i32
%i3 = addi %i2, %i2 : i32
-
- // CHECK: %6 = subf %arg1, %arg1 : f32
+
+ // CHECK: %{{[0-9]+}} = addi %arg3, %arg3 : index
+ %idx1 = addi %idx, %idx : index
+
+ // CHECK: %{{[0-9]+}} = addi %arg3, %{{[0-9]+}} : index
+ %idx2 = "addi"(%idx, %idx1) : (index, index) -> index
+
+ // CHECK: %8 = subf %arg1, %arg1 : f32
%f4 = "subf"(%f, %f) : (f32,f32) -> f32
- // CHECK: %7 = subf %6, %6 : f32
+ // CHECK: %9 = subf %8, %8 : f32
%f5 = subf %f4, %f4 : f32
-
- // CHECK: %8 = subi %arg2, %arg2 : i32
+
+ // CHECK: %10 = subi %arg2, %arg2 : i32
%i4 = "subi"(%i, %i) : (i32,i32) -> i32
- // CHECK: %9 = subi %8, %8 : i32
+ // CHECK: %11 = subi %10, %10 : i32
%i5 = subi %i4, %i4 : i32
-
- // CHECK: %10 = mulf %2, %2 : f32
+
+ // CHECK: %12 = mulf %2, %2 : f32
%f6 = mulf %f2, %f2 : f32
-
- // CHECK: %11 = muli %4, %4 : i32
+
+ // CHECK: %13 = muli %4, %4 : i32
%i6 = muli %i2, %i2 : i32
// CHECK: %c42_i32 = constant 42 : i32
@@ -88,6 +94,9 @@
// CHECK: %cst_3 = constant splat<vector<4xi32>, 0> : vector<4xi32>
%13 = constant splat<vector<4 x i32>, 0> : vector<4 x i32>
+ // CHECK: %cst_4 = constant splat<tensor<42xindex>, 0> : tensor<42xindex>
+ %tidx = constant splat<tensor<42 x index>, 0> : tensor<42 x index>
+
// CHECK: %{{[0-9]+}} = cmpi "eq", %{{[0-9]+}}, %{{[0-9]+}} : i32
%14 = cmpi "eq", %i3, %i4 : i32
@@ -101,6 +110,12 @@
// CHECK: %{{[0-9]+}} = cmpi "ne", %cst_3, %cst_3 : vector<4xi32>
%17 = "cmpi"(%13, %13) {predicate: 1} : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i1>
+ // CHECK: %{{[0-9]+}} = cmpi "slt", %arg3, %arg3 : index
+ %18 = cmpi "slt", %idx, %idx : index
+
+ // CHECK: %{{[0-9]+}} = cmpi "eq", %cst_4, %cst_4 : tensor<42xindex>
+ %19 = cmpi "eq", %tidx, %tidx : tensor<42 x index>
+
return
}
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index c799c98..5550f07 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -195,7 +195,7 @@
// Integer comparisons are not recognized for float types.
cfgfunc @cfgfunc_with_ops(f32, f32) {
bb0(%a : f32, %b : f32):
- %r = cmpi "eq", %a, %b : f32 // expected-error {{op requires an integer type}}
+ %r = cmpi "eq", %a, %b : f32 // expected-error {{op requires an integer or index type}}
}
// -----