Materialize IndexType in the API.
Previously, index (aka affint) type was hidden under OtherType in the type API.
We will need to identify and operate on values of index types in the upcoming
MLFunc->CFGFunc(->LLVM) lowering passes. Materialize index type into a
separate class and make it visible to LLVM RTTI hierarchy directly.
Practically, index is an integer type of unknown bit width and is accetable in
most places where regular integer types are. This is purely an API change that
does not affect the IR.
After IndexType is separated out from OtherType, the remaining "other types"
are, in fact, TF-specific types only. Further renaming may be of interest.
PiperOrigin-RevId: 220614026
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 01b04a2..780652d 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -73,7 +73,8 @@
FloatType getF32Type();
FloatType getF64Type();
- OtherType getIndexType();
+ IndexType getIndexType();
+
OtherType getTFControlType();
OtherType getTFStringType();
OtherType getTFResourceType();
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index bcdecef..8f3f069 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -25,6 +25,7 @@
namespace mlir {
class AffineMap;
class FloatType;
+class IndexType;
class IntegerType;
class Location;
class MLIRContext;
@@ -33,6 +34,7 @@
namespace detail {
class TypeStorage;
+class IndexTypeStorage;
class IntegerTypeStorage;
class FloatTypeStorage;
struct OtherTypeStorage;
@@ -66,7 +68,7 @@
TFString,
/// These are marker for the first and last 'other' type.
- FIRST_OTHER_TYPE = Index,
+ FIRST_OTHER_TYPE = TFControl,
LAST_OTHER_TYPE = TFString,
// Floating point.
@@ -138,12 +140,12 @@
unsigned getBitWidth() const;
// Convenience factories.
+ static IndexType getIndex(MLIRContext *ctx);
static IntegerType getInteger(unsigned width, MLIRContext *ctx);
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
- static OtherType getIndex(MLIRContext *ctx);
static OtherType getTFControl(MLIRContext *ctx);
static OtherType getTFString(MLIRContext *ctx);
static OtherType getTFResource(MLIRContext *ctx);
@@ -236,6 +238,21 @@
return FloatType::get(Kind::F64, ctx);
}
+/// Index is special integer-like type with unknown platform-dependent bit width
+/// used in subscripts and loop induction variables.
+class IndexType : public Type {
+public:
+ using ImplType = detail::IndexTypeStorage;
+ IndexType() = default;
+ /* implicit */ IndexType(Type::ImplType *ptr);
+
+ /// Crete an IndexType instance, unique in the given context.
+ static IndexType get(MLIRContext *context);
+
+ /// Support method to enable LLVM-style type casting.
+ static bool kindof(Kind kind) { return kind == Kind::Index; }
+};
+
/// This is a type for the random collection of special base types.
class OtherType : public Type {
public:
@@ -251,8 +268,8 @@
}
};
-inline OtherType Type::getIndex(MLIRContext *ctx) {
- return OtherType::get(Kind::Index, ctx);
+inline IndexType Type::getIndex(MLIRContext *ctx) {
+ return IndexType::get(ctx);
}
inline OtherType Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx);
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 906b580..c143857 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -60,7 +60,7 @@
FloatType Builder::getF64Type() { return Type::getF64(context); }
-OtherType Builder::getIndexType() { return Type::getIndex(context); }
+IndexType Builder::getIndexType() { return Type::getIndex(context); }
OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 4148e19..a05a233 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -323,6 +323,9 @@
// Uniqui'ing of AffineConstantExprStorage using constant value as key.
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
+ /// Unique index type (lazily constructed).
+ IndexTypeStorage *indexType = nullptr;
+
/// Integer type uniquing.
DenseMap<unsigned, IntegerTypeStorage *> integers;
@@ -554,6 +557,17 @@
// Type uniquing
//===----------------------------------------------------------------------===//
+IndexType IndexType::get(MLIRContext *context) {
+ auto &impl = context->getImpl();
+
+ if (impl.indexType)
+ return impl.indexType;
+
+ impl.indexType = impl.allocator.Allocate<IndexTypeStorage>();
+ new (impl.indexType) IndexTypeStorage{{Kind::Index, context}};
+ return impl.indexType;
+}
+
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
auto &impl = context->getImpl();
diff --git a/lib/IR/TypeDetail.h b/lib/IR/TypeDetail.h
index c22e87a..80a13e0 100644
--- a/lib/IR/TypeDetail.h
+++ b/lib/IR/TypeDetail.h
@@ -56,6 +56,8 @@
unsigned subclassData : 24;
};
+struct IndexTypeStorage : public TypeStorage {};
+
struct IntegerTypeStorage : public TypeStorage {
unsigned width;
};
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 1a71695..a01b1c0 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -53,6 +53,8 @@
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+IndexType::IndexType(Type::ImplType *ptr) : Type(ptr) {}
+
IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
unsigned IntegerType::getWidth() const {