| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ============================================================================== |
| */ |
| #ifdef GOOGLE_CUDA |
| #include "tensorflow/core/kernels/cuda_solvers.h" |
| |
| #include <chrono> |
| #include <complex> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "third_party/gpus/cuda/include/cublas_v2.h" |
| #include "third_party/gpus/cuda/include/cusolverDn.h" |
| #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/blocking_counter.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| #include "tensorflow/core/platform/cuda.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/stream_executor.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| // The CUDA cublas_api.h API contains const-correctness errors. Instead of |
| // casting away constness on our data, we instead reinterpret the CuBLAS |
| // functions as what they were clearly meant to be, and thus we can call |
| // the functions naturally. |
| // |
| // (The error is that input-only arrays are bound to parameter types |
| // "const T**" instead of the correct "const T* const*".) |
| extern "C" { |
| using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, |
| const float* const*, int, const int*, float**, |
| int, int*, int); |
| using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, |
| const double* const*, int, const int*, double**, |
| int, int*, int); |
| using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, |
| const float2* const*, int, const int*, float2**, |
| int, int*, int); |
| using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, |
| const double2* const*, int, const int*, |
| double2**, int, int*, int); |
| |
| using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int, |
| const int*, float**, int, int*, int); |
| using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int, |
| const int*, double**, int, int*, int); |
| using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int, |
| const int*, float2**, int, int*, int); |
| using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int, |
| const int*, double2**, int, int*, int); |
| |
| using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int, |
| float**, int, int*, int); |
| using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int, |
| double**, int, int*, int); |
| using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int, |
| float2**, int, int*, int); |
| using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int, |
| double2**, int, int*, int); |
| } |
| |
| namespace tensorflow { |
| namespace { |
| |
| using se::cuda::ScopedActivateExecutorContext; |
| |
| inline bool CopyHostToDevice(OpKernelContext* context, void* dst, |
| const void* src, uint64 bytes) { |
| auto stream = context->op_device_context()->stream(); |
| se::DeviceMemoryBase wrapped_dst(dst); |
| return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok(); |
| } |
| |
| // A set of initialized handles to the underlying Cuda libraries used by |
| // CudaSolver. We maintain one such set of handles per unique stream. |
| struct CudaSolverHandles { |
| explicit CudaSolverHandles(cudaStream_t stream) { |
| CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS) |
| << "Failed to create cuSolverDN instance."; |
| CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) == |
| CUSOLVER_STATUS_SUCCESS) |
| << "Failed to set cuSolverDN stream."; |
| CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS) |
| << "Failed to create cuBlas instance."; |
| CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS) |
| << "Failed to set cuBlas stream."; |
| } |
| |
| ~CudaSolverHandles() { |
| CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS) |
| << "Failed to destroy cuBlas instance."; |
| CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS) |
| << "Failed to destroy cuSolverDN instance."; |
| } |
| cublasHandle_t cublas_handle; |
| cusolverDnHandle_t cusolver_dn_handle; |
| }; |
| |
| static mutex handle_map_mutex(LINKER_INITIALIZED); |
| |
| using HandleMap = |
| std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>; |
| |
| // Returns a singleton map used for storing initialized handles for each unique |
| // cuda stream. |
| HandleMap* GetHandleMapSingleton() { |
| static HandleMap* cm = new HandleMap; |
| return cm; |
| } |
| |
| } // namespace |
| |
| #define TF_RETURN_IF_CUSOLVER_ERROR(expr) \ |
| do { \ |
| auto status = (expr); \ |
| if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \ |
| return errors::Internal( \ |
| __FILE__, ":", __LINE__, \ |
| ": cuSolverDN call failed with status =", status); \ |
| } \ |
| } while (0) |
| |
| #define TF_RETURN_IF_CUBLAS_ERROR(expr) \ |
| do { \ |
| auto status = (expr); \ |
| if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) { \ |
| return errors::Internal(__FILE__, ":", __LINE__, \ |
| ": cuBlas call failed status = ", status); \ |
| } \ |
| } while (0) |
| |
| CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) { |
| mutex_lock lock(handle_map_mutex); |
| const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL( |
| reinterpret_cast<const cudaStream_t*>(context->op_device_context() |
| ->stream() |
| ->implementation() |
| ->GpuStreamMemberHack())); |
| cuda_stream_ = *cu_stream_ptr; |
| HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton()); |
| auto it = handle_map->find(cuda_stream_); |
| if (it == handle_map->end()) { |
| LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_; |
| // Previously unseen Cuda stream. Initialize a set of Cuda solver library |
| // handles for it. |
| std::unique_ptr<CudaSolverHandles> new_handles( |
| new CudaSolverHandles(cuda_stream_)); |
| it = |
| handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles))) |
| .first; |
| } |
| cusolver_dn_handle_ = it->second->cusolver_dn_handle; |
| cublas_handle_ = it->second->cublas_handle; |
| } |
| |
| CudaSolver::~CudaSolver() { |
| for (auto tensor_ref : scratch_tensor_refs_) { |
| tensor_ref.Unref(); |
| } |
| } |
| |
| // static |
| void CudaSolver::CheckLapackInfoAndDeleteSolverAsync( |
| std::unique_ptr<CudaSolver> solver, |
| const std::vector<DeviceLapackInfo>& dev_lapack_infos, |
| std::function<void(const Status&, const std::vector<HostLapackInfo>&)> |
| info_checker_callback) { |
| CHECK(info_checker_callback != nullptr); |
| std::vector<HostLapackInfo> host_lapack_infos; |
| if (dev_lapack_infos.empty()) { |
| info_checker_callback(Status::OK(), host_lapack_infos); |
| return; |
| } |
| |
| // Launch memcpys to copy info back from the device to the host. |
| for (const auto& dev_lapack_info : dev_lapack_infos) { |
| bool success = true; |
| auto host_copy = dev_lapack_info.CopyToHost(&success); |
| OP_REQUIRES( |
| solver->context(), success, |
| errors::Internal( |
| "Failed to launch copy of dev_lapack_info to host, debug_info = ", |
| dev_lapack_info.debug_info())); |
| host_lapack_infos.push_back(std::move(host_copy)); |
| } |
| |
| // This callback checks that all batch items in all calls were processed |
| // successfully and passes status to the info_checker_callback accordingly. |
| auto* stream = solver->context()->op_device_context()->stream(); |
| auto wrapped_info_checker_callback = |
| [stream]( |
| CudaSolver* solver, |
| std::function<void(const Status&, const std::vector<HostLapackInfo>&)> |
| info_checker_callback, |
| std::vector<HostLapackInfo> host_lapack_infos) { |
| ScopedActivateExecutorContext scoped_activation{stream->parent()}; |
| Status status; |
| for (const auto& host_lapack_info : host_lapack_infos) { |
| for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) { |
| const int info_value = host_lapack_info(i); |
| if (info_value != 0) { |
| status = errors::InvalidArgument( |
| "Got info = ", info_value, " for batch index ", i, |
| ", expected info = 0. Debug_info = ", |
| host_lapack_info.debug_info()); |
| } |
| } |
| if (!status.ok()) { |
| break; |
| } |
| } |
| // Delete solver to release temp tensor refs. |
| delete solver; |
| |
| // Delegate further error checking to provided functor. |
| info_checker_callback(status, host_lapack_infos); |
| }; |
| // Note: An std::function cannot have unique_ptr arguments (it must be copy |
| // constructible and therefore so must its arguments). Therefore, we release |
| // solver into a raw pointer to be deleted at the end of |
| // wrapped_info_checker_callback. |
| // Release ownership of solver. It will be deleted in the cb callback. |
| auto solver_raw_ptr = solver.release(); |
| auto cb = |
| std::bind(wrapped_info_checker_callback, solver_raw_ptr, |
| std::move(info_checker_callback), std::move(host_lapack_infos)); |
| |
| solver_raw_ptr->context() |
| ->device() |
| ->tensorflow_gpu_device_info() |
| ->event_mgr->ThenExecute(stream, std::move(cb)); |
| } |
| |
| // static |
| void CudaSolver::CheckLapackInfoAndDeleteSolverAsync( |
| std::unique_ptr<CudaSolver> solver, |
| const std::vector<DeviceLapackInfo>& dev_lapack_info, |
| AsyncOpKernel::DoneCallback done) { |
| OpKernelContext* context = solver->context(); |
| auto wrapped_done = [context, done]( |
| const Status& status, |
| const std::vector<HostLapackInfo>& /* unused */) { |
| if (done != nullptr) { |
| OP_REQUIRES_OK_ASYNC(context, status, done); |
| done(); |
| } else { |
| OP_REQUIRES_OK(context, status); |
| } |
| }; |
| CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info, |
| wrapped_done); |
| } |
| |
| // Allocates a temporary tensor. The CudaSolver object maintains a |
| // TensorReference to the underlying Tensor to prevent it from being deallocated |
| // prematurely. |
| Status CudaSolver::allocate_scoped_tensor(DataType type, |
| const TensorShape& shape, |
| Tensor* out_temp) { |
| const Status status = context_->allocate_temp(type, shape, out_temp); |
| if (status.ok()) { |
| scratch_tensor_refs_.emplace_back(*out_temp); |
| } |
| return status; |
| } |
| |
| Status CudaSolver::forward_input_or_allocate_scoped_tensor( |
| gtl::ArraySlice<int> candidate_input_indices, DataType type, |
| const TensorShape& shape, Tensor* out_temp) { |
| const Status status = context_->forward_input_or_allocate_temp( |
| candidate_input_indices, type, shape, out_temp); |
| if (status.ok()) { |
| scratch_tensor_refs_.emplace_back(*out_temp); |
| } |
| return status; |
| } |
| |
| // Macro that specializes a solver method for all 4 standard |
| // numeric types. |
| #define TF_CALL_LAPACK_TYPES(m) \ |
| m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z) |
| #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D) |
| |
| // Macros to construct cusolverDn method names. |
| #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method |
| #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method |
| #define DN_BUFSIZE_FN(method, type_prefix) \ |
| cusolverDn##type_prefix##method##_bufferSize |
| |
| // Macros to construct cublas method names. |
| #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method |
| #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method |
| |
| //============================================================================= |
| // Wrappers of cuSolverDN computational methods begin here. |
| // |
| // WARNING to implementers: The function signatures listed in the online docs |
| // are sometimes inaccurate, e.g., are missing 'const' on pointers |
| // to immutable arguments, while the actual headers have them as expected. |
| // Check the actual declarations in the cusolver_api.h header file. |
| // |
| // NOTE: The cuSolver functions called below appear not to be threadsafe. |
| // so we put a global lock around the calls. Since these functions only put a |
| // kernel on the shared stream, it is not a big performance hit. |
| // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9. |
| //============================================================================= |
| |
| template <typename Scalar, typename SolverFnT> |
| static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle, |
| cublasOperation_t transa, |
| cublasOperation_t transb, int m, int n, |
| const Scalar* alpha, /* host or device pointer */ |
| const Scalar* A, int lda, |
| const Scalar* beta, /* host or device pointer */ |
| const Scalar* B, int ldb, Scalar* C, int ldc) { |
| mutex_lock lock(handle_map_mutex); |
| using CudaScalar = typename CUDAComplexT<Scalar>::type; |
| TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n, |
| reinterpret_cast<const CudaScalar*>(alpha), |
| reinterpret_cast<const CudaScalar*>(A), lda, |
| reinterpret_cast<const CudaScalar*>(beta), |
| reinterpret_cast<const CudaScalar*>(B), ldb, |
| reinterpret_cast<CudaScalar*>(C), ldc)); |
| return Status::OK(); |
| } |
| |
| #define GEAM_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Geam<Scalar>( \ |
| cublasOperation_t transa, cublasOperation_t transb, int m, int n, \ |
| const Scalar* alpha, /* host or device pointer */ \ |
| const Scalar* A, int lda, \ |
| const Scalar* beta, /* host or device pointer */ \ |
| const Scalar* B, int ldb, Scalar* C, int ldc) const { \ |
| return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \ |
| transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GEAM_INSTANCE); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, |
| cublasFillMode_t uplo, int n, Scalar* A, int lda, |
| int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR(solver( |
| cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, |
| CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define POTRF_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \ |
| int lda, int* dev_lapack_info) { \ |
| return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix), \ |
| DN_SOLVER_FN(potrf, type_prefix), this, context_, \ |
| cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(POTRF_INSTANCE); |
| |
| #if CUDA_VERSION >= 9020 |
| template <typename Scalar, typename SolverFnT> |
| static inline Status PotrfBatchedImpl( |
| SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n, |
| const Scalar* const host_a_dev_ptrs[], int lda, |
| DeviceLapackInfo* dev_lapack_info, int batch_size) { |
| mutex_lock lock(handle_map_mutex); |
| using CudaScalar = typename CUDAComplexT<Scalar>::type; |
| ScratchSpace<uint8> dev_a_dev_ptrs = |
| cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |
| /* on_host */ false); |
| if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, |
| host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { |
| return errors::Internal("PotrfBatched: failed to copy pointers to device"); |
| } |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| solver(cusolver_dn_handle, uplo, n, |
| reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda, |
| dev_lapack_info->mutable_data(), batch_size)); |
| return Status::OK(); |
| } |
| |
| #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::PotrfBatched( \ |
| cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \ |
| int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) { \ |
| return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \ |
| context_, cusolver_dn_handle_, uplo, n, \ |
| host_a_dev_ptrs, lda, dev_lapack_info, \ |
| batch_size); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE); |
| #endif // CUDA_VERSION >= 9020 |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, int m, |
| int n, Scalar* A, int lda, int* dev_pivots, |
| int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR(solver( |
| cusolver_dn_handle, m, n, CUDAComplex(A), lda, |
| CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define GETRF_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \ |
| int* dev_pivots, int* dev_lapack_info) { \ |
| return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix), \ |
| DN_SOLVER_FN(getrf, type_prefix), this, context_, \ |
| cusolver_dn_handle_, m, n, A, lda, dev_pivots, \ |
| dev_lapack_info); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GETRF_INSTANCE); |
| |
| template <typename Scalar, typename SolverFnT> |
| static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, |
| cublasOperation_t trans, int n, int nrhs, |
| const Scalar* A, int lda, const int* pivots, |
| Scalar* B, int ldb, int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs, |
| CUDAComplex(A), lda, pivots, |
| CUDAComplex(B), ldb, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define GETRS_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Getrs<Scalar>( \ |
| cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \ |
| const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \ |
| return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \ |
| cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \ |
| ldb, dev_lapack_info); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GETRS_INSTANCE); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, int m, |
| int n, Scalar* A, int lda, Scalar* tau, |
| int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR(solver( |
| cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau), |
| CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define GEQRF_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \ |
| Scalar* tau, int* dev_lapack_info) { \ |
| return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix), \ |
| DN_SOLVER_FN(geqrf, type_prefix), this, context_, \ |
| cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, |
| cublasSideMode_t side, cublasOperation_t trans, |
| int m, int n, int k, const Scalar* dev_a, |
| int lda, const Scalar* dev_tau, Scalar* dev_c, |
| int ldc, int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda, |
| CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR(solver( |
| cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda, |
| CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, |
| CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| // Unfortunately the LAPACK function name differs for the real and complex case |
| // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each |
| // one separately. |
| #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans, \ |
| int m, int n, int k, const Scalar* dev_a, int lda, \ |
| const Scalar* dev_tau, Scalar* dev_c, int ldc, \ |
| int* dev_lapack_info) { \ |
| return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix), \ |
| DN_SOLVER_FN(function_prefix##mqr, type_prefix), this, \ |
| context_, cusolver_dn_handle_, side, trans, m, n, k, \ |
| dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info); \ |
| } |
| |
| UNMQR_INSTANCE(float, or, S); |
| UNMQR_INSTANCE(double, or, D); |
| UNMQR_INSTANCE(complex64, un, C); |
| UNMQR_INSTANCE(complex128, un, Z); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, int m, |
| int n, int k, Scalar* dev_a, int lda, |
| const Scalar* dev_tau, int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k, |
| CUDAComplex(dev_a), lda, |
| CUDAComplex(dev_tau), &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda, |
| CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()), |
| lwork, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \ |
| const Scalar* dev_tau, int* dev_lapack_info) { \ |
| return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix), \ |
| DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \ |
| context_, cusolver_dn_handle_, m, n, k, dev_a, lda, \ |
| dev_tau, dev_lapack_info); \ |
| } |
| |
| UNGQR_INSTANCE(float, or, S); |
| UNGQR_INSTANCE(double, or, D); |
| UNGQR_INSTANCE(complex64, un, C); |
| UNGQR_INSTANCE(complex128, un, Z); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, |
| cusolverEigMode_t jobz, cublasFillMode_t uplo, |
| int n, Scalar* dev_A, int lda, |
| typename Eigen::NumTraits<Scalar>::Real* dev_W, |
| int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n, |
| CUDAComplex(dev_A), lda, |
| CUDAComplex(dev_W), &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| /* Launch the solver kernel. */ |
| TF_RETURN_IF_CUSOLVER_ERROR( |
| solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda, |
| CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()), |
| lwork, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, \ |
| int n, Scalar* dev_A, int lda, \ |
| typename Eigen::NumTraits<Scalar>::Real* dev_W, \ |
| int* dev_lapack_info) { \ |
| return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix), \ |
| DN_SOLVER_FN(function_prefix##evd, type_prefix), this, \ |
| context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \ |
| dev_W, dev_lapack_info); \ |
| } |
| |
| HEEVD_INSTANCE(float, sy, S); |
| HEEVD_INSTANCE(double, sy, D); |
| HEEVD_INSTANCE(complex64, he, C); |
| HEEVD_INSTANCE(complex128, he, Z); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status GesvdImpl( |
| BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver, |
| OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle, |
| signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda, |
| Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n, |
| CUDAComplex(A), lda, S, CUDAComplex(U), |
| ldu, CUDAComplex(VT), ldvt, |
| CUDAComplex(dev_workspace.mutable_data()), |
| lwork, nullptr, dev_lapack_info)); |
| return Status::OK(); |
| } |
| |
| #define GESVD_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::Gesvd<Scalar>( \ |
| signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \ |
| int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \ |
| int ldvt, int* dev_lapack_info) { \ |
| return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix), \ |
| DN_SOLVER_FN(gesvd, type_prefix), this, context_, \ |
| cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \ |
| dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE); |
| |
| template <typename Scalar, typename BufSizeFnT, typename SolverFnT> |
| static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver, |
| CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cusolverDnHandle_t cusolver_dn_handle, |
| cusolverEigMode_t jobz, int m, int n, |
| Scalar* A, int lda, Scalar* S, Scalar* U, |
| int ldu, Scalar* V, int ldv, |
| int* dev_lapack_info, int batch_size) { |
| mutex_lock lock(handle_map_mutex); |
| /* Get amount of workspace memory required. */ |
| int lwork; |
| /* Default parameters for gesvdj and gesvdjBatched. */ |
| gesvdjInfo_t svdj_info; |
| TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info)); |
| TF_RETURN_IF_CUSOLVER_ERROR(bufsize( |
| cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U), |
| ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size)); |
| /* Allocate device memory for workspace. */ |
| auto dev_workspace = |
| cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); |
| TF_RETURN_IF_CUSOLVER_ERROR(solver( |
| cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U), |
| ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()), |
| lwork, dev_lapack_info, svdj_info, batch_size)); |
| TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info)); |
| return Status::OK(); |
| } |
| |
| #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::GesvdjBatched<Scalar>( \ |
| cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda, \ |
| Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv, \ |
| int* dev_lapack_info, int batch_size) { \ |
| return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix), \ |
| DN_SOLVER_FN(gesvdjBatched, type_prefix), this, \ |
| context_, cusolver_dn_handle_, jobz, m, n, dev_A, \ |
| lda, dev_S, dev_U, ldu, dev_V, ldv, \ |
| dev_lapack_info, batch_size); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE); |
| |
| //============================================================================= |
| // Wrappers of cuBlas computational methods begin here. |
| // |
| // WARNING to implementers: The function signatures listed in the online docs |
| // are sometimes inaccurate, e.g., are missing 'const' on pointers |
| // to immutable arguments, while the actual headers have them as expected. |
| // Check the actual declarations in the cublas_api.h header file. |
| //============================================================================= |
| template <typename Scalar, typename SolverFnT> |
| static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver, |
| OpKernelContext* context, |
| cublasHandle_t cublas_handle, int n, |
| const Scalar* const host_a_dev_ptrs[], |
| int lda, int* dev_pivots, |
| DeviceLapackInfo* dev_lapack_info, |
| int batch_size) { |
| mutex_lock lock(handle_map_mutex); |
| using CudaScalar = typename CUDAComplexT<Scalar>::type; |
| ScratchSpace<uint8> dev_a_dev_ptrs = |
| cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |
| /* on_host */ false); |
| if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, |
| host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { |
| return errors::Internal("GetrfBatched: failed to copy pointers to device"); |
| } |
| TF_RETURN_IF_CUBLAS_ERROR( |
| solver(cublas_handle, n, |
| reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda, |
| dev_pivots, dev_lapack_info->mutable_data(), batch_size)); |
| return Status::OK(); |
| } |
| |
| #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::GetrfBatched( \ |
| int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, \ |
| DeviceLapackInfo* dev_lapack_info, int batch_size) { \ |
| return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this, \ |
| context_, cublas_handle_, n, host_a_dev_ptrs, lda, \ |
| dev_pivots, dev_lapack_info, batch_size); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE); |
| |
| template <typename Scalar, typename SolverFnT> |
| static inline Status GetrsBatchedImpl( |
| SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, |
| cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs, |
| const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, |
| const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, |
| int batch_size) { |
| mutex_lock lock(handle_map_mutex); |
| using CudaScalar = typename CUDAComplexT<Scalar>::type; |
| ScratchSpace<uint8> dev_a_dev_ptrs = |
| cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |
| /* on_host */ false); |
| ScratchSpace<uint8> dev_b_dev_ptrs = |
| cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |
| /* on_host */ false); |
| if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, |
| host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { |
| return errors::Internal("GetrsBatched: failed to copy pointers to device"); |
| } |
| if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */, |
| host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) { |
| return errors::Internal("GetrsBatched: failed to copy pointers to device"); |
| } |
| TF_RETURN_IF_CUBLAS_ERROR(solver( |
| cublas_handle, trans, n, nrhs, |
| reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda, |
| dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()), |
| ldb, host_lapack_info, batch_size)); |
| return Status::OK(); |
| } |
| |
| #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::GetrsBatched( \ |
| cublasOperation_t trans, int n, int nrhs, \ |
| const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \ |
| const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \ |
| int batch_size) { \ |
| return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \ |
| BLAS_SOLVER_FN(getrsBatched, type_prefix)), \ |
| this, context_, cublas_handle_, trans, n, nrhs, \ |
| host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \ |
| ldb, host_lapack_info, batch_size); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE); |
| |
| template <typename Scalar, typename SolverFnT> |
| static inline Status GetriBatchedImpl( |
| SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, |
| cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], |
| int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], |
| int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { |
| mutex_lock lock(handle_map_mutex); |
| using CudaScalar = typename CUDAComplexT<Scalar>::type; |
| ScratchSpace<uint8> dev_a_dev_ptrs = |
| cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |
| /* on_host */ false); |
| ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>( |
| sizeof(CudaScalar*) * batch_size, "", /* on_host */ false); |
| if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, |
| host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) || |
| !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(), |
| host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) { |
| return errors::Internal("GetriBatched: failed to copy pointers to device"); |
| } |
| TF_RETURN_IF_CUBLAS_ERROR( |
| solver(cublas_handle, n, |
| reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), |
| lda, dev_pivots, |
| reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), |
| ldainv, dev_lapack_info->mutable_data(), batch_size)); |
| return Status::OK(); |
| } |
| |
| #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::GetriBatched( \ |
| int n, const Scalar* const host_a_dev_ptrs[], int lda, \ |
| const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], \ |
| int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \ |
| return GetriBatchedImpl( \ |
| reinterpret_cast<getri_##type_prefix*>( \ |
| BLAS_SOLVER_FN(getriBatched, type_prefix)), \ |
| this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \ |
| host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE); |
| |
| template <typename Scalar, typename SolverFnT> |
| static inline Status MatInvBatchedImpl( |
| SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, |
| cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], |
| int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, |
| DeviceLapackInfo* dev_lapack_info, int batch_size) { |
| mutex_lock lock(handle_map_mutex); |
| using CudaScalar = typename CUDAComplexT<Scalar>::type; |
| ScratchSpace<uint8> dev_a_dev_ptrs = |
| cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |
| /* on_host */ false); |
| ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>( |
| sizeof(CudaScalar*) * batch_size, "", /* on_host */ false); |
| if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, |
| host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) || |
| !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(), |
| host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) { |
| return errors::Internal("MatInvBatched: failed to copy pointers to device"); |
| } |
| TF_RETURN_IF_CUBLAS_ERROR(solver( |
| cublas_handle, n, |
| reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda, |
| reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv, |
| dev_lapack_info->mutable_data(), batch_size)); |
| return Status::OK(); |
| } |
| |
| #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix) \ |
| template <> \ |
| Status CudaSolver::MatInvBatched( \ |
| int n, const Scalar* const host_a_dev_ptrs[], int lda, \ |
| const Scalar* const host_a_inv_dev_ptrs[], int ldainv, \ |
| DeviceLapackInfo* dev_lapack_info, int batch_size) { \ |
| return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>( \ |
| BLAS_SOLVER_FN(matinvBatched, type_prefix)), \ |
| this, context_, cublas_handle_, n, \ |
| host_a_dev_ptrs, lda, host_a_inv_dev_ptrs, \ |
| ldainv, dev_lapack_info, batch_size); \ |
| } |
| |
| TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE); |
| |
| } // namespace tensorflow |
| |
| #endif // GOOGLE_CUDA |