[mlir][vector][memref] Add `alignment` attribute to memory access ops (#144344)
Alignment information is important to allow LLVM backends such as AMDGPU
to select wide memory accesses (e.g., dwordx4 or b128). Since this info
is not always inferable, it's better to inform LLVM backends explicitly
about it. Furthermore, alignment is not necessarily a property of the
element type, but of each individual memory access op (we can have
overaligned and underaligned accesses compared to the natural/preferred
alignment of the element type).
This patch introduces `alignment` attribute to memref/vector.load/store
ops.
Follow-up PRs will
1. Propagate the attribute to LLVM/SPIR-V.
2. Introduce `alignment` attribute to other vector memory access ops:
vector.gather + vector.scatter
vector.transfer_read + vector.transfer_write
vector.compressstore + vector.expandload
vector.maskedload + vector.maskedstore
3. Replace `--convert-vector-to-llvm='use-vector-alignment=1` with a
simple pass to populate alignment attributes based on the vector
types.diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md
index b3bde05..2225329 100644
--- a/mlir/docs/DefiningDialects/Operations.md
+++ b/mlir/docs/DefiningDialects/Operations.md
@@ -306,6 +306,8 @@
* `IntPositive`: Specifying an integer attribute whose value is positive
* `IntNonNegative`: Specifying an integer attribute whose value is
non-negative
+* `IntPowerOf2`: Specifying an integer attribute whose value is a power of
+ two > 0
* `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
elements
* `ArrayMaxCount<N>`: Specifying an array attribute to have at most `N`
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 09bb393..9321089a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1216,6 +1216,11 @@
be reused in the cache. For details, refer to the
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ load operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violations may lead to
+ architecture-specific faults or performance penalties.
+ A value of 0 indicates no specific alignment requirement.
Example:
```mlir
@@ -1226,7 +1231,39 @@
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices,
- DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+ let builders = [
+ OpBuilder<(ins "Value":$memref,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, memref, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>,
+ OpBuilder<(ins "Type":$resultType,
+ "Value":$memref,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, resultType, memref, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>,
+ OpBuilder<(ins "TypeRange":$resultTypes,
+ "Value":$memref,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, resultTypes, memref, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>
+ ];
+
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
@@ -1912,6 +1949,11 @@
be reused in the cache. For details, refer to the
[https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ store operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violations may lead to
+ architecture-specific faults or performance penalties.
+ A value of 0 indicates no specific alignment requirement.
Example:
```mlir
@@ -1923,13 +1965,25 @@
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices,
- DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
let builders = [
+ OpBuilder<(ins "Value":$valueToStore,
+ "Value":$memref,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>,
OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
$_state.addOperands(valueToStore);
$_state.addOperands(memref);
- }]>];
+ }]>
+ ];
let extraClassDeclaration = [{
Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cbe490f..e07188a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1809,12 +1809,42 @@
```mlir
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
```
+
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ load operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violations may lead to
+ architecture-specific faults or performance penalties.
+ A value of 0 indicates no specific alignment requirement.
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$base,
Variadic<Index>:$indices,
- DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType,
+ "Value":$base,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, resultType, base, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>,
+ OpBuilder<(ins "TypeRange":$resultTypes,
+ "Value":$base,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, resultTypes, base, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>
+ ];
+
let results = (outs AnyVectorOfAnyRank:$result);
let extraClassDeclaration = [{
@@ -1895,6 +1925,12 @@
```mlir
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
```
+
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ store operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violations may lead to
+ architecture-specific faults or performance penalties.
+ A value of 0 indicates no specific alignment requirement.
}];
let arguments = (ins
@@ -1902,8 +1938,21 @@
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$base,
Variadic<Index>:$indices,
- DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
- );
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+ let builders = [
+ OpBuilder<(ins "Value":$valueToStore,
+ "Value":$base,
+ "ValueRange":$indices,
+ CArg<"bool", "false">:$nontemporal,
+ CArg<"uint64_t", "0">:$alignment), [{
+ return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+ alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ nullptr);
+ }]>
+ ];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index e91a13f..18da85a 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -796,6 +796,10 @@
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">,
"whose value is positive">;
+def IntPowerOf2 : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
+ "whose value is a power of two > 0">;
+
class ArrayMaxCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
"with at most " # n # " elements">;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 704cdaf..fa803ef 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -962,6 +962,24 @@
// -----
+func.func @invalid_load_alignment(%memref: memref<4xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @below {{'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+ return
+}
+
+// -----
+
+func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: i32) {
+ %c0 = arith.constant 0 : index
+ // expected-error @below {{'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ memref.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>
+ return
+}
+
+// -----
+
func.func @test_alloc_memref_map_rank_mismatch() {
^bb0:
// expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index e11de7b..6c2298a 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -265,6 +265,17 @@
// CHECK: memref.store %{{.*}}, %{{.*}}[] : memref<i32>
}
+
+// CHECK-LABEL: func @load_store_alignment
+func.func @load_store_alignment(%memref: memref<4xi32>) {
+ %c0 = arith.constant 0 : index
+ // CHECK: memref.load {{.*}} {alignment = 16 : i64}
+ %val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32>
+ // CHECK: memref.store {{.*}} {alignment = 16 : i64}
+ memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>
+ return
+}
+
// CHECK-LABEL: func @memref_view(%arg0
func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = memref.alloc() : memref<2048xi8>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5038646..8017140 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1995,6 +1995,15 @@
// -----
+func.func @invalid_load_alignment(%memref: memref<4xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
@@ -2005,3 +2014,12 @@
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
return
}
+
+// -----
+
+func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
+ return
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 10bf0f1..39578ac 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -853,6 +853,16 @@
return
}
+// CHECK-LABEL: func @load_store_alignment
+func.func @load_store_alignment(%memref: memref<4xi32>) {
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.load {{.*}} {alignment = 16 : i64}
+ %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
+ // CHECK: vector.store {{.*}} {alignment = 16 : i64}
+ vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
+ return
+}
+
// CHECK-LABEL: @masked_load_and_store
func.func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
%c0 = arith.constant 0 : index