Simplify TypeUniquer/AttributeUniquer to not require multiple overloads when constructing a new storage.

--

PiperOrigin-RevId: 246356767
diff --git a/include/mlir/IR/TypeSupport.h b/include/mlir/IR/TypeSupport.h
index f174d6d..fc25ea2 100644
--- a/include/mlir/IR/TypeSupport.h
+++ b/include/mlir/IR/TypeSupport.h
@@ -104,34 +104,14 @@
 // MLIRContext. This class manages all creation and uniquing of types.
 class TypeUniquer {
 public:
-  /// Get an uniqued instance of a type T. This overload is used for derived
-  /// types that have complex storage or uniquing constraints.
+  /// Get an uniqued instance of a type T.
   template <typename T, typename... Args>
-  static typename std::enable_if<
-      !std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
-  get(MLIRContext *ctx, unsigned kind, Args &&... args) {
-    // Lookup an instance of this complex storage type.
-    using ImplType = typename T::ImplType;
-    return ctx->getTypeUniquer().getComplex<ImplType>(
-        [&](ImplType *storage) {
-          storage->initializeDialect(lookupDialectForType<T>(ctx));
-        },
-        kind, std::forward<Args>(args)...);
-  }
-
-  /// Get an uniqued instance of a type T. This overload is used for derived
-  /// types that use the DefaultTypeStorage and thus need no additional storage
-  /// or uniquing.
-  template <typename T, typename... Args>
-  static typename std::enable_if<
-      std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
-  get(MLIRContext *ctx, unsigned kind) {
-    // Lookup an instance of this simple storage type.
-    return ctx->getTypeUniquer().getSimple<TypeStorage>(
+  static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+    return ctx->getTypeUniquer().get<typename T::ImplType>(
         [&](TypeStorage *storage) {
           storage->initializeDialect(lookupDialectForType<T>(ctx));
         },
-        kind);
+        kind, std::forward<Args>(args)...);
   }
 
 private:
diff --git a/include/mlir/Support/StorageUniquer.h b/include/mlir/Support/StorageUniquer.h
index 2a9bb4a..5b408f3 100644
--- a/include/mlir/Support/StorageUniquer.h
+++ b/include/mlir/Support/StorageUniquer.h
@@ -122,11 +122,12 @@
   /// that can be used to initialize a newly inserted storage instance. This
   /// function is used for derived types that have complex storage or uniquing
   /// constraints.
-  template <typename Storage, typename... Args>
-  Storage *getComplex(std::function<void(Storage *)> initFn, unsigned kind,
-                      Args &&... args) {
+  template <typename Storage, typename Arg, typename... Args>
+  Storage *get(std::function<void(Storage *)> initFn, unsigned kind, Arg &&arg,
+               Args &&... args) {
     // Construct a value of the derived key type.
-    auto derivedKey = getKey<Storage>(args...);
+    auto derivedKey =
+        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
 
     // Create a hash of the kind and the derived key.
     unsigned hashValue = getHash<Storage>(kind, derivedKey);
@@ -155,7 +156,7 @@
   /// function is used for derived types that use no additional storage or
   /// uniquing outside of the kind.
   template <typename Storage>
-  Storage *getSimple(std::function<void(Storage *)> initFn, unsigned kind) {
+  Storage *get(std::function<void(Storage *)> initFn, unsigned kind) {
     auto ctorFn = [&](StorageAllocator &allocator) {
       auto *storage = new (allocator.allocate<Storage>()) Storage();
       if (initFn)
@@ -167,10 +168,11 @@
 
   /// Erases a uniqued instance of 'Storage'. This function is used for derived
   /// types that have complex storage or uniquing constraints.
-  template <typename Storage, typename... Args>
-  void eraseComplex(unsigned kind, Args &&... args) {
+  template <typename Storage, typename Arg, typename... Args>
+  void erase(unsigned kind, Arg &&arg, Args &&... args) {
     // Construct a value of the derived key type.
-    auto derivedKey = getKey<Storage>(args...);
+    auto derivedKey =
+        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
 
     // Create a hash of the kind and the derived key.
     unsigned hashValue = getHash<Storage>(kind, derivedKey);
diff --git a/lib/IR/AttributeDetail.h b/lib/IR/AttributeDetail.h
index 3db5ec8..bb4d8de 100644
--- a/lib/IR/AttributeDetail.h
+++ b/lib/IR/AttributeDetail.h
@@ -67,45 +67,23 @@
 // MLIRContext. This class manages all creation and uniquing of attributes.
 class AttributeUniquer {
 public:
-  /// Get an uniqued instance of attribute T. This overload is used for
-  /// derived attributes that have complex storage or uniquing constraints.
+  /// Get an uniqued instance of attribute T.
   template <typename T, typename... Args>
-  static typename std::enable_if<
-      !std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
-  get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
-    return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
-        getInitFn(ctx), static_cast<unsigned>(kind),
-        std::forward<Args>(args)...);
-  }
-
-  /// Get an uniqued instance of attribute T. This overload is used for
-  /// derived attributes that use the AttributeStorage directly and thus need no
-  /// additional storage or uniquing.
-  template <typename T, typename... Args>
-  static typename std::enable_if<
-      std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
-  get(MLIRContext *ctx, Attribute::Kind kind) {
-    return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
-        getInitFn(ctx), static_cast<unsigned>(kind));
-  }
-
-  /// Erase a uniqued instance of attribute T. This overload is used for
-  /// derived attributes that have complex storage or uniquing constraints.
-  template <typename T, typename... Args>
-  static typename std::enable_if<
-      !std::is_same<typename T::ImplType, AttributeStorage>::value>::type
-  erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
-    return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
+  static T get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+    return ctx->getAttributeUniquer().get<typename T::ImplType>(
+        [ctx](AttributeStorage *storage) {
+          // If the attribute did not provide a type, then default to NoneType.
+          if (!storage->getType())
+            storage->setType(NoneType::get(ctx));
+        },
         static_cast<unsigned>(kind), std::forward<Args>(args)...);
   }
 
-  /// Generate a functor to initialize a new attribute storage instance.
-  static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx) {
-    return [ctx](AttributeStorage *storage) {
-      // If the attribute did not provide a type, then default to NoneType.
-      if (!storage->getType())
-        storage->setType(NoneType::get(ctx));
-    };
+  /// Erase a uniqued instance of attribute T.
+  template <typename T, typename... Args>
+  static void erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+    return ctx->getAttributeUniquer().erase<typename T::ImplType>(
+        static_cast<unsigned>(kind), std::forward<Args>(args)...);
   }
 };