Revert D32053748: [pytorch] use cublas lt interface for bias fusion

Test Plan: revert-hammer

Differential Revision:
D32053748 (https://github.com/pytorch/pytorch/commit/702d375df524565f8f2741c8d96df254e450653c)

Original commit changeset: accf787c8727

Original Phabricator Diff: D32053748 (https://github.com/pytorch/pytorch/commit/702d375df524565f8f2741c8d96df254e450653c)

fbshipit-source-id: 735fe64de4d525d8c9f2833952b09483afeaea98
(cherry picked from commit 099bd88c628feb648baad7cb33484f5772ed1052)
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index ebf1979..34b0214 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -2,18 +2,10 @@
   Provides the implementations of CUDA BLAS function templates.
  */
 
-#include <ATen/ATen.h>
 #include <ATen/cuda/CUDABlas.h>
 #include <ATen/cuda/Exceptions.h>
-#include <c10/cuda/CUDAFunctions.h>
-#include <c10/macros/Export.h>
 #include <c10/util/irange.h>
-
-// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
-// added bf16 support
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-#include <cublasLt.h>
-#endif
+#include <c10/macros/Export.h>
 
 #define CUDABLAS_POSINT_CHECK(FD, X)         \
   TORCH_CHECK(                               \
@@ -584,254 +576,6 @@
 }
 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
 
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-
-namespace {
-// Following the pattern of CuSparseDescriptor
-// Defined here for now because this is the only place cublas_lt interface is
-// used but can be moved to a header once cublas_lt interface is used in
-// multiple places.
-template <typename T, cublasStatus_t (*destructor)(T*)>
-struct CuBlasLtDeleter {
-  void operator()(T* x) {
-    if (x != nullptr) {
-      TORCH_CUDABLAS_CHECK(destructor(x));
-    }
-  }
-};
-
-template <typename T, cublasStatus_t (*destructor)(T*)>
-class CuBlasLtDescriptor {
- public:
-  T* descriptor() const {
-    return descriptor_.get();
-  }
-  T* descriptor() {
-    return descriptor_.get();
-  }
-
- protected:
-  std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
-};
-
-class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
-                                     cublasLtMatmulDescOpaque_t,
-                                     &cublasLtMatmulDescDestroy> {
- public:
-  CuBlasLtMatmulDescriptor(
-      cublasComputeType_t compute_type,
-      cudaDataType_t scale_type) {
-    cublasLtMatmulDesc_t raw_descriptor;
-    TORCH_CUDABLAS_CHECK(
-        cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
-    descriptor_.reset(raw_descriptor);
-  }
-};
-
-class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
-                                 cublasLtMatrixLayoutOpaque_t,
-                                 &cublasLtMatrixLayoutDestroy> {
- public:
-  CuBlasLtMatrixLayout(
-      cudaDataType_t type,
-      uint64_t rows,
-      uint64_t cols,
-      int64_t ld) {
-    cublasLtMatrixLayout_t raw_descriptor;
-    TORCH_CUDABLAS_CHECK(
-        cublasLtMatrixLayoutCreate(&raw_descriptor, type, rows, cols, ld));
-    descriptor_.reset(raw_descriptor);
-  }
-};
-
-class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
-                                     cublasLtMatmulPreferenceOpaque_t,
-                                     &cublasLtMatmulPreferenceDestroy> {
- public:
-  CuBlasLtMatmulPreference() {
-    cublasLtMatmulPreference_t raw_descriptor;
-    TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
-    descriptor_.reset(raw_descriptor);
-  }
-};
-} // namespace
-
-template <typename Dtype>
-void gemm_and_bias(
-    bool transpose_mat1,
-    bool transpose_mat2,
-    int64_t m,
-    int64_t n,
-    int64_t k,
-    at::opmath_type<Dtype> alpha_val,
-    const Dtype* mat1_ptr,
-    int64_t mat1_ld,
-    const Dtype* mat2_ptr,
-    int64_t mat2_ld,
-    const Dtype* bias,
-    Dtype* result_ptr,
-    int64_t result_ld) {
-  using opmath_t = at::opmath_type<Dtype>;
-  opmath_t beta_val = 0; // bias is added in epilogue
-
-  cudaDataType_t abcType;
-  cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
-  cudaDataType_t scaleType = CUDA_R_32F;
-  if (std::is_same<Dtype, double>::value) {
-    abcType = CUDA_R_64F;
-    computeType = CUBLAS_COMPUTE_64F;
-    scaleType = CUDA_R_64F;
-  } else if (std::is_same<Dtype, float>::value) {
-    // Should set computeType = CUBLAS_COMPUTE_32F_FAST_TF32 ?
-    abcType = CUDA_R_32F;
-  } else if (std::is_same<Dtype, at::Half>::value) {
-    abcType = CUDA_R_16F;
-  } else if (std::is_same<Dtype, at::BFloat16>::value) {
-    abcType = CUDA_R_16BF;
-  }
-
-  CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
-  cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
-  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
-      computeDesc.descriptor(),
-      CUBLASLT_MATMUL_DESC_TRANSA,
-      &transa,
-      sizeof(transa)));
-  cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
-  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
-      computeDesc.descriptor(),
-      CUBLASLT_MATMUL_DESC_TRANSB,
-      &transb,
-      sizeof(transb)));
-  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
-  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
-      computeDesc.descriptor(),
-      CUBLASLT_MATMUL_DESC_EPILOGUE,
-      &epilogue,
-      sizeof(epilogue)));
-  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
-      computeDesc.descriptor(),
-      CUBLASLT_MATMUL_DESC_BIAS_POINTER,
-      &bias,
-      sizeof(Dtype*)));
-
-  CuBlasLtMatrixLayout Adesc(
-      abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
-  CuBlasLtMatrixLayout Bdesc(
-      abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
-  CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
-
-  CuBlasLtMatmulPreference preference;
-  size_t workspaceSize = 0;
-  TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
-      preference.descriptor(),
-      CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
-      &workspaceSize,
-      sizeof(workspaceSize)));
-
-  auto workspace = at::empty(
-      {static_cast<int64_t>(workspaceSize)},
-      at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
-
-  cublasLtMatmulHeuristicResult_t heuristicResult = {};
-  int returnedResult = 0;
-  cublasLtHandle_t ltHandle =
-      reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
-  TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
-      ltHandle,
-      computeDesc.descriptor(),
-      Adesc.descriptor(),
-      Bdesc.descriptor(),
-      Cdesc.descriptor(),
-      Cdesc.descriptor(),
-      preference.descriptor(),
-      1,
-      &heuristicResult,
-      &returnedResult));
-  if (returnedResult == 0) {
-    TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
-  }
-
-  TORCH_CUDABLAS_CHECK(cublasLtMatmul(
-      ltHandle,
-      computeDesc.descriptor(),
-      &alpha_val,
-      mat1_ptr,
-      Adesc.descriptor(),
-      mat2_ptr,
-      Bdesc.descriptor(),
-      &beta_val,
-      result_ptr,
-      Cdesc.descriptor(),
-      result_ptr,
-      Cdesc.descriptor(),
-      &heuristicResult.algo,
-      workspace.data_ptr(),
-      workspaceSize,
-      at::cuda::getCurrentCUDAStream()));
-}
-
-template void gemm_and_bias(
-    bool transpose_mat1,
-    bool transpose_mat2,
-    int64_t m,
-    int64_t n,
-    int64_t k,
-    at::opmath_type<double> alpha_val,
-    const double* mat1_ptr,
-    int64_t mat1_ld,
-    const double* mat2_ptr,
-    int64_t mat2_ld,
-    const double* bias,
-    double* result_ptr,
-    int64_t result_ld);
-
-template void gemm_and_bias(
-    bool transpose_mat1,
-    bool transpose_mat2,
-    int64_t m,
-    int64_t n,
-    int64_t k,
-    at::opmath_type<float> alpha_val,
-    const float* mat1_ptr,
-    int64_t mat1_ld,
-    const float* mat2_ptr,
-    int64_t mat2_ld,
-    const float* bias,
-    float* result_ptr,
-    int64_t result_ld);
-
-template void gemm_and_bias(
-    bool transpose_mat1,
-    bool transpose_mat2,
-    int64_t m,
-    int64_t n,
-    int64_t k,
-    at::opmath_type<at::Half> alpha_val,
-    const at::Half* mat1_ptr,
-    int64_t mat1_ld,
-    const at::Half* mat2_ptr,
-    int64_t mat2_ld,
-    const at::Half* bias,
-    at::Half* result_ptr,
-    int64_t result_ld);
-
-template void gemm_and_bias(
-    bool transpose_mat1,
-    bool transpose_mat2,
-    int64_t m,
-    int64_t n,
-    int64_t k,
-    at::opmath_type<at::BFloat16> alpha_val,
-    const at::BFloat16* mat1_ptr,
-    int64_t mat1_ld,
-    const at::BFloat16* mat2_ptr,
-    int64_t mat2_ld,
-    const at::BFloat16* bias,
-    at::BFloat16* result_ptr,
-    int64_t result_ld);
-#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-
 template <>
 void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
   TORCH_CUDABLAS_CHECK(cublasStrsm(
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index adb4b04..8188a99 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -70,24 +70,6 @@
 void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
 #endif
 
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-template <typename Dtype>
-void gemm_and_bias(
-    bool transpose_mat1,
-    bool transpose_mat2,
-    int64_t m,
-    int64_t n,
-    int64_t k,
-    at::opmath_type<Dtype> alpha_val,
-    const Dtype* mat1_ptr,
-    int64_t mat1_ld,
-    const Dtype* mat2_ptr,
-    int64_t mat2_ld,
-    const Dtype* bias,
-    Dtype* result_ptr,
-    int64_t result_ld);
-#endif
-
 #define CUDABLAS_BGEMM_ARGTYPES(Dtype)                                                        \
   char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha,    \
       const Dtype *a, int64_t lda, int64_t stridea,                                           \
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 3032116..5415910 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -102,27 +102,9 @@
   IntArrayRef mat1_sizes = mat1.sizes();
   IntArrayRef mat2_sizes = mat2.sizes();
   IntArrayRef self__sizes;
-  bool useLtInterface = false;
-  at::ScalarType scalar_type = self.scalar_type();
   c10::MaybeOwned<Tensor> self_;
   if (&result != &self) {
-#if CUDA_VERSION >= 11000
-    // Strangely, if mat2 has only 1 row or column, we get
-    // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
-    // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
-    // is to use lt interface only when self is bias.
-    useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
-        result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
-        self.is_contiguous() &&
-        (scalar_type == at::ScalarType::Double ||
-         scalar_type == at::ScalarType::Float ||
-         scalar_type == at::ScalarType::Half ||
-         scalar_type == at::ScalarType::BFloat16) &&
-        mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
-#endif
-    if (!useLtInterface) {
-      self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
-    }
+    self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
     self__sizes = self_->sizes();
   } else {
     self_ = c10::MaybeOwned<Tensor>::borrowed(self);
@@ -133,8 +115,8 @@
   }
 
   if (&result != &self) {
-    at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
-    if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
+    at::native::resize_output(result, self__sizes);
+    if (beta.toComplexDouble() != 0.0) {
       at::native::copy_(result, *self_);
     }
   }
@@ -165,6 +147,7 @@
   int64_t mat1_ld = mat1_->stride((transpose_mat1 == transpose_result) ? 1 : 0);
   int64_t mat2_ld = mat2_->stride((transpose_mat2 == transpose_result) ? 1 : 0);
   int64_t result_ld = result_->stride(transpose_result ? 0 : 1);
+  at::ScalarType scalar_type = self_->scalar_type();
 
   if (mat1.numel() == 0) {
     // By definition, when beta==0, values in self should be ignored. nans and infs
@@ -187,61 +170,24 @@
 
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
 
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-  if (useLtInterface) {
-    AT_DISPATCH_FLOATING_TYPES_AND2(
-        at::ScalarType::Half,
-        at::ScalarType::BFloat16,
-        scalar_type,
-        "addmm_cuda_lt",
-        [&] {
-          at::cuda::blas::gemm_and_bias<scalar_t>(
-              transpose_mat1,
-              transpose_mat2,
-              m,
-              n,
-              k,
-              alpha.to<at::opmath_type<scalar_t>>(),
-              mat1_->data_ptr<scalar_t>(),
-              mat1_ld,
-              mat2_->data_ptr<scalar_t>(),
-              mat2_ld,
-              self.data_ptr<scalar_t>(),
-              result_->data_ptr<scalar_t>(),
-              result_ld);
-        });
-  } else
-#endif
-  {
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
-        at::ScalarType::Half,
-        at::ScalarType::BFloat16,
-        scalar_type,
-        "addmm_cuda",
-        [&] {
-          using opmath_t = at::opmath_type<scalar_t>;
-          opmath_t alpha_val = alpha.to<opmath_t>();
-          opmath_t beta_val = beta.to<opmath_t>();
-          scalar_t* mat1_ptr = mat1_->data_ptr<scalar_t>();
-          scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
-          scalar_t* result_ptr = result_->data_ptr<scalar_t>();
-          at::cuda::blas::gemm<scalar_t>(
-              transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n',
-              transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n',
-              m,
-              n,
-              k,
-              alpha_val,
-              mat1_ptr,
-              mat1_ld,
-              mat2_ptr,
-              mat2_ld,
-              beta_val,
-              result_ptr,
-              result_ld);
-        });
-  }
-
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
+    using opmath_t = at::opmath_type<scalar_t>;
+    opmath_t alpha_val = alpha.to<opmath_t>();
+    opmath_t beta_val = beta.to<opmath_t>();
+    scalar_t* mat1_ptr = mat1_->data_ptr<scalar_t>();
+    scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
+    scalar_t* result_ptr = result_->data_ptr<scalar_t>();
+    at::cuda::blas::gemm<scalar_t>(
+      transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n',
+      transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n',
+      m, n, k,
+      alpha_val,
+      mat1_ptr, mat1_ld,
+      mat2_ptr, mat2_ld,
+      beta_val,
+      result_ptr, result_ld
+    );
+  });
   if (!result.is_same(*result_)) {
     result.copy_(*result_);
   }