[cutorch mag2gen] move eig to generic
diff --git a/THCTensorMath.h b/THCTensorMath.h
index 32e18cf..7ce7504 100644
--- a/THCTensorMath.h
+++ b/THCTensorMath.h
@@ -45,7 +45,6 @@
#include "THCGenerateAllTypes.h"
// MAGMA (i.e. CUDA implementation of LAPACK functions)
-THC_API void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvr);
THC_API void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu);
THC_API void THCudaTensor_gesvd2(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *ra_, THCudaTensor *a, const char *jobu);
THC_API void THCudaTensor_getri(THCState *state, THCudaTensor *ra_, THCudaTensor *a);
diff --git a/THCTensorMathMagma.cu b/THCTensorMathMagma.cu
index 029811e..47bc484 100644
--- a/THCTensorMathMagma.cu
+++ b/THCTensorMathMagma.cu
@@ -23,67 +23,6 @@
#endif
}
-void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvrs)
-{
-#ifdef USE_MAGMA
- THArgCheck(a_->nDimension == 2, 3, "A should be 2 dimensional");
- THArgCheck(a_->size[0] == a_->size[1], 3, "A should be square");
-
- magma_vec_t jobvr = jobvrs[0] == 'N' ? MagmaNoVec : MagmaVec;
- int n = a_->size[0];
-
- float *a_data = th_magma_malloc_pinned<float>(n * n);
- THCudaTensor_copyTensor2d(state, a_data, a_);
-
- float *wr = th_magma_malloc_pinned<float>(n);
- float *wi = th_magma_malloc_pinned<float>(n);
-
- float *vr_data = NULL;
- int ldvr = 1;
- if (jobvr == MagmaVec)
- {
- vr_data = th_magma_malloc_pinned<float>(n * n);
- ldvr = n;
- }
-
- float wkopt;
- int info;
-
- magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
-
- int lwork = (int) wkopt;
- float *work_data = th_magma_malloc_pinned<float>(lwork);
-
- magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
-
- if (info > 0)
- THError("MAGMA geev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info);
- else if (info < 0)
- THError("MAGMA geev : Argument %d : illegal value", -info);
-
- {
- THCudaTensor_resize2d(state, re_, 2, n);
- THCudaTensor *re = THCudaTensor_newContiguous(state, re_);
- THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset, wr, n*sizeof(float), cudaMemcpyHostToDevice));
- THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset + n, wi, n*sizeof(float), cudaMemcpyHostToDevice));
- THCudaTensor_freeCopyTo(state, re, re_);
- THCudaTensor_transpose(state, re_, NULL, 0, 1);
- }
-
- if (jobvr == MagmaVec)
- THCudaTensor_copyArray2d(state, rv_, vr_data, n, n);
-
- magma_free_pinned(work_data);
- magma_free_pinned(vr_data);
- magma_free_pinned(wi);
- magma_free_pinned(wr);
- magma_free_pinned(a_data);
-
-#else
- THError(NoMagma(geev));
-#endif
-}
-
void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu)
{
#ifdef USE_MAGMA
diff --git a/generic/THCTensorMathMagma.cu b/generic/THCTensorMathMagma.cu
index feab665..e0d505c 100644
--- a/generic/THCTensorMathMagma.cu
+++ b/generic/THCTensorMathMagma.cu
@@ -145,6 +145,74 @@
#endif
}
+THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvrs)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a_->nDimension == 2, 3, "A should be 2 dimensional");
+ THArgCheck(a_->size[0] == a_->size[1], 3, "A should be square");
+
+ magma_vec_t jobvr = jobvrs[0] == 'N' ? MagmaNoVec : MagmaVec;
+ int n = a_->size[0];
+
+ real *a_data = th_magma_malloc_pinned<real>(n * n);
+ THCTensor_(copyTensor2d)(state, a_data, a_);
+
+ real *wr = th_magma_malloc_pinned<real>(n);
+ real *wi = th_magma_malloc_pinned<real>(n);
+
+ real *vr_data = NULL;
+ int ldvr = 1;
+ if (jobvr == MagmaVec)
+ {
+ vr_data = th_magma_malloc_pinned<real>(n * n);
+ ldvr = n;
+ }
+
+ real wkopt;
+ int info;
+
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
+#else
+ magma_dgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
+#endif
+
+ int lwork = (int) wkopt;
+ real *work_data = th_magma_malloc_pinned<real>(lwork);
+
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
+#else
+ magma_dgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
+#endif
+
+ if (info > 0)
+ THError("MAGMA geev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info);
+ else if (info < 0)
+ THError("MAGMA geev : Argument %d : illegal value", -info);
+
+ {
+ THCTensor_(resize2d)(state, re_, 2, n);
+ THCTensor *re = THCTensor_(newContiguous)(state, re_);
+ THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset, wr, n*sizeof(real), cudaMemcpyHostToDevice));
+ THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset + n, wi, n*sizeof(real), cudaMemcpyHostToDevice));
+ THCTensor_(freeCopyTo)(state, re, re_);
+ THCTensor_(transpose)(state, re_, NULL, 0, 1);
+ }
+
+ if (jobvr == MagmaVec)
+ THCTensor_(copyArray2d)(state, rv_, vr_data, n, n);
+
+ magma_free_pinned(work_data);
+ magma_free_pinned(vr_data);
+ magma_free_pinned(wi);
+ magma_free_pinned(wr);
+ magma_free_pinned(a_data);
+
+#else
+ THError(NoMagma(geev));
+#endif
+}
#endif
diff --git a/generic/THCTensorMathMagma.h b/generic/THCTensorMathMagma.h
index c09a7bb..1e72ef6 100644
--- a/generic/THCTensorMathMagma.h
+++ b/generic/THCTensorMathMagma.h
@@ -62,6 +62,7 @@
THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo);
+THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr);
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)