blob: c1a842653f4c4f62f03fce2f461479b67e4a0e57 [file] [log] [blame]
/*
Provides the implementations of CUDA BLAS function templates.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDADataType.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/macros/Export.h>
#include <c10/util/irange.h>
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if !defined(USE_ROCM) && !defined(_MSC_VER)
#include <cublasLt.h>
#endif
#ifdef USE_ROCM
// until hipblas has an API to accept flags, we must use rocblas here
#include <rocblas/rocblas.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
return HIPBLAS_STATUS_INTERNAL_ERROR;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// hipblas does not have hipblasSetMathMode
#define hipblasSetMathMode(handle, flags) HIPBLAS_STATUS_SUCCESS
// until we use hiblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_C_16F HIPBLAS_C_16F
#define HIP_C_32F HIPBLAS_C_32F
#define HIP_C_64F HIPBLAS_C_64F
#define HIP_R_8I HIPBLAS_R_8I
#define HIP_R_8U HIPBLAS_R_8U
#define HIP_R_32I HIPBLAS_R_32I
#define HIP_R_32U HIPBLAS_R_32U
#define HIP_C_8I HIPBLAS_C_8I
#define HIP_C_8U HIPBLAS_C_8U
#define HIP_C_32I HIPBLAS_C_32I
#define HIP_C_32U HIPBLAS_C_32U
#define HIP_R_16BF HIPBLAS_R_16B
#define HIP_C_16BF HIPBLAS_C_16B
#endif
#define CUDABLAS_POSINT_CHECK(FD, X) \
TORCH_CHECK( \
(X > 0 && X <= INT_MAX), \
"at::cuda::blas::" #FD " argument " #X \
" must be positive and less than ", \
INT_MAX, \
" but got ", \
X)
#define CUDABLAS_NONNEGINT_CHECK(FD, X) \
TORCH_CHECK( \
(X >= 0 && X <= INT_MAX), \
"at::cuda::blas::" #FD " argument " #X \
" must be non-negative and less than ", \
INT_MAX, \
" but got ", \
X)
namespace {
static cublasOperation_t _cublasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
return CUBLAS_OP_N;
case 't':
case 'T':
return CUBLAS_OP_T;
case 'c':
case 'C':
return CUBLAS_OP_C;
}
AT_ERROR(
"_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}
static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
// Q: Why does Level3 check trans but this doesn't?
// A: In level 2, the sizes (m, n) specify the size of A
// (independent of trans value). In level 3. the sizes (m, n, k)
// specify the sizes of op(A), op(B) where op depend on trans
// values.
if (n <= 1)
*lda = std::max<int64_t>(m, 1);
}
static void _cublasAdjustLdLevel3(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
int64_t* lda,
int64_t* ldb,
int64_t* ldc) {
bool transa_ = ((transa != 'n') && (transa != 'N'));
bool transb_ = ((transb != 'n') && (transb != 'N'));
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
if (n <= 1)
*ldc = std::max<int64_t>(m, 1);
if (transa_) {
if (m <= 1)
*lda = std::max<int64_t>(k, 1);
} else {
if (k <= 1)
*lda = std::max<int64_t>(m, 1);
}
if (transb_) {
if (k <= 1)
*ldb = std::max<int64_t>(n, 1);
} else {
if (n <= 1)
*ldb = std::max<int64_t>(k, 1);
}
}
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
for (; ; alignment /= 2) {
if (!(address % alignment)) {
return alignment;
}
}
}
static size_t _parseChosenWorkspaceSize() {
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
if (val) {
try {
workspace_size = std::stoi(val);
} catch(std::invalid_argument const& e) {
TORCH_WARN("invalid CUBLAS_LT_WORKSPACE_SIZE,",
" using default workspace size of ", workspace_size, " bytes.");
} catch(std::out_of_range const& e) {
TORCH_WARN("CUBLAS_LT_WORKSPACE_SIZE out of range,",
" using default workspace size of ", workspace_size, " bytes.");
}
}
return workspace_size * 1024;
}
static size_t _getWorkspaceSize() {
static size_t workspace_size = _parseChosenWorkspaceSize();
return workspace_size;
}
} // anonymous namespace
namespace at::cuda::blas {
/* LEVEL 3 BLAS FUNCTIONS */
#define GEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, n); \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, k); \
CUDABLAS_POSINT_CHECK(gemm<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldb); \
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
} while (0)
#define BGEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)
template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
}
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
}
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifdef USE_ROCM
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
hipOperationToRocOperation(opa),
hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
handle, opa, opb, m, n, k,
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
c, CUDA_R_16F, ldc, stridec,
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
for (const auto i : c10::irange(num_batches)) {
at::cuda::blas::gemm<at::Half>(
transa, transb,
m, n, k,
alpha, (a + i * stridea), lda,
(b + i * strideb), ldb, beta,
(c + i * stridec), ldc);
}
}
#endif // USE_ROCM
}
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
BGEMM_CHECK_ARGVALUES(at::BFloat16);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
const float falpha = alpha;
const float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemm(
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemm(
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc));
}
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc));
}
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
#ifdef USE_ROCM
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle)handle,
hipOperationToRocOperation(opa),
hipOperationToRocOperation(opb),
m,
n,
k,
&falpha,
a,
rocblas_datatype_f16_r,
lda,
b,
rocblas_datatype_f16_r,
ldb,
&fbeta,
c,
rocblas_datatype_f16_r,
ldc,
c,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5) {
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
&fbeta,
c,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
&fbeta,
c,
CUDA_R_16F,
ldc));
}
#endif
}
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16);
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16BF,
lda,
b,
CUDA_R_16BF,
ldb,
&fbeta,
c,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#if !defined(USE_ROCM) && !defined(_MSC_VER)
namespace {
// Following the pattern of CuSparseDescriptor
// Defined here for now because this is the only place cublas_lt interface is
// used but can be moved to a header once cublas_lt interface is used in
// multiple places.
template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};
template <typename T, cublasStatus_t (*destructor)(T*)>
class CuBlasLtDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}
protected:
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
};
class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
cublasLtMatmulDescOpaque_t,
&cublasLtMatmulDescDestroy> {
public:
CuBlasLtMatmulDescriptor(
cublasComputeType_t compute_type,
cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
cublasLtMatrixLayoutOpaque_t,
&cublasLtMatrixLayoutDestroy> {
public:
CuBlasLtMatrixLayout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
};
class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
cublasLtMatmulPreferenceOpaque_t,
&cublasLtMatmulPreferenceDestroy> {
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
} // namespace
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const Dtype* bias,
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation) {
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = 0; // bias is added in epilogue
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
if constexpr (std::is_same_v<Dtype, double>) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
abcType = CUDA_R_16F;
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abcType = CUDA_R_16BF;
}
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
}
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
CuBlasLtMatrixLayout Adesc(abcType, m, k, mat1_ld, transpose_mat1);
CuBlasLtMatrixLayout Bdesc(abcType, k, n, mat2_ld, transpose_mat2);
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_get(),
workspaceSize,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transpose_mat1,
" transpose_mat2 ",
transpose_mat2,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" abcType ",
abcType,
" computeType ",
computeType,
" scaleType ",
scaleType);
}
template void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<double> alpha_val,
const double* mat1_ptr,
int64_t mat1_ld,
const double* mat2_ptr,
int64_t mat2_ld,
const double* bias,
double* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
template void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<float> alpha_val,
const float* mat1_ptr,
int64_t mat1_ld,
const float* mat2_ptr,
int64_t mat2_ld,
const float* bias,
float* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
template void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<at::Half> alpha_val,
const at::Half* mat1_ptr,
int64_t mat1_ld,
const at::Half* mat2_ptr,
int64_t mat2_ld,
const at::Half* bias,
at::Half* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
template void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<at::BFloat16> alpha_val,
const at::BFloat16* mat1_ptr,
int64_t mat1_ld,
const at::BFloat16* mat2_ptr,
int64_t mat2_ld,
const at::BFloat16* bias,
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
void scaled_gemm(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
const void* mat1_ptr,
const void* mat1_scale_ptr,
int64_t mat1_ld,
ScalarType mat1_dtype,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum) {
#if CUDA_VERSION >= 11080
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(bias_dtype), m, n, result_ld);
CuBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
if (bias_ptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
size_t workspaceSize = _getWorkspaceSize();
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Ddesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
float alpha_val = 1.0;
float beta_val = 0.0;
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
nullptr,
Cdesc.descriptor(),
result_ptr,
Ddesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_get(),
workspaceSize,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transa,
" transpose_mat2 ",
transb,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" computeType ",
computeType,
" scaleType ",
scaleType);
return;
#endif // CUDA_VERSION >= 11080
TORCH_CHECK(false, "scaled_gemm is only supported for CUDA 11.8 and above");
}
void int8_gemm(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
const int8_t* mat1_ptr,
int64_t mat1_ld,
const int8_t* mat2_ptr,
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld) {
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
cudaDataType_t scaleType = CUDA_R_32I;
cudaDataType_t abType = CUDA_R_8I;
cudaDataType_t cType = CUDA_R_32I;
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// cublas team: alpha and beta need to be the same dtype as of scaleType
at::opmath_type<int32_t> alpha_val = 1;
int32_t beta_val = 0;
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
nullptr, // Heuristics don't seem to work for int8
nullptr, // Non-zero workspace doesn't seem to work.
0,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transpose_mat1,
" transpose_mat2 ",
transpose_mat2,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" abType ",
abType,
" cType ",
cType,
" computeType ",
computeType,
" scaleType ",
scaleType);
}
#endif // !defined(USE_ROCM) && !defined(_MSC_VER)
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#if defined(USE_ROCM) && ROCM_VERSION <= 56000
#define ROCM_CONST_BUG
#else
#define ROCM_CONST_BUG const
#endif
template <>
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasStrsm(
handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
}
template <>
void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDtrsm(
handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
}
template <>
void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCtrsm(
handle,
side,
uplo,
trans,
diag,
m,
n,
reinterpret_cast<const cuComplex*>(alpha),
reinterpret_cast<ROCM_CONST_BUG cuComplex*>(A),
lda,
reinterpret_cast<cuComplex*>(B),
ldb));
}
template <>
void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZtrsm(
handle,
side,
uplo,
trans,
diag,
m,
n,
reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<ROCM_CONST_BUG cuDoubleComplex*>(A),
lda,
reinterpret_cast<cuDoubleComplex*>(B),
ldb));
}
template <>
void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasStrsmBatched(
handle,
side,
uplo,
trans,
diag,
m,
n,
alpha,
A,
lda,
B,
ldb,
batchCount));
}
template <>
void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDtrsmBatched(
handle,
side,
uplo,
trans,
diag,
m,
n,
alpha,
A,
lda,
B,
ldb,
batchCount));
}
template <>
void trsmBatched<c10::complex<float>>(
CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCtrsmBatched(
handle,
side,
uplo,
trans,
diag,
m,
n,
reinterpret_cast<const cuComplex*>(alpha),
reinterpret_cast<cuComplex**>(A),
lda,
reinterpret_cast<cuComplex**>(B),
ldb,
batchCount));
}
template <>
void trsmBatched<c10::complex<double>>(
CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZtrsmBatched(
handle,
side,
uplo,
trans,
diag,
m,
n,
reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<cuDoubleComplex**>(A),
lda,
reinterpret_cast<cuDoubleComplex**>(B),
ldb,
batchCount));
}
/* LEVEL 2 BLAS FUNCTIONS */
#define GEMV_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, n); \
CUDABLAS_POSINT_CHECK(gemv<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incx); \
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
} while (0)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(
cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(y), incy));
}
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(
cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(y), incy));
}
template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(
cublasDgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
}
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(
cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
}
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
// In general, cublas regards matrices as column-major.
// The cublasS/Dgemv usages in cuda::blas::gemv<float>/<double> above
// require that external blas::gemv callers obey the following convention:
//
// If "a" is row-major with shape (output, summed) in blas::gemv's caller,
// caller interprets it as column-major with shape (summed, output), passes
// summed and output respectively to our local vars m, n, and requests that cublas
// internally transpose ("trans") the column-major interpretation of a.
//
// There's no such thing as "cublasHalfgemv", so here we hack gemv with a gemm.
// However, we must allow the same calling convention, because the caller shouldn't
// have to swap args based on whether it's calling blas::gemv<at::Half> or <float>.
bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
if (trans_bool) {
std::swap(m, n);
}
// After swap, local vars m, n contain the output and summed sizes respectively,
// regardless of whether "a" was row-major or column-major in gemv<>'s caller.
// To handle the possibility incy > 1, interprets vector y as column-major matrix with one row
// (shape (1, output)) and leading dim incy.
// trans(a)*x would compute a matrix with one column (shape (output, 1)) which wouldn't match y.
// So instead, we interpret x similarly to y, as a column-major matrix with one row
// (shape (1, summed)) and leading dim incx. The gemm then carries out x*transpose(trans(a)) to
// produce a matrix with one row (shape (1, output)), matching y.
char trans_flipped = (trans_bool ? 'n' : 't');
gemm<at::Half>(
'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
}
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
if (trans_bool) {
std::swap(m, n);
}
char trans_flipped = (trans_bool ? 'n' : 't');
gemm<at::BFloat16>(
'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
}
/* LEVEL 1 BLAS FUNCTIONS */
template <>
void dot<double>(CUDABLAS_DOT_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDdot(handle, n, x, incx, y, incy, result));
}
template <>
void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasSdot(handle, n, x, incx, y, incy, result));
}
template <>
void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
reinterpret_cast<cuDoubleComplex*>(result)));
}
template <>
void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x),
incx, reinterpret_cast<const cuComplex*>(y), incy,
reinterpret_cast<cuComplex*>(result)));
}
template <>
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
x,
CUDA_R_16F,
incx,
y,
CUDA_R_16F,
incy,
result,
CUDA_R_16F,
CUDA_R_32F));
}
template <>
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
x,
CUDA_R_16BF,
incx,
y,
CUDA_R_16BF,
incy,
result,
CUDA_R_16BF,
CUDA_R_32F));
}
template <>
void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x),
incx, reinterpret_cast<const cuComplex*>(y), incy,
reinterpret_cast<cuComplex*>(result)));
}
template <>
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
reinterpret_cast<cuDoubleComplex*>(result)));
}
template <>
void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasSgetrsBatched(
handle,
trans,
n,
nrhs,
dA_array,
lda,
ipiv_array,
dB_array,
ldb,
info_array,
batchsize));
}
template <>
void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDgetrsBatched(
handle,
trans,
n,
nrhs,
dA_array,
lda,
ipiv_array,
dB_array,
ldb,
info_array,
batchsize));
}
template <>
void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCgetrsBatched(
handle,
trans,
n,
nrhs,
reinterpret_cast<cuComplex**>(dA_array),
lda,
ipiv_array,
reinterpret_cast<cuComplex**>(dB_array),
ldb,
info_array,
batchsize));
}
template <>
void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZgetrsBatched(
handle,
trans,
n,
nrhs,
reinterpret_cast<cuDoubleComplex**>(dA_array),
lda,
ipiv_array,
reinterpret_cast<cuDoubleComplex**>(dB_array),
ldb,
info_array,
batchsize));
}
template <>
void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched(
handle, m, n, A_array, lda, tau_array, info, batchsize));
}
template <>
void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDgeqrfBatched(
handle, m, n, A_array, lda, tau_array, info, batchsize));
}
template <>
void geqrfBatched<c10::complex<float>>(
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCgeqrfBatched(
handle,
m,
n,
reinterpret_cast<cuComplex**>(A_array),
lda,
reinterpret_cast<cuComplex**>(tau_array),
info,
batchsize));
}
template <>
void geqrfBatched<c10::complex<double>>(
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZgeqrfBatched(
handle,
m,
n,
reinterpret_cast<cuDoubleComplex**>(A_array),
lda,
reinterpret_cast<cuDoubleComplex**>(tau_array),
info,
batchsize));
}
template <>
void getrfBatched<double>(
int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasDgetrfBatched(
handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
}
template <>
void getrfBatched<float>(
int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasSgetrfBatched(
handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
}
template <>
void getrfBatched<c10::complex<double>>(
int n,
c10::complex<double>** dA_array,
int ldda,
int* ipiv_array,
int* info_array,
int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasZgetrfBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex**>(dA_array),
ldda,
ipiv_array,
info_array,
batchsize));
}
template <>
void getrfBatched<c10::complex<float>>(
int n,
c10::complex<float>** dA_array,
int ldda,
int* ipiv_array,
int* info_array,
int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasCgetrfBatched(
handle,
n,
reinterpret_cast<cuComplex**>(dA_array),
ldda,
ipiv_array,
info_array,
batchsize));
}
template <>
void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDgelsBatched(
handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
}
template <>
void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasSgelsBatched(
handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
}
template <>
void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZgelsBatched(
handle, trans,
m, n, nrhs,
reinterpret_cast<cuDoubleComplex**>(dA_array),
ldda,
reinterpret_cast<cuDoubleComplex**>(dC_array),
lddc,
info,
devInfoArray,
batchSize));
}
template <>
void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCgelsBatched(
handle, trans,
m, n, nrhs,
reinterpret_cast<cuComplex**>(dA_array),
ldda,
reinterpret_cast<cuComplex**>(dC_array),
lddc,
info,
devInfoArray,
batchSize));
}
} // namespace at::cuda::blas