Revert "[ROCm] add hipblaslt support (#114329)"

This reverts commit b062ea38039234c80404a8f5f4d5a93c4cb9832d.

Reverted https://github.com/pytorch/pytorch/pull/114329 on behalf of https://github.com/jeanschmidt due to Reverting due to inconsistencies on internal diff ([comment](https://github.com/pytorch/pytorch/pull/114329#issuecomment-1861933267))
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index 24c96f4..c1a8426 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -11,9 +11,14 @@
 #include <c10/macros/Export.h>
 #include <c10/util/irange.h>
 
+// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
+// added bf16 support
+#if !defined(USE_ROCM) && !defined(_MSC_VER)
+#include <cublasLt.h>
+#endif
+
 #ifdef USE_ROCM
 // until hipblas has an API to accept flags, we must use rocblas here
-#include <hipblas/hipblas.h>
 #include <rocblas/rocblas.h>
 #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
 #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
@@ -59,7 +64,6 @@
 // until we use hiblas v2
 // hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
 // however hipblas v1 is still using its custom type
-#ifndef HIPBLAS_V2
 #define HIP_R_16F  HIPBLAS_R_16F
 #define HIP_R_32F  HIPBLAS_R_32F
 #define HIP_R_64F  HIPBLAS_R_64F
@@ -77,7 +81,6 @@
 #define HIP_R_16BF HIPBLAS_R_16B
 #define HIP_C_16BF HIPBLAS_C_16B
 #endif
-#endif
 
 #define CUDABLAS_POSINT_CHECK(FD, X)         \
   TORCH_CHECK(                               \
@@ -164,7 +167,6 @@
   }
 }
 
-#ifndef USE_ROCM
 uint32_t _getAlignment(uintptr_t address) {
   // alignment are in bytes
   uint32_t alignment = 256;
@@ -174,25 +176,18 @@
     }
   }
 }
-#endif
 
 static size_t _parseChosenWorkspaceSize() {
   const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
-#ifdef USE_ROCM
-  if (!val) {
-    // accept either env var
-    val = getenv("HIPBLASLT_WORKSPACE_SIZE");
-  }
-#endif
   size_t workspace_size = 1024; /* default size in KiB according to #73328 */
   if (val) {
     try {
       workspace_size = std::stoi(val);
     } catch(std::invalid_argument const& e) {
-      TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
+      TORCH_WARN("invalid CUBLAS_LT_WORKSPACE_SIZE,",
                  " using default workspace size of ", workspace_size, " bytes.");
     } catch(std::out_of_range const& e) {
-      TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,",
+      TORCH_WARN("CUBLAS_LT_WORKSPACE_SIZE out of range,",
                  " using default workspace size of ", workspace_size, " bytes.");
     }
   }
@@ -346,19 +341,12 @@
   const float fbeta = beta;
   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
-  auto compute_type = CUBLAS_COMPUTE_32F;
-#else
-  auto compute_type = CUDA_R_32F;
-#endif
   TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
                                   opa, opb, (int)m, (int)n, (int)k,
                                   (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
                                   b, CUDA_R_16BF, (int)ldb, strideb,
                                   (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
-                                  (int)num_batches,
-                                  compute_type,
-                                  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+                                  (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 }
 
 template <>
@@ -529,11 +517,6 @@
     cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
   }
 #endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
-  auto compute_type = CUBLAS_COMPUTE_32F;
-#else
-  auto compute_type = CUDA_R_32F;
-#endif
   TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
   TORCH_CUDABLAS_CHECK(cublasGemmEx(
       handle,
@@ -553,62 +536,12 @@
       c,
       CUDA_R_16BF,
       ldc,
-      compute_type,
+      CUDA_R_32F,
       CUBLAS_GEMM_DEFAULT_TENSOR_OP));
   TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
 }
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
-
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
-// only for rocm 5.7 where we first supported hipblaslt, it was difficult
-// to hipify correctly without this change.
-#define hipDataType hipblasDatatype_t
-#endif
-
-// hipblaslt custom types were a temporary work-around
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_DATA_TYPE
-hipblasltDatatype_t hipToLt(hipDataType type) {
-    switch (type) {
-        case HIP_R_32F: return HIPBLASLT_R_32F;
-        case HIP_R_64F: return HIPBLASLT_R_64F;
-        case HIP_R_16F: return HIPBLASLT_R_16F;
-        case HIP_R_8I: return HIPBLASLT_R_8I;
-        case HIP_C_32F: return HIPBLASLT_C_32F;
-        case HIP_C_64F: return HIPBLASLT_C_64F;
-        case HIP_C_16F: return HIPBLASLT_C_16F;
-        case HIP_C_8I: return HIPBLASLT_C_8I;
-        case HIP_R_8U: return HIPBLASLT_R_8U;
-        case HIP_C_8U: return HIPBLASLT_C_8U;
-        case HIP_R_32I: return HIPBLASLT_R_32I;
-        case HIP_C_32I: return HIPBLASLT_C_32I;
-        case HIP_R_32U: return HIPBLASLT_R_32U;
-        case HIP_C_32U: return HIPBLASLT_C_32U;
-        case HIP_R_16BF: return HIPBLASLT_R_16B;
-        case HIP_C_16BF: return HIPBLASLT_C_16B;
-        default: TORCH_CHECK(false);
-    }
-}
-#define HIPTOLT(type) hipToLt(type)
-#else
-#define HIPTOLT(type) type
-#endif
-
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_COMPUTE_TYPE
-hipblasLtComputeType_t hipblasToLt(hipblasComputeType_t type) {
-    switch (type) {
-        case HIPBLAS_COMPUTE_32F: return HIPBLASLT_COMPUTE_F32;
-        case HIPBLAS_COMPUTE_32F_FAST_16F: return HIPBLASLT_COMPUTE_F32_FAST_F16;
-        case HIPBLAS_COMPUTE_32F_FAST_TF32: return HIPBLASLT_COMPUTE_F32_FAST_XF32;
-        case HIPBLAS_COMPUTE_64F: return HIPBLASLT_COMPUTE_F64;
-        case HIPBLAS_COMPUTE_32I: return HIPBLASLT_COMPUTE_I32;
-        default: TORCH_CHECK(false);
-    }
-}
-#define HIPCOMPTOLT(type) hipblasToLt(type)
-#else
-#define HIPCOMPTOLT(type) type
-#endif
+#if !defined(USE_ROCM) && !defined(_MSC_VER)
 
 namespace {
 // Following the pattern of CuSparseDescriptor
@@ -647,7 +580,7 @@
       cudaDataType_t scale_type) {
     cublasLtMatmulDesc_t raw_descriptor = nullptr;
     TORCH_CUDABLAS_CHECK(
-        cublasLtMatmulDescCreate(&raw_descriptor, HIPCOMPTOLT(compute_type), HIPTOLT(scale_type)));
+        cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
     descriptor_.reset(raw_descriptor);
   }
   template <typename T>
@@ -668,7 +601,7 @@
       bool t = false) {
     cublasLtMatrixLayout_t raw_descriptor = nullptr;
     TORCH_CUDABLAS_CHECK(
-        cublasLtMatrixLayoutCreate(&raw_descriptor, HIPTOLT(type), t ? cols : rows, t ? rows : cols, ld));
+        cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
     descriptor_.reset(raw_descriptor);
   }
 };
@@ -712,19 +645,13 @@
   cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
   cudaDataType_t scaleType = CUDA_R_32F;
   if constexpr (std::is_same_v<Dtype, double>) {
-#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
     abcType = CUDA_R_64F;
     computeType = CUBLAS_COMPUTE_64F;
     scaleType = CUDA_R_64F;
-#else
-    TORCH_CHECK(false, "gemm_and_bias is only supported for double type on ROCm 6.0 and above");
-#endif
   } else if constexpr (std::is_same_v<Dtype, float>) {
-#ifndef USE_ROCM
     if (at::globalContext().allowTF32CuBLAS()) {
       computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
     }
-#endif
     abcType = CUDA_R_32F;
   } else if constexpr (std::is_same_v<Dtype, at::Half>) {
     abcType = CUDA_R_16F;
@@ -741,7 +668,7 @@
   if (activation == GEMMAndBiasActivationEpilogue::RELU) {
     epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
   } else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
-#if CUDA_VERSION >= 11040 || defined(USE_ROCM)
+#if CUDA_VERSION >= 11040
     epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
 #endif
   }
@@ -758,7 +685,6 @@
   size_t workspaceSize = _getWorkspaceSize();
   preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
 
-#ifndef USE_ROCM
   uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
   uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
   uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
@@ -767,14 +693,14 @@
   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
-#endif
 
   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
   auto workspace = allocator.allocate(workspaceSize);
 
   cublasLtMatmulHeuristicResult_t heuristicResult = {};
   int returnedResult = 0;
-  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
+  cublasLtHandle_t ltHandle =
+      reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
       ltHandle,
       computeDesc.descriptor(),
@@ -950,7 +876,8 @@
   preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
   cublasLtMatmulHeuristicResult_t heuristicResult = {};
   int returnedResult = 0;
-  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
+  cublasLtHandle_t ltHandle =
+      reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
       ltHandle,
       computeDesc.descriptor(),
@@ -1025,7 +952,6 @@
     int64_t mat2_ld,
     int32_t* result_ptr,
     int64_t result_ld) {
-#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
 
   cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
   cudaDataType_t scaleType = CUDA_R_32I;
@@ -1044,7 +970,8 @@
   CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
   CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
 
-  cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
+  cublasLtHandle_t ltHandle =
+      reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
 
   // cublas team: alpha and beta need to be the same dtype as of scaleType
   at::opmath_type<int32_t> alpha_val = 1;
@@ -1095,14 +1022,11 @@
       computeType,
       " scaleType ",
       scaleType);
-#else
-  TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
-#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
 }
-#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#endif // !defined(USE_ROCM) && !defined(_MSC_VER)
 
 // ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
-#if defined(USE_ROCM) && ROCM_VERSION <= 50600
+#if defined(USE_ROCM) && ROCM_VERSION <= 56000
 #define ROCM_CONST_BUG
 #else
 #define ROCM_CONST_BUG const
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index ee3b41b..52acf9a 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -62,7 +62,7 @@
 template <>
 void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if !defined(USE_ROCM) && !defined(_MSC_VER)
 enum GEMMAndBiasActivationEpilogue {
   None,
   RELU,
@@ -149,7 +149,7 @@
 template <>
 void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
 
-#if defined(USE_ROCM) && ROCM_VERSION <= 50500
+#if defined(USE_ROCM) && ROCM_VERSION <= 55000
 // ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
 #define CUDABLAS_TRSM_ARGTYPES(Dtype)                                  \
   hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index c189ed2..29e79c2 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -7,12 +7,6 @@
 #include <cusparse.h>
 #include <cublas_v2.h>
 
-// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
-// added bf16 support
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
-#include <cublasLt.h>
-#endif
-
 #ifdef CUDART_VERSION
 #include <cusolverDn.h>
 #endif
@@ -82,9 +76,6 @@
 /* Handles */
 TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
 TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
-TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
-#endif
 
 TORCH_CUDA_CPP_API void clearCublasWorkspaces();
 
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 05c2ba9..dae61a4 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -9,46 +9,10 @@
 #include <string>
 #include <tuple>
 
-/**
- * Note [hipblaslt handles]
- * ~~~~~~~~~~~~~~~~~~~~~~~~
- * The cublas documentation states:
- * cuBLAS handle (cublasHandle_t) encapsulates a cuBLASLt handle.
- * Any valid cublasHandle_t can be used in place of cublasLtHandle_t with a simple cast.
- *
- * hipblaslt does not behave in this way.
- * A hipblas handle does not encapsulate a hipblaslt handle.
- *
- * To work around this difference in behavior, a separate handle pool is available for ROCm builds.
- * For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle,
- * whereas for ROCm builds, it is a distinct function.
- */
-
 namespace at::cuda {
 
 namespace {
 
-#ifdef USE_ROCM
-void createCublasLtHandle(cublasLtHandle_t *handle) {
-  TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
-}
-
-void destroyCublasLtHandle(cublasLtHandle_t handle) {
-// this is because of something dumb in the ordering of
-// destruction. Sometimes atexit, the cuda context (or something)
-// would already be destroyed by the time this gets destroyed. It
-// happens in fbcode setting. @colesbury and @soumith decided to not destroy
-// the handle as a workaround.
-//   - Comments of @soumith copied from cuDNN handle pool implementation
-#ifdef NO_CUDNN_DESTROY_HANDLE
-#else
-    cublasLtDestroy(handle);
-#endif
-}
-
-using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
-#endif
-
 std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
   static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
   return instance;
@@ -177,33 +141,4 @@
   return handle;
 }
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
-#ifdef USE_ROCM
-  int device;
-  AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
-
-  // Thread local PoolWindows are lazily-initialized
-  // to avoid initialization issues that caused hangs on Windows.
-  // See: https://github.com/pytorch/pytorch/pull/22405
-  // This thread local unique_ptrs will be destroyed when the thread terminates,
-  // releasing its reserved handles back to the pool.
-
-  // Use a leaky singleton for the pool following standard practice around
-  // singletons: https://isocpp.org/wiki/faq/ctors#construct-on-first-use-v2
-  static auto pool = std::shared_ptr<CuBlasLtPoolType>(
-      new CuBlasLtPoolType(), [](CuBlasLtPoolType* p) {
-        // Leak the memory.
-      });
-  thread_local std::unique_ptr<CuBlasLtPoolType::PoolWindow> myPoolWindow(
-      pool->newPoolWindow());
-
-  auto handle = myPoolWindow->reserve(device);
-  return handle;
-#else
-  return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
-#endif
-}
-#endif
-
 } // namespace at::cuda
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 35a2477..e5163c3 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -153,7 +153,7 @@
   GELU,
 };
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if !defined(USE_ROCM) && !defined(_MSC_VER)
 cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
   switch (a) {
     case Activation::None:
@@ -171,40 +171,12 @@
 
 static bool getDisableAddmmCudaLt() {
     static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
-#ifdef USE_ROCM
-    // allow both CUDA and HIP env var names for ROCm builds
-    // also, current default for ROCm builds is disable by default
-    if (env_value == nullptr) {
-        env_value = std::getenv("DISABLE_ADDMM_HIP_LT");
-    }
-    if (env_value != nullptr && strcmp(env_value, "0") == 0) {
-      return false;
-    }
-    return true;
-#else
     if (env_value != nullptr && strcmp(env_value, "1") == 0) {
       return true;
     }
     return false;
-#endif
 }
 
-#ifdef USE_ROCM
-static bool isSupportedHipLtROCmArch(int index) {
-    hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
-    std::string device_arch = prop->gcnArchName;
-    static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
-    for (std::string arch : archs) {
-        size_t substring = device_arch.find(arch);
-        if (substring != std::string::npos) {
-            return true;
-        }
-    }
-    TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
-    return false;
-}
-#endif
-
 Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
   // Make sure to keep addmm_cuda below in sync with this code; it
   // preflights a check to try to avoid actually needing to call
@@ -226,7 +198,7 @@
   at::ScalarType scalar_type = self.scalar_type();
   c10::MaybeOwned<Tensor> self_;
   if (&result != &self) {
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)
     // Strangely, if mat2 has only 1 row or column, we get
     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@@ -239,17 +211,10 @@
       useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
           result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
           self.is_contiguous() && result.is_contiguous() &&
-#ifdef USE_ROCM
-          isSupportedHipLtROCmArch(self.device().index()) &&
-          (scalar_type == at::ScalarType::Float ||
-           scalar_type == at::ScalarType::Half ||
-           scalar_type == at::ScalarType::BFloat16) &&
-#else
           (scalar_type == at::ScalarType::Double ||
            scalar_type == at::ScalarType::Float ||
            scalar_type == at::ScalarType::Half ||
            scalar_type == at::ScalarType::BFloat16) &&
-#endif
           mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
           mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
           mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
@@ -269,14 +234,6 @@
     }
     self__sizes = self_->sizes();
   } else {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
-    useLtInterface = !disable_addmm_cuda_lt &&
-        result.dim() == 2 && result.is_contiguous() &&
-        isSupportedHipLtROCmArch(self.device().index()) &&
-        (scalar_type == at::ScalarType::Float ||
-          scalar_type == at::ScalarType::Half ||
-          scalar_type == at::ScalarType::BFloat16);
-#endif
     self_ = c10::MaybeOwned<Tensor>::borrowed(self);
     self__sizes = self_->sizes();
     TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
@@ -320,7 +277,7 @@
 
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if !defined(USE_ROCM) && !defined(_MSC_VER)
   if (useLtInterface) {
     AT_DISPATCH_FLOATING_TYPES_AND2(
         at::ScalarType::Half,
@@ -342,7 +299,7 @@
               self.const_data_ptr<scalar_t>(),
               args.result->data_ptr<scalar_t>(),
               args.result_ld,
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM)
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
               activation_to_gemm_and_blas_arg(activation)
 #else
               // GELU is not supported (and does not compile!) prior
@@ -400,7 +357,7 @@
 // gating activation_to_gemm_and_blas_arg above; here we are manually
 // performing a post-GELU because we weren't able to use the GELU
 // epilogue above.
-#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM)
+#if !defined(CUDA_VERSION) || CUDA_VERSION < 11080
   if (useLtInterface && activation == Activation::GELU) {
     at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
   }
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 9a6898f..ebcbbdf 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1257,15 +1257,6 @@
     list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
     list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
     list(APPEND HIP_CXX_FLAGS -std=c++17)
-    if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
-      list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
-    endif()
-    if(HIPBLASLT_CUSTOM_DATA_TYPE)
-      list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
-    endif()
-    if(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
-      list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE)
-    endif()
     add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
     add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
     message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")
@@ -1291,9 +1282,6 @@
 
     set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
       ${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
-    if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
-      list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES})
-    endif()
 
     list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
       roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index f7344cc..6989f57 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -136,7 +136,6 @@
   set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
   set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
   set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
-  set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt)
   set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
   set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
   set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
@@ -155,9 +154,6 @@
   find_package_and_print_version(hiprand REQUIRED)
   find_package_and_print_version(rocblas REQUIRED)
   find_package_and_print_version(hipblas REQUIRED)
-  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
-    find_package_and_print_version(hipblaslt REQUIRED)
-  endif()
   find_package_and_print_version(miopen REQUIRED)
   if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
     find_package_and_print_version(hipfft REQUIRED)
@@ -191,57 +187,4 @@
   find_library(ROCM_HIPRTC_LIB amdhip64 HINTS ${ROCM_PATH}/lib)
   # roctx is part of roctracer
   find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
-
-  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
-    # check whether hipblaslt is using its own datatype
-    set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
-    file(WRITE ${file} ""
-      "#include <hipblaslt/hipblaslt.h>\n"
-      "int main() {\n"
-      "    hipblasltDatatype_t bar = HIPBLASLT_R_16F;\n"
-      "    return 0;\n"
-      "}\n"
-      )
-
-    try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
-      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
-      COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
-      OUTPUT_VARIABLE hipblaslt_compile_output)
-
-    if(hipblaslt_compile_result)
-      set(HIPBLASLT_CUSTOM_DATA_TYPE ON)
-      #message("hipblaslt is using custom data type: ${hipblaslt_compile_output}")
-      message("hipblaslt is using custom data type")
-    else()
-      set(HIPBLASLT_CUSTOM_DATA_TYPE OFF)
-      #message("hipblaslt is NOT using custom data type: ${hipblaslt_compile_output}")
-      message("hipblaslt is NOT using custom data type")
-    endif()
-
-    # check whether hipblaslt is using its own compute type
-    set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_compute_type.cc")
-    file(WRITE ${file} ""
-      "#include <hipblaslt/hipblaslt.h>\n"
-      "int main() {\n"
-      "    hipblasLtComputeType_t baz = HIPBLASLT_COMPUTE_F32;\n"
-      "    return 0;\n"
-      "}\n"
-      )
-
-    try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
-      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
-      COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
-      OUTPUT_VARIABLE hipblaslt_compile_output)
-
-    if(hipblaslt_compile_result)
-      set(HIPBLASLT_CUSTOM_COMPUTE_TYPE ON)
-      #message("hipblaslt is using custom compute type: ${hipblaslt_compile_output}")
-      message("hipblaslt is using custom compute type")
-    else()
-      set(HIPBLASLT_CUSTOM_COMPUTE_TYPE OFF)
-      #message("hipblaslt is NOT using custom compute type: ${hipblaslt_compile_output}")
-      message("hipblaslt is NOT using custom compute type")
-    endif()
-  endif()
-
 endif()
diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py
index b80e22a..d2b2e16 100644
--- a/torch/utils/cpp_extension.py
+++ b/torch/utils/cpp_extension.py
@@ -237,9 +237,6 @@
     '-DUSE_ROCM=1',
 ]
 
-if ROCM_VERSION is not None and ROCM_VERSION >= (6, 0):
-    COMMON_HIP_FLAGS.append('-DHIPBLAS_V2')
-
 COMMON_HIPCC_FLAGS = [
     '-DCUDA_HAS_FP16=1',
     '-D__HIP_NO_HALF_OPERATORS__=1',
diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py
index 979a639..fa727a7 100644
--- a/torch/utils/hipify/cuda_to_hip_mappings.py
+++ b/torch/utils/hipify/cuda_to_hip_mappings.py
@@ -611,7 +611,6 @@
         ("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)),
         ("cublas.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
         ("cublas_v2.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
-        ("cublasLt.h", ("hipblaslt/hipblaslt.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
         ("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
         ("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
         ("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
@@ -3852,7 +3851,7 @@
                 HIP_UNSUPPORTED,
             ),
         ),
-        ("cudaDataType_t", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
+        ("cudaDataType_t", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
         ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
         ("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
         ("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
@@ -7273,65 +7272,6 @@
             ("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
         ),
         (
-            "cublasComputeType_t",
-            ("hipblasComputeType_t" if rocm_version >= (6, 0, 0) else "hipblasLtComputeType_t",
-                CONV_MATH_FUNC, API_BLAS)
-        ),
-        (
-            "CUBLAS_COMPUTE_32I",
-            ("HIPBLAS_COMPUTE_32I" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_I32", CONV_MATH_FUNC, API_BLAS)
-        ),
-        (
-            "CUBLAS_COMPUTE_32F",
-            ("HIPBLAS_COMPUTE_32F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F32", CONV_MATH_FUNC, API_BLAS)
-        ),
-        (
-            "CUBLAS_COMPUTE_64F",
-            ("HIPBLAS_COMPUTE_64F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F64", CONV_MATH_FUNC, API_BLAS)
-        ),
-        ("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_EPILOGUE_RELU", ("HIPBLASLT_EPILOGUE_RELU", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_EPILOGUE_BIAS", ("HIPBLASLT_EPILOGUE_BIAS", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_EPILOGUE_RELU_BIAS", ("HIPBLASLT_EPILOGUE_RELU_BIAS", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_EPILOGUE_GELU", ("HIPBLASLT_EPILOGUE_GELU", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_EPILOGUE_GELU_BIAS", ("HIPBLASLT_EPILOGUE_GELU_BIAS", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtHandle_t", ("hipblasLtHandle_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulDesc_t", ("hipblasLtMatmulDesc_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulDescOpaque_t", ("hipblasLtMatmulDescOpaque_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulDescAttributes_t", ("hipblasLtMatmulDescAttributes_t", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_TRANSA", ("HIPBLASLT_MATMUL_DESC_TRANSA", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_TRANSB", ("HIPBLASLT_MATMUL_DESC_TRANSB", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_EPILOGUE", ("HIPBLASLT_MATMUL_DESC_EPILOGUE", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_BIAS_POINTER", ("HIPBLASLT_MATMUL_DESC_BIAS_POINTER", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulPreference_t", ("hipblasLtMatmulPreference_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulPreferenceOpaque_t", ("hipblasLtMatmulPreferenceOpaque_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulPreferenceAttributes_t", ("hipblasLtMatmulPreferenceAttributes_t", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_PREF_SEARCH_MODE", ("HIPBLASLT_MATMUL_PREF_SEARCH_MODE", CONV_MATH_FUNC, API_BLAS)),
-        ("CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", ("HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulAlgo_t", ("hipblasLtMatmulAlgo_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulHeuristicResult_t", ("hipblasLtMatmulHeuristicResult_t", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatrixLayoutCreate", ("hipblasLtMatrixLayoutCreate", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatrixLayoutDestroy", ("hipblasLtMatrixLayoutDestroy", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtCreate", ("hipblasLtCreate", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtDestroy", ("hipblasLtDestroy", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulDescCreate", ("hipblasLtMatmulDescCreate", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulDescDestroy", ("hipblasLtMatmulDescDestroy", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulDescSetAttribute", ("hipblasLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulPreferenceCreate", ("hipblasLtMatmulPreferenceCreate", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulPreferenceDestroy", ("hipblasLtMatmulPreferenceDestroy", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulPreferenceSetAttribute", ("hipblasLtMatmulPreferenceSetAttribute", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmulAlgoGetHeuristic", ("hipblasLtMatmulAlgoGetHeuristic", CONV_MATH_FUNC, API_BLAS)),
-        ("cublasLtMatmul", ("hipblasLtMatmul", CONV_MATH_FUNC, API_BLAS)),
-        (
             "CURAND_STATUS_SUCCESS",
             ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND),
         ),
@@ -7737,14 +7677,8 @@
                 HIP_UNSUPPORTED,
             ),
         ),
-        (
-            "cuComplex",
-            ("hipComplex" if rocm_version >= (6, 0, 0) else "hipblasComplex", CONV_TYPE, API_BLAS)
-        ),
-        (
-            "cuDoubleComplex",
-            ("hipDoubleComplex" if rocm_version >= (6, 0, 0) else "hipblasDoubleComplex", CONV_TYPE, API_BLAS),
-        ),
+        ("cuComplex", ("hipblasComplex", CONV_TYPE, API_BLAS)),
+        ("cuDoubleComplex", ("hipblasDoubleComplex", CONV_TYPE, API_BLAS)),
         ("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)),
         ("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)),
         ("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)),