Add support for building a DenseIntElementsAttr with ArrayRef<int64_t> values.
PiperOrigin-RevId: 239616595
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index 09ddd46..3169364 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -428,12 +428,18 @@
using DenseElementsAttr::getValues;
using DenseElementsAttr::ImplType;
- // Constructs a dense integer elements attribute from an array of APInt
- // values. Each APInt value is expected to have the same bitwidth as the
- // element type of 'type'.
+ /// Constructs a dense integer elements attribute from an array of APInt
+ /// values. Each APInt value is expected to have the same bitwidth as the
+ /// element type of 'type'.
static DenseIntElementsAttr get(VectorOrTensorType type,
ArrayRef<APInt> values);
+ /// Constructs a dense integer elements attribute from an array of integer
+ /// values. Each value is expected to be within the bitwidth of the element
+ /// type of 'type'.
+ static DenseIntElementsAttr get(VectorOrTensorType type,
+ ArrayRef<int64_t> values);
+
/// Gets the integer value of each of the dense elements.
void getValues(SmallVectorImpl<APInt> &values) const;
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 78eb7a5..b3aba25 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -114,6 +114,8 @@
ArrayRef<char> data);
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<Attribute> values);
+ ElementsAttr getDenseIntElementsAttr(VectorOrTensorType type,
+ ArrayRef<int64_t> values);
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
index d64f9c4..6da501f 100644
--- a/lib/IR/Attributes.cpp
+++ b/lib/IR/Attributes.cpp
@@ -356,14 +356,30 @@
/// DenseIntElementsAttr
-// Constructs a dense integer elements attribute from an array of APInt
-// values. Each APInt value is expected to have the same bitwidth as the
-// element type of 'type'.
+/// Constructs a dense integer elements attribute from an array of APInt
+/// values. Each APInt value is expected to have the same bitwidth as the
+/// element type of 'type'.
DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
ArrayRef<APInt> values) {
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
}
+/// Constructs a dense integer elements attribute from an array of integer
+/// values. Each value is expected to be within the bitwidth of the element
+/// type of 'type'.
+DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
+ ArrayRef<int64_t> values) {
+ auto eltType = type.getElementType();
+ size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
+
+ // Convert the raw integer values to APInt.
+ SmallVector<APInt, 8> apIntValues;
+ apIntValues.reserve(values.size());
+ for (auto value : values)
+ apIntValues.emplace_back(APInt(bitWidth, value));
+ return get(type, apIntValues);
+}
+
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
// Simply return the raw integer values.
getRawValues(values);
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 6f1936b..917bae6 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -185,6 +185,11 @@
return DenseElementsAttr::get(type, values);
}
+ElementsAttr Builder::getDenseIntElementsAttr(VectorOrTensorType type,
+ ArrayRef<int64_t> values) {
+ return DenseIntElementsAttr::get(type, values);
+}
+
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {