| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "rocm/include/rocblas.h" |
| |
| #include "tensorflow/stream_executor/rocm/rocm_blas.h" |
| |
| #define EIGEN_USE_GPU |
| #include <assert.h> |
| |
| #include <complex> |
| |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_format.h" |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/stream_executor/device_memory.h" |
| #include "tensorflow/stream_executor/gpu/gpu_activation.h" |
| #include "tensorflow/stream_executor/gpu/gpu_executor.h" |
| #include "tensorflow/stream_executor/gpu/gpu_helpers.h" |
| #include "tensorflow/stream_executor/gpu/gpu_stream.h" |
| #include "tensorflow/stream_executor/gpu/gpu_timer.h" |
| #include "tensorflow/stream_executor/lib/env.h" |
| #include "tensorflow/stream_executor/lib/initialize.h" |
| #include "tensorflow/stream_executor/lib/status.h" |
| #include "tensorflow/stream_executor/lib/status_macros.h" |
| #include "tensorflow/stream_executor/platform/dso_loader.h" |
| #include "tensorflow/stream_executor/platform/logging.h" |
| #include "tensorflow/stream_executor/platform/port.h" |
| #include "tensorflow/stream_executor/plugin_registry.h" |
| #include "tensorflow/stream_executor/rocm/rocm_platform_id.h" |
| #include "tensorflow/stream_executor/scratch_allocator.h" |
| #include "tensorflow/stream_executor/stream_executor.h" |
| |
| namespace stream_executor { |
| namespace gpu { |
| |
| PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin); |
| |
| namespace wrap { |
| |
| #ifdef PLATFORM_GOOGLE |
| #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ |
| struct WrapperShim__##__name { \ |
| static const char *kName; \ |
| template <typename... Args> \ |
| rocblas_status operator()(GpuExecutor *parent, Args... args) { \ |
| gpu::ScopedActivateExecutorContext sac{parent}; \ |
| return ::__name(args...); \ |
| } \ |
| } __name; \ |
| const char *WrapperShim__##__name::kName = #__name; |
| |
| #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ |
| STREAM_EXECUTOR_ROCBLAS_WRAP(__name) |
| |
| #else |
| |
| #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ |
| struct DynLoadShim__##__name { \ |
| static const char *kName; \ |
| using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \ |
| static void *GetDsoHandle() { \ |
| auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \ |
| return s.ValueOrDie(); \ |
| } \ |
| static FuncPtrT LoadOrDie() { \ |
| void *f; \ |
| auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ |
| kName, &f); \ |
| CHECK(s.ok()) << "could not find " << kName \ |
| << " in rocblas DSO; dlerror: " << s.error_message(); \ |
| return reinterpret_cast<FuncPtrT>(f); \ |
| } \ |
| static FuncPtrT DynLoad() { \ |
| static FuncPtrT f = LoadOrDie(); \ |
| return f; \ |
| } \ |
| template <typename... Args> \ |
| rocblas_status operator()(GpuExecutor *parent, Args... args) { \ |
| gpu::ScopedActivateExecutorContext sac{parent}; \ |
| return DynLoad()(args...); \ |
| } \ |
| } __name; \ |
| const char *DynLoadShim__##__name::kName = #__name; |
| |
| #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ |
| STREAM_EXECUTOR_ROCBLAS_WRAP(__name) |
| |
| #endif |
| |
| // clang-format off |
| #define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \ |
| __macro(rocblas_snrm2) \ |
| __macro(rocblas_dnrm2) \ |
| __macro(rocblas_scnrm2) \ |
| __macro(rocblas_dznrm2) \ |
| __macro(rocblas_sdot) \ |
| __macro(rocblas_ddot) \ |
| __macro(rocblas_cdotu) \ |
| __macro(rocblas_cdotc) \ |
| __macro(rocblas_zdotu) \ |
| __macro(rocblas_zdotc) \ |
| __macro(rocblas_sscal) \ |
| __macro(rocblas_dscal) \ |
| __macro(rocblas_cscal) \ |
| __macro(rocblas_csscal) \ |
| __macro(rocblas_zscal) \ |
| __macro(rocblas_zdscal) \ |
| __macro(rocblas_saxpy) \ |
| __macro(rocblas_daxpy) \ |
| __macro(rocblas_caxpy) \ |
| __macro(rocblas_zaxpy) \ |
| __macro(rocblas_scopy) \ |
| __macro(rocblas_dcopy) \ |
| __macro(rocblas_ccopy) \ |
| __macro(rocblas_zcopy) \ |
| __macro(rocblas_sswap) \ |
| __macro(rocblas_dswap) \ |
| __macro(rocblas_cswap) \ |
| __macro(rocblas_zswap) \ |
| __macro(rocblas_isamax) \ |
| __macro(rocblas_idamax) \ |
| __macro(rocblas_icamax) \ |
| __macro(rocblas_izamax) \ |
| __macro(rocblas_isamin) \ |
| __macro(rocblas_idamin) \ |
| __macro(rocblas_icamin) \ |
| __macro(rocblas_izamin) \ |
| __macro(rocblas_sasum) \ |
| __macro(rocblas_dasum) \ |
| __macro(rocblas_scasum) \ |
| __macro(rocblas_dzasum) \ |
| __macro(rocblas_srot) \ |
| __macro(rocblas_drot) \ |
| __macro(rocblas_crot) \ |
| __macro(rocblas_csrot) \ |
| __macro(rocblas_zrot) \ |
| __macro(rocblas_zdrot) \ |
| __macro(rocblas_srotg) \ |
| __macro(rocblas_drotg) \ |
| __macro(rocblas_crotg) \ |
| __macro(rocblas_zrotg) \ |
| __macro(rocblas_srotm) \ |
| __macro(rocblas_drotm) \ |
| __macro(rocblas_srotmg) \ |
| __macro(rocblas_drotmg) \ |
| __macro(rocblas_sgemv) \ |
| __macro(rocblas_dgemv) \ |
| __macro(rocblas_cgemv) \ |
| __macro(rocblas_zgemv) \ |
| __macro(rocblas_sgbmv) \ |
| __macro(rocblas_dgbmv) \ |
| __macro(rocblas_cgbmv) \ |
| __macro(rocblas_zgbmv) \ |
| __macro(rocblas_strmv) \ |
| __macro(rocblas_dtrmv) \ |
| __macro(rocblas_ctrmv) \ |
| __macro(rocblas_ztrmv) \ |
| __macro(rocblas_stbmv) \ |
| __macro(rocblas_dtbmv) \ |
| __macro(rocblas_ctbmv) \ |
| __macro(rocblas_ztbmv) \ |
| __macro(rocblas_stpmv) \ |
| __macro(rocblas_dtpmv) \ |
| __macro(rocblas_ctpmv) \ |
| __macro(rocblas_ztpmv) \ |
| __macro(rocblas_strsv) \ |
| __macro(rocblas_dtrsv) \ |
| __macro(rocblas_ctrsv) \ |
| __macro(rocblas_ztrsv) \ |
| __macro(rocblas_stpsv) \ |
| __macro(rocblas_dtpsv) \ |
| __macro(rocblas_ctpsv) \ |
| __macro(rocblas_ztpsv) \ |
| __macro(rocblas_stbsv) \ |
| __macro(rocblas_dtbsv) \ |
| __macro(rocblas_ctbsv) \ |
| __macro(rocblas_ztbsv) \ |
| __macro(rocblas_ssymv) \ |
| __macro(rocblas_dsymv) \ |
| /* __macro(rocblas_csymv) \ |
| __macro(rocblas_zsymv) */ \ |
| __macro(rocblas_chemv) \ |
| __macro(rocblas_zhemv) \ |
| __macro(rocblas_ssbmv) \ |
| __macro(rocblas_dsbmv) \ |
| __macro(rocblas_chbmv) \ |
| __macro(rocblas_zhbmv) \ |
| __macro(rocblas_sspmv) \ |
| __macro(rocblas_dspmv) \ |
| __macro(rocblas_chpmv) \ |
| __macro(rocblas_zhpmv) \ |
| __macro(rocblas_sger) \ |
| __macro(rocblas_dger) \ |
| __macro(rocblas_cgeru) \ |
| __macro(rocblas_cgerc) \ |
| __macro(rocblas_zgeru) \ |
| __macro(rocblas_zgerc) \ |
| __macro(rocblas_ssyr) \ |
| __macro(rocblas_dsyr) \ |
| /*__macro(rocblas_csyr) \ |
| __macro(rocblas_zsyr) */ \ |
| __macro(rocblas_cher) \ |
| __macro(rocblas_zher) \ |
| __macro(rocblas_sspr) \ |
| __macro(rocblas_dspr) \ |
| __macro(rocblas_chpr) \ |
| __macro(rocblas_zhpr) \ |
| __macro(rocblas_ssyr2) \ |
| __macro(rocblas_dsyr2) \ |
| /* __macro(rocblas_csyr2) \ |
| __macro(rocblas_zsyr2) */ \ |
| __macro(rocblas_cher2) \ |
| __macro(rocblas_zher2) \ |
| __macro(rocblas_sspr2) \ |
| __macro(rocblas_dspr2) \ |
| __macro(rocblas_chpr2) \ |
| __macro(rocblas_zhpr2) \ |
| __macro(rocblas_sgemm) \ |
| __macro(rocblas_dgemm) \ |
| __macro(rocblas_hgemm) \ |
| __macro(rocblas_cgemm) \ |
| __macro(rocblas_zgemm) \ |
| __macro(rocblas_ssyrk) \ |
| __macro(rocblas_dsyrk) \ |
| __macro(rocblas_csyrk) \ |
| __macro(rocblas_zsyrk) \ |
| __macro(rocblas_cherk) \ |
| __macro(rocblas_zherk) \ |
| __macro(rocblas_ssyr2k) \ |
| __macro(rocblas_dsyr2k) \ |
| __macro(rocblas_csyr2k) \ |
| __macro(rocblas_zsyr2k) \ |
| __macro(rocblas_cher2k) \ |
| __macro(rocblas_zher2k) \ |
| /* __macro(rocblas_ssyrkx) \ |
| __macro(rocblas_dsyrkx) \ |
| __macro(rocblas_csyrkx) \ |
| __macro(rocblas_zsyrkx) \ |
| __macro(rocblas_cherkx) \ |
| __macro(rocblas_zherkx) */ \ |
| __macro(rocblas_ssymm) \ |
| __macro(rocblas_dsymm) \ |
| __macro(rocblas_csymm) \ |
| __macro(rocblas_zsymm) \ |
| __macro(rocblas_chemm) \ |
| __macro(rocblas_zhemm) \ |
| __macro(rocblas_strsm) \ |
| __macro(rocblas_dtrsm) \ |
| __macro(rocblas_ctrsm) \ |
| __macro(rocblas_ztrsm) \ |
| __macro(rocblas_strmm) \ |
| __macro(rocblas_dtrmm) \ |
| __macro(rocblas_ctrmm) \ |
| __macro(rocblas_ztrmm) \ |
| __macro(rocblas_sgeam) \ |
| __macro(rocblas_dgeam) \ |
| __macro(rocblas_gemm_ex) \ |
| __macro(rocblas_gemm_strided_batched_ex) \ |
| /*__macro(rocblas_cgeam) \ |
| __macro(rocblas_zgeam) \ |
| __macro(rocblas_sdgmm) \ |
| __macro(rocblas_ddgmm) \ |
| __macro(rocblas_cdgmm) \ |
| __macro(rocblas_zdgmm) */ |
| // clang-format on |
| |
| STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle) |
| STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle) |
| STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream) |
| // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode) |
| // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode) |
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched) |
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched) |
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched) |
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched) |
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched) |
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched) |
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched) |
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched) |
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched) |
| ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP) |
| |
| } // namespace wrap |
| |
| template <class T> |
| const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast( |
| const DeviceMemory<T> &a) { |
| return reinterpret_cast< |
| const typename RocBlasTypeConversionHelper<T>::mapped_type *>( |
| GpuMemory(a)); |
| } |
| |
| template <class T> |
| const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast( |
| const T &a) { |
| return reinterpret_cast< |
| const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a); |
| } |
| template <class T> |
| typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast( |
| DeviceMemory<T> *a) { |
| return reinterpret_cast< |
| typename RocBlasTypeConversionHelper<T>::mapped_type *>( |
| GpuMemoryMutable(a)); |
| } |
| |
| static void blas_log(const char *c) {} |
| |
| static string ToString(rocblas_status status) { |
| switch (status) { |
| case rocblas_status_success: |
| return "rocblas_status_success"; |
| case rocblas_status_invalid_handle: |
| return "rocblas_status_invalid_handle"; |
| case rocblas_status_not_implemented: |
| return "rocblas_status_not_implemented"; |
| case rocblas_status_invalid_pointer: |
| return "rocblas_status_invalid_pointer"; |
| case rocblas_status_invalid_size: |
| return "rocblas_status_invalid_size"; |
| case rocblas_status_memory_error: |
| return "rocblas_status_memory_error"; |
| case rocblas_status_internal_error: |
| return "rocblas_status_internal_error"; |
| default: |
| return absl::StrCat("<invalid rocBLAS status: ", status, ">"); |
| } |
| } |
| |
| bool ROCMBlas::Init() { |
| rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_); |
| if (ret != rocblas_status_success) { |
| LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret); |
| return false; |
| } |
| |
| return true; |
| } |
| |
| ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent) |
| : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} |
| |
| ROCMBlas::~ROCMBlas() { |
| if (blas_ != nullptr) { |
| wrap::rocblas_destroy_handle(parent_, blas_); |
| } |
| } |
| |
| bool ROCMBlas::SetStream(Stream *stream) { |
| CHECK(stream != nullptr); |
| CHECK(AsGpuStreamValue(stream) != nullptr); |
| CHECK(blas_ != nullptr); |
| rocblas_status ret = |
| wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream)); |
| if (ret != rocblas_status_success) { |
| LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); |
| return false; |
| } |
| |
| return true; |
| } |
| |
| namespace { |
| |
| // Helper functions transforming blas arguments into rocBLAS arguments. |
| |
| rocblas_operation ROCMBlasTranspose(blas::Transpose trans) { |
| switch (trans) { |
| case blas::Transpose::kNoTranspose: |
| return rocblas_operation_none; |
| case blas::Transpose::kTranspose: |
| return rocblas_operation_transpose; |
| case blas::Transpose::kConjugateTranspose: |
| return rocblas_operation_conjugate_transpose; |
| default: |
| LOG(FATAL) << "Invalid value of blas::Transpose."; |
| } |
| } |
| |
| rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) { |
| switch (uplo) { |
| case blas::UpperLower::kUpper: |
| return rocblas_fill_upper; |
| case blas::UpperLower::kLower: |
| return rocblas_fill_lower; |
| default: |
| LOG(FATAL) << "Invalid value of blas::UpperLower."; |
| } |
| } |
| |
| rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) { |
| switch (diag) { |
| case blas::Diagonal::kUnit: |
| return rocblas_diagonal_unit; |
| case blas::Diagonal::kNonUnit: |
| return rocblas_diagonal_non_unit; |
| default: |
| LOG(FATAL) << "Invalid value of blas::Diagonal."; |
| } |
| } |
| |
| rocblas_side ROCMBlasSide(blas::Side side) { |
| switch (side) { |
| case blas::Side::kLeft: |
| return rocblas_side_left; |
| case blas::Side::kRight: |
| return rocblas_side_right; |
| default: |
| LOG(FATAL) << "Invalid value of blas::Side."; |
| } |
| } |
| |
| } // namespace |
| |
| template <typename FuncT, typename... Args> |
| bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, |
| bool pointer_mode_host, bool err_on_failure, |
| Args... args) { |
| absl::MutexLock lock{&mu_}; |
| |
| CHECK(blas_ != nullptr); |
| if (!SetStream(stream)) { |
| return false; |
| } |
| |
| rocblas_status ret = rocblas_func(parent_, blas_, args...); |
| if (err_on_failure && ret != rocblas_status_success) { |
| LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " |
| << ToString(ret); |
| } |
| return ret == rocblas_status_success; |
| } |
| |
| bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *result) { |
| return DoBlasInternal(wrap::rocblas_sasum, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *result) { |
| return DoBlasInternal(wrap::rocblas_dasum, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<float> *result) { |
| return DoBlasInternal(wrap::rocblas_scasum, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<double> *result) { |
| return DoBlasInternal(wrap::rocblas_dzasum, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *y, int incy) { |
| blas_log("DoBlasAxpy"); |
| return DoBlasInternal(wrap::rocblas_saxpy, stream, |
| /* pointer_mode_host = */ true, elem_count, &alpha, |
| GpuMemory(x), incx, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *y, int incy) { |
| blas_log("DoBlasAxpy"); |
| return DoBlasInternal(wrap::rocblas_daxpy, stream, |
| /* pointer_mode_host = */ true, elem_count, &alpha, |
| GpuMemory(x), incx, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count, |
| complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count, |
| complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_scopy, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_dcopy, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_ccopy, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_zcopy, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *result) { |
| blas_log("DoBlasDot"); |
| return DoBlasInternal( |
| wrap::rocblas_sdot, stream, /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *result) { |
| blas_log("DoBlasDot"); |
| return DoBlasInternal( |
| wrap::rocblas_ddot, stream, /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *result) { |
| return DoBlasInternal( |
| wrap::rocblas_cdotc, stream, /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy, complex_cast(result)); |
| } |
| |
| bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *result) { |
| return DoBlasInternal( |
| wrap::rocblas_zdotc, stream, /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy, complex_cast(result)); |
| } |
| |
| bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *result) { |
| return DoBlasInternal( |
| wrap::rocblas_cdotu, stream, /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy, complex_cast(result)); |
| } |
| |
| bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *result) { |
| return DoBlasInternal( |
| wrap::rocblas_zdotu, stream, /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy, complex_cast(result)); |
| } |
| |
| bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *result) { |
| return DoBlasInternal(wrap::rocblas_snrm2, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *result) { |
| return DoBlasInternal(wrap::rocblas_dnrm2, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<float> *result) { |
| return DoBlasInternal(wrap::rocblas_scnrm2, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<double> *result) { |
| return DoBlasInternal(wrap::rocblas_dznrm2, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, |
| DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy, float c, float s) { |
| return DoBlasInternal( |
| wrap::rocblas_srot, stream, /* pointer_mode_host = */ true, elem_count, |
| GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s); |
| } |
| |
| bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, |
| DeviceMemory<double> *x, int incx, |
| DeviceMemory<double> *y, int incy, double c, |
| double s) { |
| return DoBlasInternal( |
| wrap::rocblas_drot, stream, /* pointer_mode_host = */ true, elem_count, |
| GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s); |
| } |
| |
| bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, |
| DeviceMemory<std::complex<float>> *x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy, |
| float c, float s) { |
| return DoBlasInternal(wrap::rocblas_csrot, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy, &c, &s); |
| } |
| |
| bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, |
| DeviceMemory<std::complex<double>> *x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy, |
| double c, double s) { |
| return DoBlasInternal(wrap::rocblas_zdrot, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy, &c, &s); |
| } |
| |
| bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a, |
| DeviceMemory<float> *b, DeviceMemory<float> *c, |
| DeviceMemory<float> *s) { |
| return DoBlasInternal(wrap::rocblas_srotg, stream, |
| /* pointer_mode_host = */ false, GpuMemoryMutable(a), |
| GpuMemoryMutable(b), GpuMemoryMutable(c), |
| GpuMemoryMutable(s)); |
| } |
| |
| bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a, |
| DeviceMemory<double> *b, DeviceMemory<double> *c, |
| DeviceMemory<double> *s) { |
| return DoBlasInternal(wrap::rocblas_drotg, stream, |
| /* pointer_mode_host = */ false, GpuMemoryMutable(a), |
| GpuMemoryMutable(b), GpuMemoryMutable(c), |
| GpuMemoryMutable(s)); |
| } |
| |
| bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, |
| DeviceMemory<std::complex<float>> *b, |
| DeviceMemory<float> *c, |
| DeviceMemory<std::complex<float>> *s) { |
| return DoBlasInternal(wrap::rocblas_crotg, stream, |
| /* pointer_mode_host = */ false, complex_cast(a), |
| complex_cast(b), GpuMemoryMutable(c), complex_cast(s)); |
| } |
| |
| bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, |
| DeviceMemory<std::complex<double>> *b, |
| DeviceMemory<double> *c, |
| DeviceMemory<std::complex<double>> *s) { |
| return DoBlasInternal(wrap::rocblas_zrotg, stream, |
| /* pointer_mode_host = */ false, complex_cast(a), |
| complex_cast(b), GpuMemoryMutable(c), complex_cast(s)); |
| } |
| |
| bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, |
| DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy, |
| const DeviceMemory<float> ¶m) { |
| return DoBlasInternal( |
| wrap::rocblas_srotm, stream, /* pointer_mode_host = */ false, elem_count, |
| GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param)); |
| } |
| |
| bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, |
| DeviceMemory<double> *x, int incx, |
| DeviceMemory<double> *y, int incy, |
| const DeviceMemory<double> ¶m) { |
| return DoBlasInternal( |
| wrap::rocblas_drotm, stream, /* pointer_mode_host = */ false, elem_count, |
| GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param)); |
| } |
| |
| bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, |
| DeviceMemory<float> *d2, DeviceMemory<float> *x1, |
| const DeviceMemory<float> &y1, |
| DeviceMemory<float> *param) { |
| return DoBlasInternal(wrap::rocblas_srotmg, stream, |
| /* pointer_mode_host = */ false, GpuMemoryMutable(d1), |
| GpuMemoryMutable(d2), GpuMemoryMutable(x1), |
| GpuMemory(y1), GpuMemoryMutable(param)); |
| } |
| |
| bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, |
| DeviceMemory<double> *d2, DeviceMemory<double> *x1, |
| const DeviceMemory<double> &y1, |
| DeviceMemory<double> *param) { |
| return DoBlasInternal(wrap::rocblas_drotmg, stream, |
| /* pointer_mode_host = */ false, GpuMemoryMutable(d1), |
| GpuMemoryMutable(d2), GpuMemoryMutable(x1), |
| GpuMemory(y1), GpuMemoryMutable(param)); |
| } |
| |
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, |
| DeviceMemory<float> *x, int incx) { |
| blas_log("DoBlasScal<float>"); |
| return DoBlasInternal(wrap::rocblas_sscal, stream, |
| /* pointer_mode_host = */ true, elem_count, &alpha, |
| GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_dscal, stream, |
| /* pointer_mode_host = */ true, elem_count, &alpha, |
| GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_csscal, stream, |
| /* pointer_mode_host = */ true, elem_count, &alpha, |
| complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_zdscal, stream, |
| /* pointer_mode_host = */ true, elem_count, &alpha, |
| complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, |
| std::complex<float> alpha, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_cscal, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(alpha), complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, |
| std::complex<double> alpha, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_zscal, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(alpha), complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, |
| DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_sswap, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, |
| DeviceMemory<double> *x, int incx, |
| DeviceMemory<double> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_dswap, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, |
| DeviceMemory<std::complex<float>> *x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_cswap, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, |
| DeviceMemory<std::complex<double>> *x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_zswap, stream, |
| /* pointer_mode_host = */ true, elem_count, |
| complex_cast(x), incx, complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_isamax, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_idamax, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_icamax, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_izamax, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_isamin, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_idamin, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| GpuMemory(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_icamin, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<int> *result) { |
| return DoBlasInternal(wrap::rocblas_izamin, stream, |
| /* pointer_mode_host = */ false, elem_count, |
| complex_cast(x), incx, GpuMemoryMutable(result)); |
| } |
| |
| bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, uint64 kl, uint64 ku, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_sgbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda, |
| GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, uint64 kl, uint64 ku, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_dgbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda, |
| GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, uint64 kl, uint64 ku, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_cgbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta), |
| complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, uint64 kl, uint64 ku, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_zgbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta), |
| complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, float alpha, const DeviceMemory<float> &a, |
| int lda, const DeviceMemory<float> &x, int incx, |
| float beta, DeviceMemory<float> *y, int incy) { |
| blas_log("DoBlasGemv"); |
| return DoBlasInternal( |
| wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), |
| incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| blas_log("DoBlasGemv"); |
| return DoBlasInternal( |
| wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), |
| incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| blas_log("DoBlasGemv"); |
| return DoBlasInternal( |
| wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, |
| uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| blas_log("DoBlasGemv\n"); |
| return DoBlasInternal( |
| wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *a, int lda) { |
| return DoBlasInternal( |
| wrap::rocblas_sger, stream, /* pointer_mode_host = */ true, m, n, &alpha, |
| GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *a, int lda) { |
| return DoBlasInternal( |
| wrap::rocblas_dger, stream, /* pointer_mode_host = */ true, m, n, &alpha, |
| GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_cgerc, stream, |
| /* pointer_mode_host = */ true, m, n, |
| complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_zgerc, stream, |
| /* pointer_mode_host = */ true, m, n, |
| complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_cgeru, stream, |
| /* pointer_mode_host = */ true, m, n, |
| complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_zgeru, stream, |
| /* pointer_mode_host = */ true, m, n, |
| complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_chbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_zhbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_chemv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_zhemv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_cher, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), |
| complex_cast(x), incx, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_zher, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), |
| complex_cast(x), incx, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *a, int lda) { |
| return DoBlasInternal( |
| wrap::rocblas_cher2, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *a, int lda) { |
| return DoBlasInternal( |
| wrap::rocblas_zher2, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &ap, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_chpmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap), |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &ap, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_zhpmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap), |
| complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *ap) { |
| return DoBlasInternal(wrap::rocblas_chpr, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), |
| complex_cast(x), incx, complex_cast(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *ap) { |
| return DoBlasInternal(wrap::rocblas_zhpr, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), |
| complex_cast(x), incx, complex_cast(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *ap) { |
| return DoBlasInternal( |
| wrap::rocblas_chpr2, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *ap) { |
| return DoBlasInternal( |
| wrap::rocblas_zhpr2, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx, |
| complex_cast(y), incy, complex_cast(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| uint64 k, float alpha, const DeviceMemory<float> &a, |
| int lda, const DeviceMemory<float> &x, int incx, |
| float beta, DeviceMemory<float> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x), |
| incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| uint64 k, double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| return DoBlasInternal( |
| wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x), |
| incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, const DeviceMemory<float> &ap, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_sspmv, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap), |
| GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, const DeviceMemory<double> &ap, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_dspmv, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap), |
| GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *ap) { |
| return DoBlasInternal(wrap::rocblas_sspr, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemoryMutable(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *ap) { |
| return DoBlasInternal(wrap::rocblas_dspr, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemoryMutable(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *ap) { |
| return DoBlasInternal(wrap::rocblas_sspr2, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemory(y), incy, GpuMemoryMutable(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *ap) { |
| return DoBlasInternal(wrap::rocblas_dspr2, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemory(y), incy, GpuMemoryMutable(ap)); |
| } |
| |
| bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_ssymv, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda, |
| GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy) { |
| return DoBlasInternal(wrap::rocblas_dsymv, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda, |
| GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); |
| } |
| |
| bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_ssyr, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemoryMutable(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_dsyr, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemoryMutable(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| float alpha, const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_ssyr2, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemory(y), incy, GpuMemoryMutable(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, |
| double alpha, const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *a, int lda) { |
| return DoBlasInternal(wrap::rocblas_dsyr2, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, |
| GpuMemory(y), incy, GpuMemoryMutable(a), lda); |
| } |
| |
| bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_stbmv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda, |
| GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_dtbmv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda, |
| GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *x, |
| int incx) { |
| return DoBlasInternal(wrap::rocblas_ctbmv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda, |
| complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *x, |
| int incx) { |
| return DoBlasInternal(wrap::rocblas_ztbmv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda, |
| complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_stbsv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda, |
| GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal(wrap::rocblas_dtbsv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda, |
| GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *x, |
| int incx) { |
| return DoBlasInternal(wrap::rocblas_ctbsv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda, |
| complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| uint64 k, const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *x, |
| int incx) { |
| return DoBlasInternal(wrap::rocblas_ztbsv, stream, |
| /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda, |
| complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &ap, DeviceMemory<float> *x, |
| int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_stpmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &ap, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_dtpmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &ap, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ctpmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &ap, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ztpmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &ap, DeviceMemory<float> *x, |
| int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_stpsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &ap, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_dtpsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &ap, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ctpsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &ap, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ztpsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_strmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_dtrmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ctrmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ztrmv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_strsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_dtrsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ctrsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx); |
| } |
| |
| bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| return DoBlasInternal( |
| wrap::rocblas_ztrsv, stream, /* pointer_mode_host = */ false, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), |
| ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx); |
| } |
| |
| port::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, |
| blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, blas::DataType dtype, |
| const void *alpha, const DeviceMemoryBase &a, |
| int lda, const DeviceMemoryBase &b, int ldb, |
| const void *beta, DeviceMemoryBase *c, |
| int ldc) { |
| blas_log("DoBlasGemm"); |
| VLOG(1) << absl::StreamFormat( |
| "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u " |
| "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " |
| "c=%p ldc=%d", |
| static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha, |
| a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); |
| if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) { |
| if (transa == blas::Transpose::kNoTranspose) { |
| if (lda < static_cast<int64>(m)) { |
| LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " |
| "precondition violation"; |
| } |
| } else { |
| if (lda < static_cast<int64>(k)) { |
| LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k |
| << ") (transpose case); precondition violation"; |
| } |
| } |
| if (transb == blas::Transpose::kNoTranspose) { |
| if (ldb < static_cast<int64>(k)) { |
| LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k |
| << ") (no transpose case); precondition violation"; |
| } |
| } else { |
| if (ldb < static_cast<int64>(n)) { |
| LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " |
| "precondition violation"; |
| } |
| } |
| } |
| |
| switch (dtype) { |
| case blas::DataType::kHalf: { |
| port::StatusOr<bool> maybe_hasXDLOPS = GpuDriver::GetMFMASupport(); |
| if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.ValueOrDie()) { |
| VLOG(1) << "Using rocblas_gemm_ex"; |
| return DoBlasInternalStatus( |
| wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), |
| (rocblas_int)m, (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), |
| rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r, |
| ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(), |
| rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, |
| rocblas_gemm_algo_standard, 0, 0); |
| } else { |
| VLOG(1) << "Using rocblas_hgemm"; |
| const Eigen::half alpha_half(*static_cast<const float *>(alpha)); |
| const Eigen::half beta_half(*static_cast<const float *>(beta)); |
| return DoBlasInternalStatus( |
| wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| reinterpret_cast<const rocblas_half *>(&alpha_half), |
| reinterpret_cast<const rocblas_half *>(a.opaque()), lda, |
| reinterpret_cast<const rocblas_half *>(b.opaque()), ldb, |
| reinterpret_cast<const rocblas_half *>(&beta_half), |
| reinterpret_cast<rocblas_half *>(c->opaque()), ldc); |
| } |
| } |
| case blas::DataType::kBF16: |
| return DoBlasInternalStatus( |
| wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), |
| (rocblas_int)m, (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), |
| rocblas_datatype_bf16_r, lda, b.opaque(), rocblas_datatype_bf16_r, |
| ldb, beta, c->opaque(), rocblas_datatype_bf16_r, ldc, c->opaque(), |
| rocblas_datatype_bf16_r, ldc, rocblas_datatype_f32_r, |
| rocblas_gemm_algo_standard, 0, 0); |
| case blas::DataType::kFloat: |
| return DoBlasInternalStatus( |
| wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| static_cast<const float *>(alpha), |
| static_cast<const float *>(a.opaque()), lda, |
| static_cast<const float *>(b.opaque()), ldb, |
| static_cast<const float *>(beta), static_cast<float *>(c->opaque()), |
| ldc); |
| case blas::DataType::kDouble: |
| return DoBlasInternalStatus( |
| wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| static_cast<const double *>(alpha), |
| static_cast<const double *>(a.opaque()), lda, |
| static_cast<const double *>(b.opaque()), ldb, |
| static_cast<const double *>(beta), static_cast<double *>(c->opaque()), |
| ldc); |
| case blas::DataType::kComplexFloat: { |
| auto cb_alpha = |
| complex_cast(*static_cast<const std::complex<float> *>(alpha)); |
| auto cb_beta = |
| complex_cast(*static_cast<const std::complex<float> *>(beta)); |
| return DoBlasInternalStatus( |
| wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| cb_alpha, static_cast<const rocblas_float_complex *>(a.opaque()), lda, |
| static_cast<const rocblas_float_complex *>(b.opaque()), ldb, cb_beta, |
| static_cast<rocblas_float_complex *>(c->opaque()), ldc); |
| } |
| case blas::DataType::kComplexDouble: { |
| auto cb_alpha = |
| complex_cast(*static_cast<const std::complex<double> *>(alpha)); |
| auto cb_beta = |
| complex_cast(*static_cast<const std::complex<double> *>(beta)); |
| return DoBlasInternalStatus( |
| wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| cb_alpha, static_cast<const rocblas_double_complex *>(a.opaque()), |
| lda, static_cast<const rocblas_double_complex *>(b.opaque()), ldb, |
| cb_beta, static_cast<rocblas_double_complex *>(c->opaque()), ldc); |
| } |
| default: |
| return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", |
| blas::DataTypeString(dtype))); |
| } |
| } |
| |
| bool ROCMBlas::DoBlasGemvWithProfiling( |
| Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, |
| int incx, float beta, DeviceMemory<float> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, |
| incx, beta, y, incy, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemvWithProfiling( |
| Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, |
| int incx, double beta, DeviceMemory<double> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, |
| incx, beta, y, incy, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemvWithProfiling( |
| Stream *stream, blas::Transpose trans, uint64 m, uint64 n, |
| std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, |
| int lda, const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, |
| incx, beta, y, incy, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemvWithProfiling( |
| Stream *stream, blas::Transpose trans, uint64 m, uint64 n, |
| std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, |
| int lda, const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, |
| incx, beta, y, incy, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemmWithProfiling( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, |
| int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, |
| DeviceMemory<Eigen::half> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, |
| lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemmWithProfiling( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, |
| int ldc, blas::ProfileResult *output_profile_result) { |
| return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, |
| lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemmWithProfiling( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, double beta, |
| DeviceMemory<double> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, |
| lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemmWithProfiling( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, |
| lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| bool ROCMBlas::DoBlasGemmWithProfiling( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, |
| lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| template <typename T> |
| bool ROCMBlas::DoBlasGemvWithProfilingImpl( |
| Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha, |
| const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx, |
| const T &beta, DeviceMemory<T> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| // ROCM TODO: properly implement the interface |
| return false; |
| } |
| |
| template <typename T, typename ParamType> |
| bool ROCMBlas::DoBlasGemmWithProfilingImpl( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, |
| int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, |
| DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) { |
| // ROCM TODO: properly implement the interface |
| return false; |
| } |
| port::Status ROCMBlas::DoBlasGemmWithAlgorithm( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a, |
| blas::DataType type_a, int lda, const DeviceMemoryBase &b, |
| blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, |
| blas::DataType type_c, int ldc, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { |
| // ROCM TODO: properly implement the interface |
| return port::InternalError("Not implemented on ROCm"); |
| } |
| |
| port::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a, |
| blas::DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b, |
| blas::DataType type_b, int ldb, int64 stride_b, const void *beta, |
| DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64 stride_c, |
| int batch_count, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { |
| // ROCM TODO: properly implement the interface |
| return port::InternalError("Not implemented on ROCm"); |
| } |
| |
| bool ROCMBlas::GetBlasGemmAlgorithms( |
| std::vector<blas::AlgorithmType> *out_algorithms) { |
| // ROCM TODO: properly implement the interface |
| return true; |
| } |
| |
| // This copies from source memory: raw_ptrs[i] to target memory: |
| // device_memory_ptr at the interval of matrix_byte_size, or vice versa. |
| // The below algorithm tries to minimize the number of memcpy by consolidating |
| // neighboring memcpy into a single request |
| template <typename MAPPED_T> |
| port::Status ReorganizeMemory(Stream *stream, |
| DeviceMemory<MAPPED_T> *device_memory, |
| const std::vector<MAPPED_T *> &raw_ptrs, |
| int batch_count, uint64_t batch_stride, |
| bool gather) { |
| assert(batch_count > 0); |
| char *device_memory_ptr = static_cast<char *>(device_memory->opaque()); |
| char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]); |
| char *dst_ptr = device_memory_ptr; |
| size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T); |
| uint64_t cur_stride_size = matrix_byte_size; |
| |
| for (int i = 1; i < batch_count; ++i) { |
| if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) { |
| cur_stride_size += matrix_byte_size; |
| } else { |
| DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size); |
| DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size); |
| bool a_status = |
| gather |
| ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok() |
| : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok(); |
| if (!a_status) { |
| return port::Status( |
| port::error::INTERNAL, |
| "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); |
| } |
| src_ptr = reinterpret_cast<char *>(raw_ptrs[i]); |
| dst_ptr = device_memory_ptr + i * matrix_byte_size; |
| cur_stride_size = matrix_byte_size; |
| } |
| } |
| |
| DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size); |
| DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size); |
| bool a_status = |
| gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok() |
| : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok(); |
| if (!a_status) |
| return port::Status( |
| port::error::INTERNAL, |
| "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); |
| return port::Status::OK(); |
| } |
| |
| template <typename T> |
| port::Status ROCMBlas::AllocateStridedBuffer( |
| const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *> |
| &raw_ptrs, |
| int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator, |
| Stream *stream, |
| std::unique_ptr<TemporaryDeviceMemory< |
| typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, |
| DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> |
| *device_memory, |
| bool copy_data, bool &reallocated) { |
| assert(device_memory != nullptr); |
| |
| using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type; |
| |
| bool needs_allocate_strided = false; |
| for (int i = 1; i < batch_count; ++i) { |
| uint64_t tmp_batch_stride = raw_ptrs[i] - raw_ptrs[i - 1]; |
| if (tmp_batch_stride != batch_stride) { |
| needs_allocate_strided = true; |
| break; |
| } |
| } |
| |
| size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T); |
| size_t matrix_batch_byte_size = matrix_byte_size * batch_count; |
| |
| // No need to do re-allocation, take the short cut and return |
| if (!needs_allocate_strided) { |
| *device_memory = DeviceMemory<MAPPED_T>( |
| DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size)); |
| reallocated = false; |
| return port::Status::OK(); |
| } |
| |
| if (scratch_allocator != nullptr) { |
| SE_ASSIGN_OR_RETURN( |
| DeviceMemory<uint8> batch_matrix_bytes, |
| scratch_allocator->AllocateBytes(matrix_batch_byte_size)); |
| *device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes); |
| } else { |
| assert(temp_memory != nullptr); |
| SE_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray<MAPPED_T>( |
| matrix_batch_byte_size)); |
| *device_memory = |
| DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory()); |
| } |
| |
| reallocated = true; |
| |
| if (copy_data) |
| return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count, |
| batch_stride, true); |
| return port::Status::OK(); |
| } |
| |
| template <typename T, typename FuncT> |
| port::Status ROCMBlas::DoBlasGemmBatchedInternal( |
| FuncT rocblas_func, Stream *stream, blas::Transpose transa, |
| blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, |
| const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda, |
| const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb, |
| T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, |
| int ldc, int batch_count, ScratchAllocator *scratch_allocator) { |
| using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type; |
| |
| // Sanity checks before making any further progress |
| uint64_t batch_stride_a = 0; |
| uint64_t batch_stride_b = 0; |
| uint64_t batch_stride_c = 0; |
| |
| assert(ldc >= m); |
| batch_stride_c = ldc * n; |
| |
| if (ROCMBlasTranspose(transa) == rocblas_operation_none) { |
| assert(lda >= m); |
| batch_stride_a = lda * k; |
| } else { |
| assert(lda >= k); |
| batch_stride_a = lda * m; |
| } |
| |
| if (ROCMBlasTranspose(transb) == rocblas_operation_none) { |
| assert(ldb >= k); |
| batch_stride_b = ldb * n; |
| } else { |
| assert(ldb >= n); |
| batch_stride_b = ldb * k; |
| } |
| |
| // Allocate local vectors to hold device pointers to matrices |
| std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; |
| for (int i = 0; i < batch_count; ++i) { |
| // static_cast does work when converting Eigen::half* to rocblas_half*, |
| // hence the use of reinterpret_cast |
| a_raw_ptrs.push_back( |
| reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque())); |
| b_raw_ptrs.push_back( |
| reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque())); |
| c_raw_ptrs.push_back( |
| reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque())); |
| } |
| |
| DeviceMemory<MAPPED_T> a; |
| // Make sure the temporary memory are in-scope before the function returns |
| std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp; |
| bool reallocated_a, reallocated_b, reallocated_c; |
| port::Status a_allocation_status = AllocateStridedBuffer<T>( |
| a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream, |
| &a_temp, &a, true, reallocated_a); |
| if (a_allocation_status != port::Status::OK()) { |
| return a_allocation_status; |
| } |
| |
| DeviceMemory<MAPPED_T> b; |
| std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp; |
| port::Status b_allocation_status = AllocateStridedBuffer<T>( |
| b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream, |
| &b_temp, &b, true, reallocated_b); |
| if (b_allocation_status != port::Status::OK()) { |
| return b_allocation_status; |
| } |
| |
| DeviceMemory<MAPPED_T> c; |
| std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp; |
| port::Status c_allocation_status = AllocateStridedBuffer<T>( |
| c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream, |
| &c_temp, &c, true, reallocated_c); // can disable copy if beta=0 |
| if (c_allocation_status != port::Status::OK()) { |
| return c_allocation_status; |
| } |
| |
| MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha); |
| MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta); |
| |
| bool ok; |
| ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true, |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, |
| n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, |
| batch_stride_a, GpuMemory(b), ldb, batch_stride_b, |
| GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, |
| batch_stride_c, batch_count); |
| if (!ok) |
| return port::Status(port::error::INTERNAL, |
| "failed BLAS call, see log for details"); |
| if (reallocated_c) |
| return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c, |
| false); |
| return port::Status::OK(); |
| } |
| |
| bool ROCMBlas::DoBlasGemmBatched( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, float alpha, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| blas_log("DoBlasGemmBatched"); |
| const Eigen::half alpha_half(alpha); |
| const Eigen::half beta_half(beta); |
| |
| port::Status status = DoBlasGemmBatchedInternal( |
| wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k, |
| alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count, |
| scratch_allocator); |
| if (!status.ok()) { |
| LOG(ERROR) << status; |
| } |
| |
| return status.ok(); |
| } |
| |
| bool ROCMBlas::DoBlasGemmBatched( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, float alpha, |
| const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda, |
| const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta, |
| const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| blas_log("DoBlasGemmBatched"); |
| port::Status status = DoBlasGemmBatchedInternal( |
| wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, |
| alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, |
| scratch_allocator); |
| if (!status.ok()) { |
| LOG(ERROR) << status; |
| } |
| return status.ok(); |
| } |
| |
| bool ROCMBlas::DoBlasGemmBatched( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, double alpha, |
| const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda, |
| const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb, |
| double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array, |
| int ldc, int batch_count, ScratchAllocator *scratch_allocator) { |
| blas_log("DoBlasGemmBatched"); |
| port::Status status = DoBlasGemmBatchedInternal( |
| wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, |
| alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, |
| scratch_allocator); |
| if (!status.ok()) { |
| LOG(ERROR) << status; |
| } |
| return status.ok(); |
| } |
| |
| bool ROCMBlas::DoBlasGemmBatched( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, std::complex<float> alpha, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array, |
| int lda, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array, |
| int ldb, std::complex<float> beta, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array, |
| int ldc, int batch_count, ScratchAllocator *scratch_allocator) { |
| blas_log("DoBlasGemmBatched"); |
| port::Status status = DoBlasGemmBatchedInternal( |
| wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, |
| alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, |
| scratch_allocator); |
| if (!status.ok()) { |
| LOG(ERROR) << status; |
| } |
| return status.ok(); |
| } |
| |
| bool ROCMBlas::DoBlasGemmBatched( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, std::complex<double> alpha, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array, |
| int lda, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array, |
| int ldb, std::complex<double> beta, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array, |
| int ldc, int batch_count, ScratchAllocator *scratch_allocator) { |
| blas_log("DoBlasGemmBatched"); |
| port::Status status = DoBlasGemmBatchedInternal( |
| wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, |
| alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, |
| scratch_allocator); |
| if (!status.ok()) { |
| LOG(ERROR) << status; |
| } |
| return status.ok(); |
| } |
| |
| bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| return DoBlasInternal(wrap::rocblas_chemm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| return DoBlasInternal(wrap::rocblas_zhemm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| float alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| float beta, DeviceMemory<std::complex<float>> *c, |
| int ldc) { |
| return DoBlasInternal(wrap::rocblas_cherk, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, |
| k, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| double alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| double beta, DeviceMemory<std::complex<double>> *c, |
| int ldc) { |
| return DoBlasInternal(wrap::rocblas_zherk, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, |
| k, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| float beta, DeviceMemory<std::complex<float>> *c, |
| int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_cher2k, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, |
| complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| double beta, DeviceMemory<std::complex<double>> *c, |
| int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_zher2k, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, |
| complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, uint64 m, uint64 n, |
| float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, |
| DeviceMemory<float> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_ssymm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a), |
| lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, uint64 m, uint64 n, |
| double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, double beta, |
| DeviceMemory<double> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_dsymm, stream, /* pointer_mode_host = */ true, |
| ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a), |
| lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| return DoBlasInternal(wrap::rocblas_csymm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| return DoBlasInternal(wrap::rocblas_zsymm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| float alpha, const DeviceMemory<float> &a, int lda, |
| float beta, DeviceMemory<float> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_ssyrk, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, |
| GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| double alpha, const DeviceMemory<double> &a, int lda, |
| double beta, DeviceMemory<double> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_dsyrk, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, |
| GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| return DoBlasInternal(wrap::rocblas_csyrk, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, |
| k, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| return DoBlasInternal(wrap::rocblas_zsyrk, stream, |
| /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, |
| k, complex_cast(alpha), complex_cast(a), lda, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, |
| DeviceMemory<float> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_ssyr2k, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, |
| GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, double beta, |
| DeviceMemory<double> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_dsyr2k, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, |
| GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_csyr2k, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, |
| complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, |
| blas::Transpose trans, uint64 n, uint64 k, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| return DoBlasInternal( |
| wrap::rocblas_zsyr2k, stream, /* pointer_mode_host = */ true, |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, |
| complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb, |
| complex_cast(beta), complex_cast(c), ldc); |
| } |
| |
| bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *b, int ldb) { |
| return DoBlasInternal(wrap::rocblas_strmm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, |
| GpuMemoryMutable(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *b, int ldb) { |
| return DoBlasInternal(wrap::rocblas_dtrmm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, |
| GpuMemoryMutable(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *b, int ldb) { |
| return DoBlasInternal(wrap::rocblas_ctrmm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *b, int ldb) { |
| return DoBlasInternal(wrap::rocblas_ztrmm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *b, int ldb) { |
| blas_log("DoBlasTrsm"); |
| return DoBlasInternal(wrap::rocblas_strsm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, |
| GpuMemoryMutable(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *b, int ldb) { |
| blas_log("DoBlasTrsm"); |
| return DoBlasInternal(wrap::rocblas_dtrsm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, |
| GpuMemoryMutable(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *b, int ldb) { |
| return DoBlasInternal(wrap::rocblas_ctrsm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb); |
| } |
| |
| bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, |
| blas::UpperLower uplo, blas::Transpose transa, |
| blas::Diagonal diag, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *b, int ldb) { |
| return DoBlasInternal(wrap::rocblas_ztrsm, stream, |
| /* pointer_mode_host = */ true, ROCMBlasSide(side), |
| ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), |
| ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), |
| complex_cast(a), lda, complex_cast(b), ldb); |
| } |
| |
| port::Status ROCMBlas::DoBlasGemmStridedBatched( |
| Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, |
| uint64 n, uint64 k, blas::DataType dtype, const void *alpha, |
| const DeviceMemoryBase &a, int lda, int64 stride_a, |
| const DeviceMemoryBase &b, int ldb, int64 stride_b, const void *beta, |
| DeviceMemoryBase *c, int ldc, int64 stride_c, int batch_count) { |
| VLOG(1) << absl::StreamFormat( |
| "doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u " |
| "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " |
| "c=%p ldc=%d", |
| static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha, |
| a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); |
| |
| switch (dtype) { |
| case blas::DataType::kHalf: { |
| const Eigen::half alpha_half(*static_cast<const float *>(alpha)); |
| const Eigen::half beta_half(*static_cast<const float *>(beta)); |
| return DoBlasInternalStatus( |
| wrap::rocblas_hgemm_strided_batched, stream, |
| false, /* pointer_mode_host */ |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| reinterpret_cast<const rocblas_half *>(&alpha_half), |
| reinterpret_cast<const rocblas_half *>(a.opaque()), lda, stride_a, |
| reinterpret_cast<const rocblas_half *>(b.opaque()), ldb, stride_b, |
| reinterpret_cast<const rocblas_half *>(&beta_half), |
| reinterpret_cast<rocblas_half *>(c->opaque()), ldc, stride_c, |
| batch_count); |
| } |
| case blas::DataType::kBF16: |
| return DoBlasInternalStatus( |
| wrap::rocblas_gemm_strided_batched_ex, stream, |
| false, /* pointer_mode_host */ |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| alpha, |
| a.opaque(), rocblas_datatype_bf16_r, lda, stride_a, |
| b.opaque(), rocblas_datatype_bf16_r, ldb, stride_b, |
| beta, |
| c->opaque(), rocblas_datatype_bf16_r, ldc, stride_c, |
| c->opaque(), rocblas_datatype_bf16_r, ldc, stride_c, |
| batch_count, rocblas_datatype_f32_r, |
| rocblas_gemm_algo_standard, 0, 0); |
| case blas::DataType::kFloat: |
| return DoBlasInternalStatus( |
| wrap::rocblas_sgemm_strided_batched, stream, |
| false, /* pointer_mode_host */ |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| reinterpret_cast<const float *>(alpha), |
| reinterpret_cast<const float *>(a.opaque()), lda, stride_a, |
| reinterpret_cast<const float *>(b.opaque()), ldb, stride_b, |
| reinterpret_cast<const float *>(beta), |
| reinterpret_cast<float *>(c->opaque()), ldc, stride_c, batch_count); |
| case blas::DataType::kDouble: |
| return DoBlasInternalStatus( |
| wrap::rocblas_dgemm_strided_batched, stream, |
| false, /* pointer_mode_host */ |
| ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, |
| reinterpret_cast<const double *>(alpha), |
| reinterpret_cast<const double *>(a.opaque()), lda, stride_a, |
| reinterpret_cast<const double *>(b.opaque()), ldb, stride_b, |
| reinterpret_cast<const double *>(beta), |
| reinterpret_cast<double *>(c->opaque()), ldc, stride_c, batch_count); |
| case blas::DataType::kComplexFloat: { |
| auto cb_alpha = |
| complex_cast(*static_cast<const std::complex<float> *>(alpha)); |
| auto cb_beta = |
| complex_cast(*static_cast<const std::complex<float> *>(beta)); |
| return DoBlasInternalStatus( |
| wrap::rocblas_cgemm_strided_batched, stream, |
| /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), |
| ROCMBlasTranspose(transb), m, n, k, cb_alpha, |
| static_cast<const rocblas_float_complex *>(a.opaque()), lda, stride_a, |
| static_cast<const rocblas_float_complex *>(b.opaque()), ldb, stride_b, |
| cb_beta, static_cast<rocblas_float_complex *>(c->opaque()), ldc, |
| stride_c, batch_count); |
| } |
| case blas::DataType::kComplexDouble: { |
| auto cb_alpha = |
| complex_cast(*static_cast<const std::complex<double> *>(alpha)); |
| auto cb_beta = |
| complex_cast(*static_cast<const std::complex<double> *>(beta)); |
| return DoBlasInternalStatus( |
| wrap::rocblas_zgemm_strided_batched, stream, |
| /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), |
| ROCMBlasTranspose(transb), m, n, k, cb_alpha, |
| static_cast<const rocblas_double_complex *>(a.opaque()), lda, |
| stride_a, static_cast<const rocblas_double_complex *>(b.opaque()), |
| ldb, stride_b, cb_beta, |
| static_cast<rocblas_double_complex *>(c->opaque()), ldc, stride_c, |
| batch_count); |
| } |
| default: |
| return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", |
| blas::DataTypeString(dtype))); |
| } |
| } |
| |
| port::Status ROCMBlas::GetVersion(string *version) { |
| return port::UnimplementedError(""); |
| } |
| |
| port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> |
| ROCMBlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) { |
| return port::Status( |
| port::error::UNIMPLEMENTED, |
| "CreateBlasLtMatmulPlan is not supported with this version of ROCM"); |
| } |
| |
| port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> |
| ROCMBlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, |
| size_t max_workspace_size, |
| int max_algorithm_count) { |
| return port::Status( |
| port::error::UNIMPLEMENTED, |
| "GetBlasLtMatmulAlgorithms is not supported with this version of ROCM"); |
| } |
| |
| bool ROCMBlas::DoBlasLtMatmul( |
| Stream *stream, const blas::IBlasLtMatmulPlan *plan, |
| const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, |
| DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, |
| DeviceMemoryBase c, ScratchAllocator *scratch_allocator, |
| const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, |
| blas::ProfileResult *output_profile_result) { |
| return false; |
| } |
| |
| } // namespace gpu |
| |
| void initialize_rocblas() { |
| auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory( |
| rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin); |
| |
| if (!rocBlasAlreadyRegistered) { |
| port::Status status = |
| PluginRegistry::Instance() |
| ->RegisterFactory<PluginRegistry::BlasFactory>( |
| rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS", |
| [](internal::StreamExecutorInterface *parent) |
| -> blas::BlasSupport * { |
| gpu::GpuExecutor *rocm_executor = |
| dynamic_cast<gpu::GpuExecutor *>(parent); |
| if (rocm_executor == nullptr) { |
| LOG(ERROR) |
| << "Attempting to initialize an instance of the " |
| "rocBLAS " |
| << "support library with a non-ROCM StreamExecutor"; |
| return nullptr; |
| } |
| |
| gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor); |
| if (!blas->Init()) { |
| // Note: Init() will log a more specific error. |
| delete blas; |
| return nullptr; |
| } |
| return blas; |
| }); |
| |
| if (!status.ok()) { |
| LOG(ERROR) << "Unable to register rocBLAS factory: " |
| << status.error_message(); |
| } |
| |
| PluginRegistry::Instance()->SetDefaultFactory( |
| rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin); |
| } |
| } |
| |
| } // namespace stream_executor |
| |
| REGISTER_MODULE_INITIALIZER(register_rocblas, |
| { stream_executor::initialize_rocblas(); }); |