Revert D32053748: [pytorch] use cublas lt interface for bias fusion
Test Plan: revert-hammer
Differential Revision:
D32053748 (https://github.com/pytorch/pytorch/commit/702d375df524565f8f2741c8d96df254e450653c)
Original commit changeset: accf787c8727
Original Phabricator Diff: D32053748 (https://github.com/pytorch/pytorch/commit/702d375df524565f8f2741c8d96df254e450653c)
fbshipit-source-id: 735fe64de4d525d8c9f2833952b09483afeaea98
(cherry picked from commit 099bd88c628feb648baad7cb33484f5772ed1052)
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index ebf1979..34b0214 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -2,18 +2,10 @@
Provides the implementations of CUDA BLAS function templates.
*/
-#include <ATen/ATen.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
-#include <c10/cuda/CUDAFunctions.h>
-#include <c10/macros/Export.h>
#include <c10/util/irange.h>
-
-// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
-// added bf16 support
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-#include <cublasLt.h>
-#endif
+#include <c10/macros/Export.h>
#define CUDABLAS_POSINT_CHECK(FD, X) \
TORCH_CHECK( \
@@ -584,254 +576,6 @@
}
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-
-namespace {
-// Following the pattern of CuSparseDescriptor
-// Defined here for now because this is the only place cublas_lt interface is
-// used but can be moved to a header once cublas_lt interface is used in
-// multiple places.
-template <typename T, cublasStatus_t (*destructor)(T*)>
-struct CuBlasLtDeleter {
- void operator()(T* x) {
- if (x != nullptr) {
- TORCH_CUDABLAS_CHECK(destructor(x));
- }
- }
-};
-
-template <typename T, cublasStatus_t (*destructor)(T*)>
-class CuBlasLtDescriptor {
- public:
- T* descriptor() const {
- return descriptor_.get();
- }
- T* descriptor() {
- return descriptor_.get();
- }
-
- protected:
- std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
-};
-
-class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
- cublasLtMatmulDescOpaque_t,
- &cublasLtMatmulDescDestroy> {
- public:
- CuBlasLtMatmulDescriptor(
- cublasComputeType_t compute_type,
- cudaDataType_t scale_type) {
- cublasLtMatmulDesc_t raw_descriptor;
- TORCH_CUDABLAS_CHECK(
- cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
- descriptor_.reset(raw_descriptor);
- }
-};
-
-class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
- cublasLtMatrixLayoutOpaque_t,
- &cublasLtMatrixLayoutDestroy> {
- public:
- CuBlasLtMatrixLayout(
- cudaDataType_t type,
- uint64_t rows,
- uint64_t cols,
- int64_t ld) {
- cublasLtMatrixLayout_t raw_descriptor;
- TORCH_CUDABLAS_CHECK(
- cublasLtMatrixLayoutCreate(&raw_descriptor, type, rows, cols, ld));
- descriptor_.reset(raw_descriptor);
- }
-};
-
-class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
- cublasLtMatmulPreferenceOpaque_t,
- &cublasLtMatmulPreferenceDestroy> {
- public:
- CuBlasLtMatmulPreference() {
- cublasLtMatmulPreference_t raw_descriptor;
- TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
- descriptor_.reset(raw_descriptor);
- }
-};
-} // namespace
-
-template <typename Dtype>
-void gemm_and_bias(
- bool transpose_mat1,
- bool transpose_mat2,
- int64_t m,
- int64_t n,
- int64_t k,
- at::opmath_type<Dtype> alpha_val,
- const Dtype* mat1_ptr,
- int64_t mat1_ld,
- const Dtype* mat2_ptr,
- int64_t mat2_ld,
- const Dtype* bias,
- Dtype* result_ptr,
- int64_t result_ld) {
- using opmath_t = at::opmath_type<Dtype>;
- opmath_t beta_val = 0; // bias is added in epilogue
-
- cudaDataType_t abcType;
- cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
- cudaDataType_t scaleType = CUDA_R_32F;
- if (std::is_same<Dtype, double>::value) {
- abcType = CUDA_R_64F;
- computeType = CUBLAS_COMPUTE_64F;
- scaleType = CUDA_R_64F;
- } else if (std::is_same<Dtype, float>::value) {
- // Should set computeType = CUBLAS_COMPUTE_32F_FAST_TF32 ?
- abcType = CUDA_R_32F;
- } else if (std::is_same<Dtype, at::Half>::value) {
- abcType = CUDA_R_16F;
- } else if (std::is_same<Dtype, at::BFloat16>::value) {
- abcType = CUDA_R_16BF;
- }
-
- CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
- cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
- TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
- computeDesc.descriptor(),
- CUBLASLT_MATMUL_DESC_TRANSA,
- &transa,
- sizeof(transa)));
- cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
- TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
- computeDesc.descriptor(),
- CUBLASLT_MATMUL_DESC_TRANSB,
- &transb,
- sizeof(transb)));
- cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
- TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
- computeDesc.descriptor(),
- CUBLASLT_MATMUL_DESC_EPILOGUE,
- &epilogue,
- sizeof(epilogue)));
- TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
- computeDesc.descriptor(),
- CUBLASLT_MATMUL_DESC_BIAS_POINTER,
- &bias,
- sizeof(Dtype*)));
-
- CuBlasLtMatrixLayout Adesc(
- abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
- CuBlasLtMatrixLayout Bdesc(
- abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
- CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
-
- CuBlasLtMatmulPreference preference;
- size_t workspaceSize = 0;
- TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
- preference.descriptor(),
- CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
- &workspaceSize,
- sizeof(workspaceSize)));
-
- auto workspace = at::empty(
- {static_cast<int64_t>(workspaceSize)},
- at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
-
- cublasLtMatmulHeuristicResult_t heuristicResult = {};
- int returnedResult = 0;
- cublasLtHandle_t ltHandle =
- reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
- TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
- ltHandle,
- computeDesc.descriptor(),
- Adesc.descriptor(),
- Bdesc.descriptor(),
- Cdesc.descriptor(),
- Cdesc.descriptor(),
- preference.descriptor(),
- 1,
- &heuristicResult,
- &returnedResult));
- if (returnedResult == 0) {
- TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
- }
-
- TORCH_CUDABLAS_CHECK(cublasLtMatmul(
- ltHandle,
- computeDesc.descriptor(),
- &alpha_val,
- mat1_ptr,
- Adesc.descriptor(),
- mat2_ptr,
- Bdesc.descriptor(),
- &beta_val,
- result_ptr,
- Cdesc.descriptor(),
- result_ptr,
- Cdesc.descriptor(),
- &heuristicResult.algo,
- workspace.data_ptr(),
- workspaceSize,
- at::cuda::getCurrentCUDAStream()));
-}
-
-template void gemm_and_bias(
- bool transpose_mat1,
- bool transpose_mat2,
- int64_t m,
- int64_t n,
- int64_t k,
- at::opmath_type<double> alpha_val,
- const double* mat1_ptr,
- int64_t mat1_ld,
- const double* mat2_ptr,
- int64_t mat2_ld,
- const double* bias,
- double* result_ptr,
- int64_t result_ld);
-
-template void gemm_and_bias(
- bool transpose_mat1,
- bool transpose_mat2,
- int64_t m,
- int64_t n,
- int64_t k,
- at::opmath_type<float> alpha_val,
- const float* mat1_ptr,
- int64_t mat1_ld,
- const float* mat2_ptr,
- int64_t mat2_ld,
- const float* bias,
- float* result_ptr,
- int64_t result_ld);
-
-template void gemm_and_bias(
- bool transpose_mat1,
- bool transpose_mat2,
- int64_t m,
- int64_t n,
- int64_t k,
- at::opmath_type<at::Half> alpha_val,
- const at::Half* mat1_ptr,
- int64_t mat1_ld,
- const at::Half* mat2_ptr,
- int64_t mat2_ld,
- const at::Half* bias,
- at::Half* result_ptr,
- int64_t result_ld);
-
-template void gemm_and_bias(
- bool transpose_mat1,
- bool transpose_mat2,
- int64_t m,
- int64_t n,
- int64_t k,
- at::opmath_type<at::BFloat16> alpha_val,
- const at::BFloat16* mat1_ptr,
- int64_t mat1_ld,
- const at::BFloat16* mat2_ptr,
- int64_t mat2_ld,
- const at::BFloat16* bias,
- at::BFloat16* result_ptr,
- int64_t result_ld);
-#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-
template <>
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasStrsm(
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index adb4b04..8188a99 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -70,24 +70,6 @@
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
-template <typename Dtype>
-void gemm_and_bias(
- bool transpose_mat1,
- bool transpose_mat2,
- int64_t m,
- int64_t n,
- int64_t k,
- at::opmath_type<Dtype> alpha_val,
- const Dtype* mat1_ptr,
- int64_t mat1_ld,
- const Dtype* mat2_ptr,
- int64_t mat2_ld,
- const Dtype* bias,
- Dtype* result_ptr,
- int64_t result_ld);
-#endif
-
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 3032116..5415910 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -102,27 +102,9 @@
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
- bool useLtInterface = false;
- at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
-#if CUDA_VERSION >= 11000
- // Strangely, if mat2 has only 1 row or column, we get
- // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
- // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
- // is to use lt interface only when self is bias.
- useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
- result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
- self.is_contiguous() &&
- (scalar_type == at::ScalarType::Double ||
- scalar_type == at::ScalarType::Float ||
- scalar_type == at::ScalarType::Half ||
- scalar_type == at::ScalarType::BFloat16) &&
- mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
-#endif
- if (!useLtInterface) {
- self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
- }
+ self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
@@ -133,8 +115,8 @@
}
if (&result != &self) {
- at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
- if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
+ at::native::resize_output(result, self__sizes);
+ if (beta.toComplexDouble() != 0.0) {
at::native::copy_(result, *self_);
}
}
@@ -165,6 +147,7 @@
int64_t mat1_ld = mat1_->stride((transpose_mat1 == transpose_result) ? 1 : 0);
int64_t mat2_ld = mat2_->stride((transpose_mat2 == transpose_result) ? 1 : 0);
int64_t result_ld = result_->stride(transpose_result ? 0 : 1);
+ at::ScalarType scalar_type = self_->scalar_type();
if (mat1.numel() == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
@@ -187,61 +170,24 @@
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(_MSC_VER)
- if (useLtInterface) {
- AT_DISPATCH_FLOATING_TYPES_AND2(
- at::ScalarType::Half,
- at::ScalarType::BFloat16,
- scalar_type,
- "addmm_cuda_lt",
- [&] {
- at::cuda::blas::gemm_and_bias<scalar_t>(
- transpose_mat1,
- transpose_mat2,
- m,
- n,
- k,
- alpha.to<at::opmath_type<scalar_t>>(),
- mat1_->data_ptr<scalar_t>(),
- mat1_ld,
- mat2_->data_ptr<scalar_t>(),
- mat2_ld,
- self.data_ptr<scalar_t>(),
- result_->data_ptr<scalar_t>(),
- result_ld);
- });
- } else
-#endif
- {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
- at::ScalarType::Half,
- at::ScalarType::BFloat16,
- scalar_type,
- "addmm_cuda",
- [&] {
- using opmath_t = at::opmath_type<scalar_t>;
- opmath_t alpha_val = alpha.to<opmath_t>();
- opmath_t beta_val = beta.to<opmath_t>();
- scalar_t* mat1_ptr = mat1_->data_ptr<scalar_t>();
- scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
- scalar_t* result_ptr = result_->data_ptr<scalar_t>();
- at::cuda::blas::gemm<scalar_t>(
- transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n',
- transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n',
- m,
- n,
- k,
- alpha_val,
- mat1_ptr,
- mat1_ld,
- mat2_ptr,
- mat2_ld,
- beta_val,
- result_ptr,
- result_ld);
- });
- }
-
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
+ using opmath_t = at::opmath_type<scalar_t>;
+ opmath_t alpha_val = alpha.to<opmath_t>();
+ opmath_t beta_val = beta.to<opmath_t>();
+ scalar_t* mat1_ptr = mat1_->data_ptr<scalar_t>();
+ scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
+ scalar_t* result_ptr = result_->data_ptr<scalar_t>();
+ at::cuda::blas::gemm<scalar_t>(
+ transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n',
+ transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n',
+ m, n, k,
+ alpha_val,
+ mat1_ptr, mat1_ld,
+ mat2_ptr, mat2_ld,
+ beta_val,
+ result_ptr, result_ld
+ );
+ });
if (!result.is_same(*result_)) {
result.copy_(*result_);
}