added THCudaBlas, now handles a bunch of corner cases. Also fixes #55 and another bug in addr
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23bb46c..ac01d16 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -15,7 +15,7 @@
INSTALL(FILES
THC.h
THCGeneral.h
- THCGeneral.h
+ THCBlas.h
THCStorage.h
THCTensor.h
THCTensorRandom.h
diff --git a/THC.cu b/THC.cu
index acc5153..7c5ff3e 100644
--- a/THC.cu
+++ b/THC.cu
@@ -1,6 +1,7 @@
/* thrust library does not allow multiple files */
+#include "THCBlas.cu"
#include "THCStorage.cu"
#include "THCTensor.cu"
#include "THCTensorCopy.cu"
diff --git a/THC.h b/THC.h
index fae5f16..9aab4d44 100644
--- a/THC.h
+++ b/THC.h
@@ -2,6 +2,7 @@
#define THC_INC
#include "THCGeneral.h"
+#include "THCBlas.h"
#include "THCStorage.h"
#include "THCTensor.h"
#include "THCTensorRandom.h"
diff --git a/THCBlas.cu b/THCBlas.cu
new file mode 100644
index 0000000..11b269f
--- /dev/null
+++ b/THCBlas.cu
@@ -0,0 +1,201 @@
+#include "THCBlas.h"
+#include "THCGeneral.h"
+
+
+void THCudaBlas_swap(long n, float *x, long incx, float *y, long incy)
+{
+ if(n == 1)
+ {
+ incx = 1;
+ incy = 1;
+ }
+
+ if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ {
+ int i_n = (int)n;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+ cublasSswap(i_n, x, i_incx, y, i_incy);
+ THCublasCheck();
+ return;
+ }
+ THError("Cublas_swap only supports n, incx and"
+ " incy upto signed integer limits: %d", INT_MAX);
+}
+
+void THCudaBlas_scal(long n, float a, float *x, long incx)
+{
+ if(n == 1)
+ incx = 1;
+
+ if( (n <= INT_MAX) && (incx <= INT_MAX) )
+ {
+ int i_n = (int)n;
+ int i_incx = (int)incx;
+ cublasSscal(i_n, a, x, i_incx);
+ THCublasCheck();
+ return;
+ }
+ THError("Cublas_scal only supports n and incx "
+ "upto signed integer limits: %d", INT_MAX);
+}
+
+void THCudaBlas_copy(long n, float *x, long incx, float *y, long incy)
+{
+ if(n == 1)
+ {
+ incx = 1;
+ incy = 1;
+ }
+
+ if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ {
+ int i_n = (int)n;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+ cublasScopy(i_n, x, i_incx, y, i_incy);
+ THCublasCheck();
+ return;
+ }
+
+ THError("Cublas_copy only supports n, incx and incy "
+ "upto signed integer limits: %d", INT_MAX);
+}
+
+void THCudaBlas_axpy(long n, float a, float *x, long incx, float *y, long incy)
+{
+ if(n == 1)
+ {
+ incx = 1;
+ incy = 1;
+ }
+
+ if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ {
+ int i_n = (int)n;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+ cublasSaxpy(i_n, a, x, i_incx, y, i_incy);
+ THCublasCheck();
+ return;
+ }
+
+ THError("Cublas_axpy only supports n, incx and incy "
+ "upto signed integer limits: %d", INT_MAX);
+
+}
+float THCudaBlas_dot(long n, float *x, long incx, float *y, long incy)
+{
+ if(n == 1)
+ {
+ incx = 1;
+ incy = 1;
+ }
+
+ if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ {
+ int i_n = (int)n;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+ float result;
+ result = cublasSdot(i_n, x, i_incx, y, i_incy);
+ THCublasCheck();
+ return result;
+ }
+ THError("Cublas_dot only supports n, incx and incy "
+ "upto signed integer limits: %d", INT_MAX);
+ return -1;
+}
+
+/* Level 2 */
+void THCudaBlas_gemv(char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy)
+{
+ if(n == 1)
+ lda = m;
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) &&
+ (lda > 0) && (lda <= INT_MAX) &&
+ (incx > 0) && (incx <= INT_MAX) &&
+ (incy > 0) && (incy <= INT_MAX) )
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_lda = (int)lda;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+
+ cublasSgemv(trans, i_m, i_n, alpha, a, i_lda, x, i_incx, beta, y, i_incy);
+ THCublasCheck();
+ return;
+ }
+ THError("Cublas_gemv only supports m, n, lda, incx, incy"
+ "in the range 0 < [val] <= %d", INT_MAX);
+}
+void THCudaBlas_ger(long m, long n, float alpha, float *x, long incx, float *y, long incy, float *a, long lda)
+{
+ if(n == 1)
+ lda = m;
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_lda = (int)lda;
+ int i_incx = (int)incx;
+ int i_incy = (int)incy;
+
+ cublasSger(i_m, i_n, alpha, x, i_incx, y, i_incy, a, i_lda);
+ THCublasCheck();
+ return;
+ }
+ THError("Cublas_ger only supports m, n, lda, incx, incy"
+ "with the bound [val] <= %d", INT_MAX);
+}
+
+/* Level 3 */
+void THCudaBlas_gemm(char transa, char transb, long m, long n, long k, float alpha, float *a, long lda, float *b, long ldb, float beta, float *c, long ldc)
+{
+ int transa_ = ((transa == 't') || (transa == 'T'));
+ int transb_ = ((transb == 't') || (transb == 'T'));
+
+ if(n == 1)
+ ldc = m;
+
+ if(transa_)
+ {
+ if(m == 1)
+ lda = k;
+ }
+ else
+ {
+ if(k == 1)
+ lda = m;
+ }
+
+ if(transb_)
+ {
+ if(k == 1)
+ ldb = n;
+ }
+ else
+ {
+ if(n == 1)
+ ldb = k;
+ }
+
+ if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
+ {
+ int i_m = (int)m;
+ int i_n = (int)n;
+ int i_k = (int)k;
+ int i_lda = (int)lda;
+ int i_ldb = (int)ldb;
+ int i_ldc = (int)ldc;
+
+ cublasSgemm(transa, transb, i_m, i_n, i_k, alpha, a, i_lda, b, i_ldb, beta, c, i_ldc);
+ THCublasCheck();
+ return;
+ }
+ THError("Cublas_gemm only supports m, n, k, lda, ldb, ldc"
+ "with the bound [val] <= %d", INT_MAX);
+}
diff --git a/THCBlas.h b/THCBlas.h
new file mode 100644
index 0000000..cc3b996
--- /dev/null
+++ b/THCBlas.h
@@ -0,0 +1,27 @@
+#ifndef THC_BLAS_INC
+#define THC_BLAS_INC
+
+#include "THCGeneral.h"
+
+#undef TH_API
+#define TH_API THC_API
+#define real float
+#define Real Cuda
+#define THBlas_(NAME) TH_CONCAT_4(TH,Real,Blas_,NAME)
+
+#define TH_GENERIC_FILE "generic/THBlas.h"
+#include "generic/THBlas.h"
+#undef TH_GENERIC_FILE
+
+#undef THBlas_
+#undef real
+#undef Real
+#undef TH_API
+
+#ifdef WIN32
+# define TH_API THC_EXTERNC __declspec(dllimport)
+#else
+# define TH_API THC_EXTERNC
+#endif
+
+#endif
diff --git a/THCGeneral.h b/THCGeneral.h
index 0ed43e6..bb4cdbf 100644
--- a/THCGeneral.h
+++ b/THCGeneral.h
@@ -6,7 +6,6 @@
#include "cuda.h"
#include "cublas.h"
-//#include "cuda_runtime_api.h"
#ifdef __cplusplus
# define THC_EXTERNC extern "C"
diff --git a/THCTensorMath.cu b/THCTensorMath.cu
index 10c65ff..08c9996 100644
--- a/THCTensorMath.cu
+++ b/THCTensorMath.cu
@@ -1,5 +1,6 @@
#include "THCTensorMath.h"
#include "THCGeneral.h"
+#include "THCBlas.h"
#include "THCTensorRandom.h"
#include <thrust/fill.h>
@@ -55,8 +56,7 @@
{
THCudaTensor *self = THCudaTensor_newContiguous(self_);
- cublasSscal(THCudaTensor_nElement(self), value, THCudaTensor_data(self), 1);
- THCublasCheck();
+ THCudaBlas_scal(THCudaTensor_nElement(self), value, THCudaTensor_data(self), 1);
THCudaTensor_freeCopyTo(self, self_);
}
@@ -65,8 +65,7 @@
{
THCudaTensor *self = THCudaTensor_newContiguous(self_);
- cublasSscal(THCudaTensor_nElement(self), 1/value, THCudaTensor_data(self), 1);
- THCublasCheck();
+ THCudaBlas_scal(THCudaTensor_nElement(self), 1/value, THCudaTensor_data(self), 1);
THCudaTensor_freeCopyTo(self, self_);
}
@@ -79,8 +78,7 @@
THCudaTensor *self = THCudaTensor_newContiguous(self_);
src = THCudaTensor_newContiguous(src);
- cublasSaxpy(THCudaTensor_nElement(self), value, THCudaTensor_data(src), 1, THCudaTensor_data(self), 1);
- THCublasCheck();
+ THCudaBlas_axpy(THCudaTensor_nElement(self), value, THCudaTensor_data(src), 1, THCudaTensor_data(self), 1);
THCudaTensor_free(src);
THCudaTensor_freeCopyTo(self, self_);
@@ -98,8 +96,7 @@
src2 = THCudaTensor_newContiguous(src2);
THCudaTensor_copy(self, src1);
- cublasSaxpy(THCudaTensor_nElement(self), value, THCudaTensor_data(src2), 1, THCudaTensor_data(self), 1);
- THCublasCheck();
+ THCudaBlas_axpy(THCudaTensor_nElement(self), value, THCudaTensor_data(src2), 1, THCudaTensor_data(self), 1);
THCudaTensor_free(src1);
THCudaTensor_free(src2);
@@ -229,12 +226,9 @@
self = THCudaTensor_newContiguous(self);
src = THCudaTensor_newContiguous(src);
- float result = cublasSdot(THCudaTensor_nElement(self),
- THCudaTensor_data(self), 1,
- THCudaTensor_data(src), 1);
-
- THCublasCheck();
-
+ float result = THCudaBlas_dot(THCudaTensor_nElement(self),
+ THCudaTensor_data(self), 1,
+ THCudaTensor_data(src), 1);
THCudaTensor_free(src);
THCudaTensor_free(self);
@@ -517,14 +511,14 @@
if(mat->stride[0] == 1)
{
- cublasSgemv('n', mat->size[0], mat->size[1],
+ THCudaBlas_gemv('n', mat->size[0], mat->size[1],
alpha, THCudaTensor_data(mat), mat->stride[1],
THCudaTensor_data(vec), vec->stride[0],
beta, THCudaTensor_data(self), self->stride[0]);
}
else if(mat->stride[1] == 1)
{
- cublasSgemv('t', mat->size[1], mat->size[0],
+ THCudaBlas_gemv('t', mat->size[1], mat->size[0],
alpha, THCudaTensor_data(mat), mat->stride[0],
THCudaTensor_data(vec), vec->stride[0],
beta, THCudaTensor_data(self), self->stride[0]);
@@ -533,7 +527,7 @@
{
mat = THCudaTensor_newContiguous(mat);
- cublasSgemv('t', mat->size[1], mat->size[0],
+ THCudaBlas_gemv('t', mat->size[1], mat->size[0],
alpha, THCudaTensor_data(mat), mat->stride[0],
THCudaTensor_data(vec), vec->stride[0],
beta, THCudaTensor_data(self), self->stride[0]);
@@ -541,7 +535,6 @@
THCudaTensor_free(mat);
}
- THCublasCheck();
}
void THCudaTensor_addmm(THCudaTensor *self, float beta, float alpha, THCudaTensor *m1, THCudaTensor *m2)
@@ -619,21 +612,19 @@
}
/* do the operation */
- cublasSgemm(transpose_m1,
- transpose_m2,
- self_->size[0],
- self_->size[1],
- m1_->size[1],
- alpha,
- THCudaTensor_data(m1_),
- (transpose_m1 == 'n' ? m1_->stride[1] : m1_->stride[0]),
- THCudaTensor_data(m2_),
- (transpose_m2 == 'n' ? m2_->stride[1] : m2_->stride[0]),
- beta,
- THCudaTensor_data(self_),
- self_->stride[1]);
-
- THCublasCheck();
+ THCudaBlas_gemm(transpose_m1,
+ transpose_m2,
+ self_->size[0],
+ self_->size[1],
+ m1_->size[1],
+ alpha,
+ THCudaTensor_data(m1_),
+ (transpose_m1 == 'n' ? m1_->stride[1] : m1_->stride[0]),
+ THCudaTensor_data(m2_),
+ (transpose_m2 == 'n' ? m2_->stride[1] : m2_->stride[0]),
+ beta,
+ THCudaTensor_data(self_),
+ self_->stride[1]);
/* free intermediate variables */
if(m1_ != m1)
@@ -666,14 +657,14 @@
if(self->stride[0] == 1)
{
- cublasSger(vec1->size[0], vec2->size[0],
+ THCudaBlas_ger(vec1->size[0], vec2->size[0],
alpha, THCudaTensor_data(vec1), vec1->stride[0],
THCudaTensor_data(vec2), vec2->stride[0],
THCudaTensor_data(self), self->stride[1]);
}
else if(self->stride[1] == 1)
{
- cublasSger(vec2->size[0], vec1->size[0],
+ THCudaBlas_ger(vec2->size[0], vec1->size[0],
alpha, THCudaTensor_data(vec2), vec2->stride[0],
THCudaTensor_data(vec1), vec1->stride[0],
THCudaTensor_data(self), self->stride[0]);
@@ -682,7 +673,7 @@
{
THCudaTensor *cself = THCudaTensor_newClone(self);
- cublasSger(vec2->size[0], vec1->size[0],
+ THCudaBlas_ger(vec2->size[0], vec1->size[0],
alpha, THCudaTensor_data(vec2), vec2->stride[0],
THCudaTensor_data(vec1), vec1->stride[0],
THCudaTensor_data(cself), cself->stride[0]);
@@ -690,7 +681,6 @@
THCudaTensor_freeCopyTo(cself, self);
}
- THCublasCheck();
}
#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC) \