Remove TH/THC link for single matrix inverse (#20534)
Summary:
- Earlier, we had to use the legacy implementation of `getri` for single matrix inverse from TH and THC
- Now, this has been moved to ATen
Changelog:
- Move single matrix inverse implementation to ATen
- Remove unused code in TH and THC resulting from the change
- Minor modifications made to single matrix CPU function implementations in ATen to avoid redundancy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20534
Differential Revision: D15393383
Pulled By: ezyang
fbshipit-source-id: 81972111cd9757d15f1d634f294c93fd0f35636c
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index f9a61fa..0fb683b 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -2364,22 +2364,6 @@
default: S
]]
[[
- name: _th_getri_single
- cname: getri
- types:
- - Float
- - Double
- backends:
- - CPU
- - CUDA
- variants: function
- return: argument 0
- arguments:
- - arg: THTensor* output
- output: True
- - THTensor* self
-]]
-[[
name: _th_potri
cname: potri
types:
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 2a7682c..c71ff61 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -138,28 +138,22 @@
#else
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
+ auto A_mat_stride = matrixStride(A);
+ auto b_mat_stride = matrixStride(b);
+ auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);
auto ipiv = at::empty({n}, b.options().dtype(kInt));
int info;
- if (b.dim() == 2) {
- lapackSolve<scalar_t>(n, nrhs, A_data, n, ipiv.data<int>(), b_data, n, &info);
- infos[0] = info;
- } else {
- auto A_mat_stride = matrixStride(A);
- auto b_mat_stride = matrixStride(b);
- auto batch_size = batchCount(A);
-
- for (int64_t i = 0; i < batch_size; i++) {
- scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
- scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
- lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
- infos[i] = info;
- if (info != 0) {
- return;
- }
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
+ scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
+ lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
+ infos[i] = info;
+ if (info != 0) {
+ return;
}
}
#endif
@@ -208,7 +202,6 @@
#else
auto self_data = self.data<scalar_t>();
auto self_matrix_stride = matrixStride(self);
-
auto batch_size = batchCount(self);
auto n = self.size(-2);
@@ -217,8 +210,8 @@
scalar_t wkopt;
Tensor work;
+ int info;
for (int64_t i = 0; i < batch_size; i++) {
- int info;
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
infos[i] = info;
@@ -249,7 +242,11 @@
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cpu", [&]{
apply_inverse<scalar_t>(self_working_copy, infos);
});
- batchCheckErrors(infos, "inverse_cpu");
+ if (self.dim() > 2) {
+ batchCheckErrors(infos, "inverse_cpu");
+ } else {
+ singleCheckErrors(infos[0], "inverse_cpu");
+ }
return self_working_copy;
}
@@ -257,9 +254,6 @@
if (self.size(-1) == 0) {
return at::empty_like(self);
}
- if (self.dim() == 2) {
- return at::legacy::th::_th_getri_single(self);
- }
squareCheckInputs(self);
return at::_inverse_helper(self);
}
@@ -283,25 +277,20 @@
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
+ auto A_mat_stride = matrixStride(A);
+ auto b_mat_stride = matrixStride(b);
+ auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);
int info;
- if (b.dim() == 2) {
- lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, n, b_data, n, &info);
- infos[0] = info;
- } else {
- auto A_mat_stride = matrixStride(A);
- auto b_mat_stride = matrixStride(b);
- auto batch_size = batchCount(A);
- for (int64_t i = 0; i < batch_size; i++) {
- scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
- scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
- lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
- infos[i] = info;
- if (info != 0) {
- return;
- }
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
+ scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
+ lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
+ infos[i] = info;
+ if (info != 0) {
+ return;
}
}
#endif
@@ -350,22 +339,17 @@
char uplo = upper ? 'U' : 'L';
auto self_data = self.data<scalar_t>();
+ auto self_matrix_stride = matrixStride(self);
+ auto batch_size = batchCount(self);
auto n = self.size(-2);
int info;
- if (self.dim() == 2) {
- lapackCholesky<scalar_t>(uplo, n, self_data, n, &info);
- infos[0] = info;
- } else {
- auto self_matrix_stride = matrixStride(self);
- auto batch_size = batchCount(self);
- for (int64_t i = 0; i < batch_size; i++) {
- scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
- lapackCholesky<scalar_t>(uplo, n, self_working_ptr, n, &info);
- infos[i] = info;
- if (info != 0) {
- return;
- }
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
+ lapackCholesky<scalar_t>(uplo, n, self_working_ptr, n, &info);
+ infos[i] = info;
+ if (info != 0) {
+ return;
}
}
#endif
@@ -417,21 +401,16 @@
auto self_data = self.data<scalar_t>();
auto pivots_data = pivots.data<int>();
auto infos_data = infos.data<int>();
-
+ auto self_matrix_stride = matrixStride(self);
+ auto pivots_matrix_stride = pivots.size(-1);
+ auto batch_size = batchCount(self);
auto n = self.size(-1);
- if (self.dim() == 2) {
- lapackLu<scalar_t>(n, n, self_data, n, pivots_data, infos_data);
- } else {
- auto self_matrix_stride = matrixStride(self);
- auto batch_size = batchCount(self);
- auto pivots_matrix_stride = pivots.size(-1);
- for (int64_t i = 0; i < batch_size; i++) {
- scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
- int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
- int* infos_working_ptr = &infos_data[i];
- lapackLu<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
- }
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
+ int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
+ int* infos_working_ptr = &infos_data[i];
+ lapackLu<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
}
#endif
}
@@ -458,10 +437,10 @@
});
}
if (check_errors) {
- if (self.dim() == 2) {
- singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
- } else {
+ if (self.dim() > 2) {
batchCheckErrors(infos_tensor, "lu");
+ } else {
+ singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
}
}
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
@@ -621,21 +600,17 @@
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
+ auto A_mat_stride = matrixStride(A);
+ auto b_mat_stride = matrixStride(b);
+ auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);
int info;
- if (b.dim() == 2) {
- lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
- } else {
- auto A_mat_stride = matrixStride(A);
- auto b_mat_stride = matrixStride(b);
- auto batch_size = batchCount(A);
- for (int64_t i = 0; i < batch_size; i++) {
- scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
- scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
- lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
- }
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
+ scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
+ lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
}
#endif
}
diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
index 3984d69..0411850 100644
--- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
@@ -64,6 +64,18 @@
}
template<class scalar_t>
+inline magma_int_t magmaGetriOptimalBlocksize(magma_int_t n) {
+ AT_ERROR("getri only takes float or double Tensors");
+}
+
+template<class scalar_t>
+void magmaGetri(
+ magma_int_t n, scalar_t* dA, magma_int_t ldda, magma_int_t* ipiv, scalar_t* dwork,
+ magma_int_t lwork, magma_int_t* info) {
+ AT_ERROR("getri only takes float or double Tensors");
+}
+
+template<class scalar_t>
void magmaGetriBatched(
magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, scalar_t** dinvA_array, magma_int_t lddia,
@@ -203,6 +215,30 @@
}
template<>
+inline magma_int_t magmaGetriOptimalBlocksize<double>(magma_int_t n) {
+ return magma_get_dgetri_nb(n);
+}
+
+template<>
+inline magma_int_t magmaGetriOptimalBlocksize<float>(magma_int_t n) {
+ return magma_get_sgetri_nb(n);
+}
+
+template<>
+void magmaGetri<double>(
+ magma_int_t n, double* dA, magma_int_t ldda, magma_int_t* ipiv, double* dwork,
+ magma_int_t lwork, magma_int_t* info) {
+ magma_dgetri_gpu(n, dA, ldda, ipiv, dwork, lwork, info);
+}
+
+template<>
+void magmaGetri<float>(
+ magma_int_t n, float* dA, magma_int_t ldda, magma_int_t* ipiv, float* dwork,
+ magma_int_t lwork, magma_int_t* info) {
+ magma_sgetri_gpu(n, dA, ldda, ipiv, dwork, lwork, info);
+}
+
+template<>
void magmaGetriBatched<double>(
magma_int_t n, double** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, double** dinvA_array, magma_int_t lddia,
@@ -382,7 +418,7 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
-static void apply_inverse(Tensor& self, Tensor& self_inv, std::vector<int64_t>& infos) {
+static void apply_batched_inverse(Tensor& self, Tensor& self_inv, std::vector<int64_t>& infos) {
#ifndef USE_MAGMA
AT_ERROR("inverse: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
@@ -429,17 +465,47 @@
#endif
}
-// Because this is out-of-place inverse, the predefined macros will
-// not work
+template <typename scalar_t>
+static void apply_single_inverse(Tensor& self, int64_t& info) {
+#ifndef USE_MAGMA
+AT_ERROR("inverse: MAGMA library not found in "
+ "compilation. Please rebuild with MAGMA.");
+#else
+ auto self_data = self.data<scalar_t>();
+ magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)");
+ magma_int_t lwork = n * magmaGetriOptimalBlocksize<scalar_t>(n);
+ magma_int_t info_tmp = 0;
+
+ Tensor ipiv = at::empty({n}, at::kInt);
+ Tensor dwork = at::empty({lwork}, self.options());
+ magmaLu<scalar_t>(n, n, self_data, n, ipiv.data<magma_int_t>(), &info_tmp);
+ if (info_tmp != 0) {
+ info = info_tmp;
+ return;
+ }
+ magmaGetri<scalar_t>(
+ n, self_data, n, ipiv.data<magma_int_t>(), dwork.data<scalar_t>(), lwork, &info_tmp);
+ info = info_tmp;
+#endif
+}
+
Tensor _inverse_helper_cuda(const Tensor& self) {
- std::vector<int64_t> infos(batchCount(self), 0);
- auto self_working_copy = cloneBatchedColumnMajor(self);
auto self_inv_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
- apply_inverse<scalar_t>(
- self_working_copy, self_inv_working_copy, infos);
- });
- batchCheckErrors(infos, "inverse_cuda");
+ if (self.dim() > 2) {
+ std::vector<int64_t> infos(batchCount(self), 0);
+ auto self_working_copy = cloneBatchedColumnMajor(self);
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
+ apply_batched_inverse<scalar_t>(
+ self_working_copy, self_inv_working_copy, infos);
+ });
+ batchCheckErrors(infos, "inverse_cuda");
+ } else {
+ int64_t info = 0;
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
+ apply_single_inverse<scalar_t>(self_inv_working_copy, info);
+ });
+ singleCheckErrors(info, "inverse_cuda");
+ }
return self_inv_working_copy;
}
diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp
index 1c81ed2..23a2b3f 100644
--- a/aten/src/TH/generic/THLapack.cpp
+++ b/aten/src/TH/generic/THLapack.cpp
@@ -11,12 +11,8 @@
TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
TH_EXTERNC void dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info);
TH_EXTERNC void sgesdd_(char *jobz, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *iwork, int *info);
-TH_EXTERNC void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
-TH_EXTERNC void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
TH_EXTERNC void dgetrs_(char *trans, int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
TH_EXTERNC void sgetrs_(char *trans, int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
-TH_EXTERNC void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
-TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
@@ -89,20 +85,6 @@
#endif
}
-/* LU decomposition */
-void THLapack_(getrf)(int m, int n, scalar_t *a, int lda, int *ipiv, int *info)
-{
-#ifdef USE_LAPACK
-#if defined(TH_REAL_IS_DOUBLE)
- dgetrf_(&m, &n, a, &lda, ipiv, info);
-#else
- sgetrf_(&m, &n, a, &lda, ipiv, info);
-#endif
-#else
- THError("getrf : Lapack library not found in compile time\n");
-#endif
-}
-
void THLapack_(getrs)(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info)
{
#ifdef USE_LAPACK
@@ -116,20 +98,6 @@
#endif
}
-/* Matrix Inverse */
-void THLapack_(getri)(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int* info)
-{
-#ifdef USE_LAPACK
-#if defined(TH_REAL_IS_DOUBLE)
- dgetri_(&n, a, &lda, ipiv, work, &lwork, info);
-#else
- sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
-#endif
-#else
- THError("getri : Lapack library not found in compile time\n");
-#endif
-}
-
/* Cholesky factorization based Matrix Inverse */
void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info)
{
diff --git a/aten/src/TH/generic/THLapack.h b/aten/src/TH/generic/THLapack.h
index 0557834..20d469d 100644
--- a/aten/src/TH/generic/THLapack.h
+++ b/aten/src/TH/generic/THLapack.h
@@ -10,11 +10,7 @@
TH_API void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info);
/* svd */
TH_API void THLapack_(gesdd)(char jobz, int m, int n, scalar_t *a, int lda, scalar_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, int *iwork, int *info);
-/* LU decomposition */
-TH_API void THLapack_(getrf)(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
TH_API void THLapack_(getrs)(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
-/* Matrix Inverse */
-TH_API void THLapack_(getri)(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int* info);
/* Positive Definite matrices */
/* Matrix inverse based on Cholesky factorization */
diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp
index 5705f18..f14ed72 100644
--- a/aten/src/TH/generic/THTensorLapack.cpp
+++ b/aten/src/TH/generic/THTensorLapack.cpp
@@ -440,50 +440,6 @@
}
}
-void THTensor_(getri)(THTensor *ra_, THTensor *a)
-{
- if (a == NULL) a = ra_;
- THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
- THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
-
- int m, n, lda, info, lwork;
- scalar_t wkopt;
- THIntTensor *ipiv;
- THTensor *work;
- THTensor *ra__ = NULL;
-
- ra__ = THTensor_(cloneColumnMajor)(ra_, a);
-
- m = ra__->size(0);
- n = ra__->size(1);
- lda = m;
- ipiv = THIntTensor_newWithSize1d((int64_t)m);
-
- /* Run LU */
- THLapack_(getrf)(n, n, ra__->data<scalar_t>(), lda, THIntTensor_data(ipiv), &info);
- THLapackCheckWithCleanup("Lapack Error %s : U(%d,%d) is 0, U is singular",
- THCleanup(
- c10::raw::intrusive_ptr::decref(ra__);
- THIntTensor_free(ipiv);),
- "getrf", info, info);
-
- /* Run inverse */
- THLapack_(getri)(n, ra__->data<scalar_t>(), lda, THIntTensor_data(ipiv), &wkopt, -1, &info);
- lwork = (int)wkopt;
- work = THTensor_(newWithSize1d)(lwork);
- THLapack_(getri)(n, ra__->data<scalar_t>(), lda, THIntTensor_data(ipiv), work->data<scalar_t>(), lwork, &info);
- THLapackCheckWithCleanup("Lapack Error %s : U(%d,%d) is 0, U is singular",
- THCleanup(
- c10::raw::intrusive_ptr::decref(ra__);
- c10::raw::intrusive_ptr::decref(work);
- THIntTensor_free(ipiv);),
- "getri", info, info);
-
- THTensor_(freeCopyTo)(ra__, ra_);
- c10::raw::intrusive_ptr::decref(work);
- THIntTensor_free(ipiv);
-}
-
void THTensor_(clearUpLoTriangle)(THTensor *a, const char *uplo)
{
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
diff --git a/aten/src/TH/generic/THTensorLapack.h b/aten/src/TH/generic/THTensorLapack.h
index 4c693a8..5c512ab 100644
--- a/aten/src/TH/generic/THTensorLapack.h
+++ b/aten/src/TH/generic/THTensorLapack.h
@@ -8,7 +8,6 @@
TH_API void THTensor_(gesdd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char *some, const char* compute_uv);
TH_API void THTensor_(gesdd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a,
const char *some, const char* compute_uv);
-TH_API void THTensor_(getri)(THTensor *ra_, THTensor *a);
TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo);
TH_API void THTensor_(qr)(THTensor *rq_, THTensor *rr_, THTensor *a);
TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a);
diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu
index 3904873..4927bf2 100644
--- a/aten/src/THC/THCBlas.cu
+++ b/aten/src/THC/THCBlas.cu
@@ -508,37 +508,6 @@
}
#endif
-/* Inverse */
-void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
-#ifndef __HIP_PLATFORM_HCC__
- if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
- {
- THError("Cublas_Sgetrf only supports n, lda, batchSize"
- "with the bound [val] <= %d", INT_MAX);
- }
- cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
- cublasSetStream(handle, THCState_getCurrentStream(state));
- THCublasCheck(cublasSgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
-#else
- THError("THCudaBlas_Sgetrf not supported in ROCM.");
-#endif
-}
-
-void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) {
-#ifndef __HIP_PLATFORM_HCC__
- if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
- {
- THError("Cublas_Dgetrf only supports n, lda, batchSize"
- "with the bound [val] <= %d", INT_MAX);
- }
- cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
- cublasSetStream(handle, THCState_getCurrentStream(state));
- THCublasCheck(cublasDgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
-#else
- THError("THCudaBlas_Dgetrf not supported in ROCM.");
-#endif
-}
-
void THCudaBlas_Sgetrs(THCState *state, char transa, int n, int nrhs, const float **a, int lda, int *pivot, float **b, int ldb, int *info, int batchSize)
{
#ifndef __HIP_PLATFORM_HCC__
@@ -579,33 +548,3 @@
THError("THCudaBlas_Dgetrs not supported in ROCM.");
#endif
}
-
-void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) {
-#ifndef __HIP_PLATFORM_HCC__
- if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
- {
- THError("Cublas_Sgetri only supports n, lda, ldc, batchSize"
- "with the bound [val] <= %d", INT_MAX);
- }
- cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
- cublasSetStream(handle, THCState_getCurrentStream(state));
- THCublasCheck(cublasSgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
-#else
- THError("THCudaBlas_Sgetri not supported in ROCM.");
-#endif
-}
-
-void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) {
-#ifndef __HIP_PLATFORM_HCC__
- if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
- {
- THError("Cublas_Dgetri only supports n, lda, ldc, batchSize"
- "with the bound [val] <= %d", INT_MAX);
- }
- cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
- cublasSetStream(handle, THCState_getCurrentStream(state));
- THCublasCheck(cublasDgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
-#else
- THError("THCudaBlas_Dgetri not supported in ROCM.");
-#endif
-}
diff --git a/aten/src/THC/THCBlas.h b/aten/src/THC/THCBlas.h
index 0306468..56e0113 100644
--- a/aten/src/THC/THCBlas.h
+++ b/aten/src/THC/THCBlas.h
@@ -42,14 +42,7 @@
THHalf beta, THHalf *c, int64_t ldc, int64_t strideC, int64_t batchCount);
#endif
-/* Inverse */
-THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
-THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);
-
THC_API void THCudaBlas_Sgetrs(THCState *state, char transa, int n, int nrhs, const float **a, int lda, int *pivot, float **b, int ldb, int *info, int batchSize);
THC_API void THCudaBlas_Dgetrs(THCState *state, char transa, int n, int nrhs, const double **a, int lda, int *pivot, double **b, int ldb, int *info, int batchSize);
-THC_API void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize);
-THC_API void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize);
-
#endif
diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu
index a379626..a83d6d0 100644
--- a/aten/src/THC/generic/THCTensorMathMagma.cu
+++ b/aten/src/THC/generic/THCTensorMathMagma.cu
@@ -334,112 +334,6 @@
#endif
}
-void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)
-{
- THArgCheck(!a->is_empty() && a->dim() == 2, 2, "A should be non-empty 2 dimensional");
- THArgCheck(a->size(0) == a->size(1), 2, "A should be square");
-
-#ifdef USE_MAGMA
- int info;
- int64_t n = a->size(0);
- int lwork = n * magma_get_sgetri_nb(n);
-
- THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a);
- scalar_t *input_data = THCTensor_(data)(state, input);
-
- int *ipiv = th_magma_malloc_pinned<int>(n);
-
- THCTensor *work = THCTensor_(newWithSize1d)(state, lwork);
- scalar_t *work_data = THCTensor_(data)(state, work);
-
- // Run LU
-#if defined(THC_REAL_IS_FLOAT)
- magma_sgetrf_gpu(n, n, input_data, n, ipiv, &info);
-#else
- magma_dgetrf_gpu(n, n, input_data, n, ipiv, &info);
-#endif
-
- if (info > 0)
- THError("MAGMA getrf : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("MAGMA getrf : Argument %d : illegal value", -info);
-
- // Inverse
-#if defined(THC_REAL_IS_FLOAT)
- magma_sgetri_gpu(n, input_data, n, ipiv, work_data, lwork, &info);
-#else
- magma_dgetri_gpu(n, input_data, n, ipiv, work_data, lwork, &info);
-#endif
-
- if (info > 0)
- THError("MAGMA getri : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("MAGMA getri : Argument %d : illegal value", -info);
-
- THCTensor_(free)(state, work);
- magma_free_pinned(ipiv);
- THCTensor_(freeCopyTo)(state, input, ra_);
-#else
- int64_t n = a->size(0);
-
- // input
- THCTensor *input = THCTensor_(newColumnMajor)(state, a, a);
- THCTensor_(resizeNd)(state, ra_, 2, THTensor_getSizePtr(input), THTensor_getStridePtr(input));
-
- scalar_t *matrices1[1] = { THCTensor_(data)(state, input) };
- scalar_t *matrices2[1] = { THCTensor_(data)(state, ra_) };
-
- // Copy pointers to device.
- auto d_matrices1 = static_cast<scalar_t**>(THCudaMalloc(state, sizeof(scalar_t*)));
- auto d_matrices2 = static_cast<scalar_t**>(THCudaMalloc(state, sizeof(scalar_t*)));
-
- THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, sizeof(scalar_t*),
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, sizeof(scalar_t*),
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- int info;
- auto info_gpu = static_cast<int*>(THCudaMalloc(state, sizeof(int)));
-
- auto ipiv_gpu = static_cast<int*>(THCudaMalloc(state, n * sizeof(int)));
-
- // Run LU
-#if defined(THC_REAL_IS_FLOAT)
- THCudaBlas_Sgetrf(state, n, d_matrices1, n, ipiv_gpu, info_gpu, 1);
-#else
- THCudaBlas_Dgetrf(state, n, d_matrices1, n, ipiv_gpu, info_gpu, 1);
-#endif
-
- THCudaCheck(cudaMemcpy(&info, info_gpu, sizeof(int), cudaMemcpyDeviceToHost));
-
- if (info > 0)
- THError("CUBLAS getrf : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("CUBLAS getrf : Argument %d : illegal value", -info);
-
- // Inverse
-#if defined(THC_REAL_IS_FLOAT)
- THCudaBlas_Sgetri(state, n, (const scalar_t**)d_matrices1, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
-#else
- THCudaBlas_Dgetri(state, n, (const scalar_t**)d_matrices1, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
-#endif
-
- THCudaCheck(cudaMemcpy(&info, info_gpu, sizeof(int), cudaMemcpyDeviceToHost));
-
- if (info > 0)
- THError("CUBLAS getri : U(%d,%d) is 0, U is singular", info, info);
- else if (info < 0)
- THError("CUBLAS getri : Argument %d : illegal value", -info);
-
- THCudaFree(state, ipiv_gpu);
- THCudaFree(state, info_gpu);
-
- THCudaFree(state, d_matrices1);
- THCudaFree(state, d_matrices2);
-
- THCTensor_(free)(state, input);
-#endif
-}
-
__global__ void THCTensor_(copyUpperSymmetric)(scalar_t *input, int n, int len)
{
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) {
diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h
index f388f68..0ae49cd 100644
--- a/aten/src/THC/generic/THCTensorMathMagma.h
+++ b/aten/src/THC/generic/THCTensorMathMagma.h
@@ -12,7 +12,6 @@
const char *some, const char* compute_uv);
THC_API void THCTensor_(gesdd2)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *ra_, THCTensor *a,
const char *some, const char* compute_uv);
-THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_);
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);