blob: 2ea3157e06b1cc84ca753bc85185accb0aabc251 [file] [log] [blame]
#include "THCBlas.h"
#include "THCGeneral.h"
#include "THCHalf.h"
float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy)
{
if (n == 1) {
incx = 1;
incy = 1;
}
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
float result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSdot(handle, i_n, x, i_incx, y, i_incy, &result));
return result;
}
THError("Cublas_Sdot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return 0;
}
double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy)
{
if (n == 1) {
incx = 1;
incy = 1;
}
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
double result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDdot(handle, i_n, x, i_incx, y, i_incy, &result));
return result;
}
THError("Cublas_Ddot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return 0;
}
/* Level 2 */
void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy)
{
if(n == 1)
lda = m;
cublasOperation_t op;
if (trans == 't') op = CUBLAS_OP_T;
else if (trans == 'n') op = CUBLAS_OP_N;
else if (trans == 'c') op = CUBLAS_OP_C;
if( (m <= INT_MAX) && (n <= INT_MAX) &&
(lda > 0) && (lda <= INT_MAX) &&
(incx > 0) && (incx <= INT_MAX) &&
(incy > 0) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemv(handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_Sgemv only supports m, n, lda, incx, incy"
"in the range 0 < [val] <= %d", INT_MAX);
}
void THCudaBlas_Dgemv(THCState *state, char trans, long m, long n, double alpha, double *a, long lda, double *x, long incx, double beta, double *y, long incy)
{
if(n == 1)
lda = m;
cublasOperation_t op;
if (trans == 't') op = CUBLAS_OP_T;
else if (trans == 'n') op = CUBLAS_OP_N;
else if (trans == 'c') op = CUBLAS_OP_C;
if( (m <= INT_MAX) && (n <= INT_MAX) &&
(lda > 0) && (lda <= INT_MAX) &&
(incx > 0) && (incx <= INT_MAX) &&
(incy > 0) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemv(handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_Dgemv only supports m, n, lda, incx, incy"
"in the range 0 < [val] <= %d", INT_MAX);
}
void THCudaBlas_Sger(THCState *state, long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda)
{
if(n == 1)
lda = m;
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_Sger only supports m, n, lda, incx, incy"
"with the bound [val] <= %d", INT_MAX);
}
void THCudaBlas_Dger(THCState *state, long m, long n, double alpha, double *x, long incx, double *y, long incy, double *a, long lda)
{
if(n == 1)
lda = m;
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_Dger only supports m, n, lda, incx, incy"
"with the bound [val] <= %d", INT_MAX);
}
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T;
else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C;
else {
THError("trans must be one of: t, n, c");
return CUBLAS_OP_T;
}
}
void adjustLd(char transa, char transb, long m, long n, long k, long *lda, long *ldb, long *ldc)
{
int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T'));
if(n == 1)
*ldc = m;
if(transa_)
{
if(m == 1)
*lda = k;
}
else
{
if(k == 1)
*lda = m;
}
if(transb_)
{
if(k == 1)
*ldb = n;
}
else
{
if(n == 1)
*ldb = k;
}
}
/* Level 3 */
void THCudaBlas_Sgemm(THCState *state, char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc)
{
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_k = (int)k;
int i_lda = (int)lda;
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_Sgemm only supports m, n, k, lda, ldb, ldc"
"with the bound [val] <= %d", INT_MAX);
}
#ifdef CUDA_HALF_TENSOR
// In CUDA 8.0, definition of data types for sgemmex changed
#if CUDA_VERSION < 8000
# define CUDA_R_16F CUBLAS_DATA_HALF
#endif
void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, long k, half alpha, half *a, long lda, half *b, long ldb, half beta, half *c, long ldc)
{
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_k = (int)k;
int i_lda = (int)lda;
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
// Check for native Hgemm support
if (THC_nativeHalfInstructions(state)) {
THCublasCheck(cublasHgemm(handle, opa, opb,
i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb,
&beta, c, i_ldc));
} else {
// Simulated Hgemm
float fAlpha = THC_half2float(alpha);
float fBeta = THC_half2float(beta);
THCublasCheck(cublasSgemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
}
return;
}
THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc"
"with th bound [val] <= %d", INT_MAX);
}
#endif
void THCudaBlas_Dgemm(THCState *state, char transa, char transb, long m, long n, long k, double alpha, double *a, long lda, double *b, long ldb, double beta, double *c, long ldc)
{
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_k = (int)k;
int i_lda = (int)lda;
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_Dgemm only supports m, n, k, lda, ldb, ldc"
"with the bound [val] <= %d", INT_MAX);
}
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const float *a[], long lda, const float *b[], long ldb,
float beta, float *c[], long ldc, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemmBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
}
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
double alpha, const double *a[], long lda, const double *b[], long ldb,
double beta, double *c[], long ldc, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemmBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
}
/* Inverse */
void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Sgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
}
void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Dgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
}
void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Sgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
}
void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Dgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
}