blob: 1c695b7a24c698b0240c4c6c88ec6f29624603c1 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rocm/include/rocblas.h"
#include "tensorflow/stream_executor/rocm/rocm_blas.h"
#define EIGEN_USE_GPU
#include <assert.h>
#include <complex>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/gpu/gpu_timer.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/status_macros.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
#include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace stream_executor {
namespace gpu {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
namespace wrap {
#ifdef PLATFORM_GOOGLE
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
struct WrapperShim__##__name { \
static const char *kName; \
template <typename... Args> \
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name; \
const char *WrapperShim__##__name::kName = #__name;
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
#else
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void *f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocblas DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
#endif
// clang-format off
#define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(rocblas_snrm2) \
__macro(rocblas_dnrm2) \
/*__macro(rocblas_scnrm2) \
__macro(rocblas_dznrm2) */ \
__macro(rocblas_sdot) \
__macro(rocblas_ddot) \
/*__macro(rocblas_cdotu) \
__macro(rocblas_cdotc) \
__macro(rocblas_zdotu) \
__macro(rocblas_zdotc) */ \
__macro(rocblas_sscal) \
__macro(rocblas_dscal) \
/*__macro(rocblas_cscal) \
__macro(rocblas_csscal) \
__macro(rocblas_zscal) \
__macro(rocblas_zdscal) */ \
__macro(rocblas_saxpy) \
__macro(rocblas_daxpy) \
/*__macro(rocblas_caxpy) \
__macro(rocblas_zaxpy) */ \
__macro(rocblas_scopy) \
__macro(rocblas_dcopy) \
/*__macro(rocblas_ccopy) \
__macro(rocblas_zcopy) */ \
__macro(rocblas_sswap) \
__macro(rocblas_dswap) \
/*__macro(rocblas_cswap) \
__macro(rocblas_zswap) */ \
__macro(rocblas_isamax) \
__macro(rocblas_idamax) \
/*__macro(rocblas_icamax) \
__macro(rocblas_izamax) */ \
__macro(rocblas_isamin) \
__macro(rocblas_idamin) \
/*__macro(rocblas_icamin) \
__macro(rocblas_izamin) */ \
__macro(rocblas_sasum) \
__macro(rocblas_dasum) \
/*__macro(rocblas_scasum) \
__macro(rocblas_dzasum) \
__macro(rocblas_srot) \
__macro(rocblas_drot) \
__macro(rocblas_crot) \
__macro(rocblas_csrot) \
__macro(rocblas_zrot) \
__macro(rocblas_zdrot) \
__macro(rocblas_srotg) \
__macro(rocblas_drotg) \
__macro(rocblas_Crotg) \
__macro(rocblas_crotg) \
__macro(rocblas_zrotm) \
__macro(rocblas_drotm) \
__macro(rocblas_srotmg) \
__macro(rocblas_drotmg) */ \
__macro(rocblas_sgemv) \
__macro(rocblas_dgemv) \
/*__macro(rocblas_cgemv) \
__macro(rocblas_zgemv) \
__macro(rocblas_sgbmv) \
__macro(rocblas_dgbmv) \
__macro(rocblas_cgbmv) \
__macro(rocblas_zgbmv) \
__macro(rocblas_strmv) \
__macro(rocblas_dtrmv) \
__macro(rocblas_ctrmv) \
__macro(rocblas_ztrmv) \
__macro(rocblas_stbmv) \
__macro(rocblas_dtbmv) \
__macro(rocblas_ctbmv) \
__macro(rocblas_ztbmv) \
__macro(rocblas_stpmv) \
__macro(rocblas_dtpmv) \
__macro(rocblas_ctpmv) \
__macro(rocblas_ztpmv) \
__macro(rocblas_strsv) \
__macro(rocblas_dtrsv) \
__macro(rocblas_ctrsv) \
__macro(rocblas_ztrsv) \
__macro(rocblas_stpsv) \
__macro(rocblas_dtpsv) \
__macro(rocblas_ctpsv) \
__macro(rocblas_ztpsv) \
__macro(rocblas_stbsv) \
__macro(rocblas_dtbsv) \
__macro(rocblas_ctbsv) \
__macro(rocblas_ztbsv) \
__macro(rocblas_ssymv) \
__macro(rocblas_dsymv) \
__macro(rocblas_csymv) \
__macro(rocblas_zsymv) \
__macro(rocblas_chemv) \
__macro(rocblas_zhemv) \
__macro(rocblas_ssbmv) \
__macro(rocblas_dsbmv) \
__macro(rocblas_chbmv) \
__macro(rocblas_zhbmv) \
__macro(rocblas_sspmv) \
__macro(rocblas_dspmv) \
__macro(rocblas_chpmv) \
__macro(rocblas_zhpmv) */ \
__macro(rocblas_sger) \
__macro(rocblas_dger) \
/*__macro(rocblas_cgeru) \
__macro(rocblas_cgerc) \
__macro(rocblas_zgeru) \
__macro(rocblas_zgerc) */ \
__macro(rocblas_ssyr) \
__macro(rocblas_dsyr) \
/*__macro(rocblas_csyr) \
__macro(rocblas_zsyr) \
__macro(rocblas_cher) \
__macro(rocblas_zher) \
__macro(rocblas_sspr) \
__macro(rocblas_dspr) \
__macro(rocblas_chpr) \
__macro(rocblas_zhpr) \
__macro(rocblas_ssyr2) \
__macro(rocblas_dsyr2) \
__macro(rocblas_csyr2) \
__macro(rocblas_zsyr2) \
__macro(rocblas_cher2) \
__macro(rocblas_zher2) \
__macro(rocblas_sspr2) \
__macro(rocblas_dspr2) \
__macro(rocblas_chpr2) \
__macro(rocblas_zhpr2) */ \
__macro(rocblas_sgemm) \
__macro(rocblas_dgemm) \
__macro(rocblas_hgemm) \
/*__macro(rocblas_cgemm) \
__macro(rocblas_zgemm) \
__macro(rocblas_ssyrk) \
__macro(rocblas_dsyrk) \
__macro(rocblas_csyrk) \
__macro(rocblas_zsyrk) \
__macro(rocblas_cherk) \
__macro(rocblas_zherk) \
__macro(rocblas_ssyr2k) \
__macro(rocblas_dsyr2k) \
__macro(rocblas_csyr2k) \
__macro(rocblas_zsyr2k) \
__macro(rocblas_cher2k) \
__macro(rocblas_zher2k) \
__macro(rocblas_ssyrkx) \
__macro(rocblas_dsyrkx) \
__macro(rocblas_csyrkx) \
__macro(rocblas_zsyrkx) \
__macro(rocblas_cherkx) \
__macro(rocblas_zherkx) \
__macro(rocblas_ssymm) \
__macro(rocblas_dsymm) \
__macro(rocblas_csymm) \
__macro(rocblas_zsymm) \
__macro(rocblas_chemm) \
__macro(rocblas_zhemm) */ \
__macro(rocblas_strsm) \
__macro(rocblas_dtrsm) \
/*__macro(rocblas_ctrsm) \
__macro(rocblas_ztrsm) \
__macro(rocblas_strmm) \
__macro(rocblas_dtrmm) \
__macro(rocblas_ctrmm) \
__macro(rocblas_ztrmm) */ \
__macro(rocblas_sgeam) \
__macro(rocblas_dgeam) \
/*__macro(rocblas_cgeam) \
__macro(rocblas_zgeam) \
__macro(rocblas_sdgmm) \
__macro(rocblas_ddgmm) \
__macro(rocblas_cdgmm) \
__macro(rocblas_zdgmm) */
// clang-format on
STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle)
STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle)
STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream)
// STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode)
// STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
} // namespace wrap
static string ToString(rocblas_status status) {
switch (status) {
case rocblas_status_success:
return "rocblas_status_success";
case rocblas_status_invalid_handle:
return "rocblas_status_invalid_handle";
case rocblas_status_not_implemented:
return "rocblas_status_not_implemented";
case rocblas_status_invalid_pointer:
return "rocblas_status_invalid_pointer";
case rocblas_status_invalid_size:
return "rocblas_status_invalid_size";
case rocblas_status_memory_error:
return "rocblas_status_memory_error";
case rocblas_status_internal_error:
return "rocblas_status_internal_error";
default:
return absl::StrCat("<invalid rocBLAS status: ", status, ">");
}
}
bool ROCMBlas::Init() {
rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_);
if (ret != rocblas_status_success) {
LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
return false;
}
return true;
}
ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
ROCMBlas::~ROCMBlas() {
if (blas_ != nullptr) {
wrap::rocblas_destroy_handle(parent_, blas_);
}
}
bool ROCMBlas::SetStream(Stream *stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
CHECK(blas_ != nullptr);
rocblas_status ret =
wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream));
if (ret != rocblas_status_success) {
LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
return false;
}
return true;
}
namespace {
// Helper functions transforming blas arguments into rocBLAS arguments.
rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
switch (trans) {
case blas::Transpose::kNoTranspose:
return rocblas_operation_none;
case blas::Transpose::kTranspose:
return rocblas_operation_transpose;
case blas::Transpose::kConjugateTranspose:
return rocblas_operation_conjugate_transpose;
default:
LOG(FATAL) << "Invalid value of blas::Transpose.";
}
}
rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
switch (uplo) {
case blas::UpperLower::kUpper:
return rocblas_fill_upper;
case blas::UpperLower::kLower:
return rocblas_fill_lower;
default:
LOG(FATAL) << "Invalid value of blas::UpperLower.";
}
}
rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
switch (diag) {
case blas::Diagonal::kUnit:
return rocblas_diagonal_unit;
case blas::Diagonal::kNonUnit:
return rocblas_diagonal_non_unit;
default:
LOG(FATAL) << "Invalid value of blas::Diagonal.";
}
}
rocblas_side ROCMBlasSide(blas::Side side) {
switch (side) {
case blas::Side::kLeft:
return rocblas_side_left;
case blas::Side::kRight:
return rocblas_side_right;
default:
LOG(FATAL) << "Invalid value of blas::Side.";
}
}
} // namespace
template <typename FuncT, typename... Args>
bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
Args... args) {
absl::MutexLock lock{&mu_};
CHECK(blas_ != nullptr);
if (!SetStream(stream)) {
return false;
}
rocblas_status ret = rocblas_func(parent_, blas_, args...);
if (err_on_failure && ret != rocblas_status_success) {
LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
<< ToString(ret);
}
return ret == rocblas_status_success;
}
bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *result) {
return DoBlasInternal(wrap::rocblas_sasum, stream,
false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *result) {
return DoBlasInternal(wrap::rocblas_dasum, stream,
false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<float> *result) {
LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<double> *result) {
LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *y, int incy) {
return DoBlasInternal(wrap::rocblas_saxpy, stream,
true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemory(x), incx, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *y, int incy) {
return DoBlasInternal(wrap::rocblas_daxpy, stream,
true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemory(x), incx, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *y, int incy) {
return DoBlasInternal(wrap::rocblas_scopy, stream,
true /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *y, int incy) {
return DoBlasInternal(wrap::rocblas_dcopy, stream,
true /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx,
const DeviceMemory<float> &y, int incy,
DeviceMemory<float> *result) {
return DoBlasInternal(
wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx,
const DeviceMemory<double> &y, int incy,
DeviceMemory<double> *result) {
return DoBlasInternal(
wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *result) {
LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *result) {
LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *result) {
LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *result) {
LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *result) {
return DoBlasInternal(wrap::rocblas_snrm2, stream,
false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *result) {
return DoBlasInternal(wrap::rocblas_dnrm2, stream,
false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<float> *result) {
LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<double> *result) {
LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
DeviceMemory<float> *x, int incx,
DeviceMemory<float> *y, int incy, float c, float s) {
LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
DeviceMemory<double> *x, int incx,
DeviceMemory<double> *y, int incy, double c,
double s) {
LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
DeviceMemory<std::complex<float>> *x, int incx,
DeviceMemory<std::complex<float>> *y, int incy,
float c, float s) {
LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
DeviceMemory<std::complex<double>> *x, int incx,
DeviceMemory<std::complex<double>> *y, int incy,
double c, double s) {
LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
DeviceMemory<float> *b, DeviceMemory<float> *c,
DeviceMemory<float> *s) {
LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
DeviceMemory<double> *b, DeviceMemory<double> *c,
DeviceMemory<double> *s) {
LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
DeviceMemory<std::complex<float>> *b,
DeviceMemory<float> *c,
DeviceMemory<std::complex<float>> *s) {
LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
DeviceMemory<std::complex<double>> *b,
DeviceMemory<double> *c,
DeviceMemory<std::complex<double>> *s) {
LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
DeviceMemory<float> *x, int incx,
DeviceMemory<float> *y, int incy,
const DeviceMemory<float> &param) {
LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
DeviceMemory<double> *x, int incx,
DeviceMemory<double> *y, int incy,
const DeviceMemory<double> &param) {
LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
DeviceMemory<float> *d2, DeviceMemory<float> *x1,
const DeviceMemory<float> &y1,
DeviceMemory<float> *param) {
LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
DeviceMemory<double> *d2, DeviceMemory<double> *x1,
const DeviceMemory<double> &y1,
DeviceMemory<double> *param) {
LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
DeviceMemory<float> *x, int incx) {
return DoBlasInternal(wrap::rocblas_sscal, stream,
true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemoryMutable(x), incx);
}
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
DeviceMemory<double> *x, int incx) {
return DoBlasInternal(wrap::rocblas_dscal, stream,
true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemoryMutable(x), incx);
}
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
std::complex<float> alpha,
DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
std::complex<double> alpha,
DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
DeviceMemory<float> *x, int incx,
DeviceMemory<float> *y, int incy) {
return DoBlasInternal(wrap::rocblas_sswap, stream,
true /* = pointer_mode_host */, elem_count,
GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
DeviceMemory<double> *x, int incx,
DeviceMemory<double> *y, int incy) {
return DoBlasInternal(wrap::rocblas_dswap, stream,
true /* = pointer_mode_host */, elem_count,
GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
DeviceMemory<std::complex<float>> *x, int incx,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
DeviceMemory<std::complex<double>> *x, int incx,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx,
DeviceMemory<int> *result) {
return DoBlasInternal(wrap::rocblas_isamax, stream,
false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx,
DeviceMemory<int> *result) {
return DoBlasInternal(wrap::rocblas_idamax, stream,
false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<int> *result) {
LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx, DeviceMemory<int> *result) {
LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx,
DeviceMemory<int> *result) {
return DoBlasInternal(
wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count,
GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx,
DeviceMemory<int> *result) {
return DoBlasInternal(
wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count,
GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
}
bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<int> *result) {
LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx, DeviceMemory<int> *result) {
LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, uint64 kl, uint64 ku, float alpha,
const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &x, int incx, float beta,
DeviceMemory<float> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, uint64 kl, uint64 ku, double alpha,
const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &x, int incx, double beta,
DeviceMemory<double> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, uint64 kl, uint64 ku,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, uint64 kl, uint64 ku,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, float alpha, const DeviceMemory<float> &a,
int lda, const DeviceMemory<float> &x, int incx,
float beta, DeviceMemory<float> *y, int incy) {
return DoBlasInternal(
wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
incx, &beta, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, double alpha, const DeviceMemory<double> &a,
int lda, const DeviceMemory<double> &x, int incx,
double beta, DeviceMemory<double> *y, int incy) {
return DoBlasInternal(
wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
incx, &beta, GpuMemoryMutable(y), incy);
}
bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &x, int incx,
const DeviceMemory<float> &y, int incy,
DeviceMemory<float> *a, int lda) {
return DoBlasInternal(
wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha,
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
}
bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &x, int incx,
const DeviceMemory<double> &y, int incy,
DeviceMemory<double> *a, int lda) {
return DoBlasInternal(
wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha,
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
}
bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the GER operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the GER operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<std::complex<float>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the HER operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<std::complex<double>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the HER operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &ap,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &ap,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
DeviceMemory<std::complex<float>> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
DeviceMemory<std::complex<double>> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x, int incx,
const DeviceMemory<std::complex<float>> &y, int incy,
DeviceMemory<std::complex<float>> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x, int incx,
const DeviceMemory<std::complex<double>> &y, int incy,
DeviceMemory<std::complex<double>> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
uint64 k, float alpha, const DeviceMemory<float> &a,
int lda, const DeviceMemory<float> &x, int incx,
float beta, DeviceMemory<float> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
uint64 k, double alpha, const DeviceMemory<double> &a,
int lda, const DeviceMemory<double> &x, int incx,
double beta, DeviceMemory<double> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha, const DeviceMemory<float> &ap,
const DeviceMemory<float> &x, int incx, float beta,
DeviceMemory<float> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha, const DeviceMemory<double> &ap,
const DeviceMemory<double> &x, int incx, double beta,
DeviceMemory<double> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha, const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha, const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha, const DeviceMemory<float> &x, int incx,
const DeviceMemory<float> &y, int incy,
DeviceMemory<float> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha, const DeviceMemory<double> &x, int incx,
const DeviceMemory<double> &y, int incy,
DeviceMemory<double> *ap) {
LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &x, int incx, float beta,
DeviceMemory<float> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &x, int incx, double beta,
DeviceMemory<double> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha, const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *a, int lda) {
return DoBlasInternal(wrap::rocblas_ssyr, stream,
true /* = pointer_mode_host */,
ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
GpuMemoryMutable(a), lda);
}
bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha, const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *a, int lda) {
return DoBlasInternal(wrap::rocblas_dsyr, stream,
true /* = pointer_mode_host */,
ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
GpuMemoryMutable(a), lda);
}
bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
float alpha, const DeviceMemory<float> &x, int incx,
const DeviceMemory<float> &y, int incy,
DeviceMemory<float> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
double alpha, const DeviceMemory<double> &x, int incx,
const DeviceMemory<double> &y, int incy,
DeviceMemory<double> *a, int lda) {
LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<std::complex<float>> &a,
int lda, DeviceMemory<std::complex<float>> *x,
int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<std::complex<double>> &a,
int lda, DeviceMemory<std::complex<double>> *x,
int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<std::complex<float>> &a,
int lda, DeviceMemory<std::complex<float>> *x,
int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
uint64 k, const DeviceMemory<std::complex<double>> &a,
int lda, DeviceMemory<std::complex<double>> *x,
int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<float> &ap, DeviceMemory<float> *x,
int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<double> &ap,
DeviceMemory<double> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<float>> &ap,
DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<double>> &ap,
DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<float> &ap, DeviceMemory<float> *x,
int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<double> &ap,
DeviceMemory<double> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<float>> &ap,
DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<double>> &ap,
DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, blas::Diagonal diag, uint64 n,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<Eigen::half> &a,
int lda, const DeviceMemory<Eigen::half> &b, int ldb,
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
if (transa == blas::Transpose::kNoTranspose) {
if (lda < static_cast<int64>(m)) {
LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
"precondition violation";
}
} else {
if (lda < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
<< ") (transpose case); precondition violation";
}
}
if (transb == blas::Transpose::kNoTranspose) {
if (ldb < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
<< ") (no transpose case); precondition violation";
}
} else {
if (ldb < static_cast<int64>(n)) {
LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
"precondition violation";
}
}
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
return DoBlasInternal(
wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
reinterpret_cast<const rocblas_half *>(&alpha_half),
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
reinterpret_cast<const rocblas_half *>(&beta_half),
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta,
DeviceMemory<float> *c, int ldc) {
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
"c=%p ldc=%d",
static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
if (transa == blas::Transpose::kNoTranspose) {
if (lda < static_cast<int64>(m)) {
LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
"precondition violation";
}
} else {
if (lda < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
<< ") (transpose case); precondition violation";
}
}
if (transb == blas::Transpose::kNoTranspose) {
if (ldb < static_cast<int64>(k)) {
LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
<< ") (no transpose case); precondition violation";
}
} else {
if (ldb < static_cast<int64>(n)) {
LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
"precondition violation";
}
}
return DoBlasInternal(
wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc) {
return DoBlasInternal(
wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
int incx, float beta, DeviceMemory<float> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool ROCMBlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
int incx, double beta, DeviceMemory<double> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool ROCMBlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
int lda, const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool ROCMBlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
int lda, const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool ROCMBlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool ROCMBlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
int ldc, blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool ROCMBlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool ROCMBlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool ROCMBlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
template <typename T>
bool ROCMBlas::DoBlasGemvWithProfilingImpl(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result) {
// ROCM TODO: properly implement the interface
return false;
}
template <typename T, typename ParamType>
bool ROCMBlas::DoBlasGemmWithProfilingImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
// ROCM TODO: properly implement the interface
return false;
}
template <typename InT, typename OutT, typename CompT>
bool ROCMBlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
const DeviceMemory<InT> &b, int ldb, const CompT &beta,
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
// ROCM TODO: properly implement the interface
return false;
}
bool ROCMBlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
// ROCM TODO: properly implement the interface
return true;
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"int8\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb,
const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"half\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
const HostOrDeviceScalar<std::complex<float>> &beta,
DeviceMemory<std::complex<float>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
const HostOrDeviceScalar<std::complex<double>> &beta,
DeviceMemory<std::complex<double>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"complex<double>\" datatype";
return false;
}
template <typename T>
port::Status ROCMBlas::AllocateStridedBuffer(
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
&raw_ptrs,
int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
Stream *stream,
std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory) {
assert(device_memory != nullptr);
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
bool needs_allocate_strided = false;
for (int i = 1; i < batch_count; ++i) {
uint64_t tmp_batch_stride = raw_ptrs[i] - raw_ptrs[i - 1];
if (tmp_batch_stride != batch_stride) {
needs_allocate_strided = true;
break;
}
}
size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
size_t matrix_batch_byte_size = matrix_byte_size * batch_count;
// No need to do re-allocation, take the short cut and return
if (!needs_allocate_strided) {
*device_memory = DeviceMemory<MAPPED_T>(
DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
return port::Status::OK();
}
if (scratch_allocator != nullptr) {
SE_ASSIGN_OR_RETURN(
DeviceMemory<uint8> batch_matrix_bytes,
scratch_allocator->AllocateBytes(matrix_batch_byte_size));
*device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes);
} else {
assert(temp_memory != nullptr);
SE_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray<MAPPED_T>(
matrix_batch_byte_size));
*device_memory =
DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
}
for (int i = 0; i < batch_count; ++i) {
char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
DeviceMemoryBase src_mem = DeviceMemoryBase(raw_ptrs[i], matrix_byte_size);
DeviceMemoryBase target_mem = DeviceMemoryBase(
device_memory_ptr + i * matrix_byte_size, matrix_byte_size);
bool a_status =
stream->ThenMemcpy(&target_mem, src_mem, matrix_byte_size).ok();
if (!a_status) {
return port::Status(
port::error::INTERNAL,
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
}
}
return port::Status::OK();
}
template <typename T, typename FuncT>
port::Status ROCMBlas::DoBlasGemmBatchedInternal(
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
// Sanity checks before making any further progress
uint64_t batch_stride_a = 0;
uint64_t batch_stride_b = 0;
uint64_t batch_stride_c = 0;
assert(ldc >= m);
batch_stride_c = ldc * n;
if (ROCMBlasTranspose(transa) == rocblas_operation_none) {
assert(lda >= m);
batch_stride_a = lda * k;
} else {
assert(lda >= k);
batch_stride_a = lda * m;
}
if (ROCMBlasTranspose(transb) == rocblas_operation_none) {
assert(ldb >= k);
batch_stride_b = ldb * n;
} else {
assert(ldb >= n);
batch_stride_b = ldb * k;
}
// Allocate local vectors to hold device pointers to matrices
std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
for (int i = 0; i < batch_count; ++i) {
// static_cast does work when converting Eigen::half* to rocblas_half*,
// hence the use of reinterpret_cast
a_raw_ptrs.push_back(
reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
b_raw_ptrs.push_back(
reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
c_raw_ptrs.push_back(
reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
}
DeviceMemory<MAPPED_T> a;
// Make sure the temporary memory are in-scope before the function returns
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
port::Status a_allocation_status =
AllocateStridedBuffer<T>(a_raw_ptrs, batch_count, batch_stride_a,
scratch_allocator, stream, &a_temp, &a);
if (a_allocation_status != port::Status::OK()) {
return a_allocation_status;
}
DeviceMemory<MAPPED_T> b;
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
port::Status b_allocation_status =
AllocateStridedBuffer<T>(b_raw_ptrs, batch_count, batch_stride_b,
scratch_allocator, stream, &b_temp, &b);
if (b_allocation_status != port::Status::OK()) {
return b_allocation_status;
}
DeviceMemory<MAPPED_T> c;
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
port::Status c_allocation_status =
AllocateStridedBuffer<T>(c_raw_ptrs, batch_count, batch_stride_c,
scratch_allocator, stream, &c_temp, &c);
if (c_allocation_status != port::Status::OK()) {
return c_allocation_status;
}
MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
bool ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
m, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
batch_stride_c, batch_count);
if (ok) {
return port::Status::OK();
} else {
return port::Status(port::error::INTERNAL,
"failed BLAS call, see log for details");
}
}
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) {
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
scratch_allocator);
if (!status.ok()) {
LOG(ERROR) << status;
}
return status.ok();
}
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) {
port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
scratch_allocator);
if (!status.ok()) {
LOG(ERROR) << status;
}
return status.ok();
}
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha,
const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
scratch_allocator);
if (!status.ok()) {
LOG(ERROR) << status;
}
return status.ok();
}
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
int lda,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
int ldb, std::complex<float> beta,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
int lda,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
int ldb, std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
float alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
float beta, DeviceMemory<std::complex<float>> *c,
int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
double alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
double beta, DeviceMemory<std::complex<double>> *c,
int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
float beta, DeviceMemory<std::complex<float>> *c,
int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
double beta, DeviceMemory<std::complex<double>> *c,
int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta,
DeviceMemory<float> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
float beta, DeviceMemory<float> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
double alpha, const DeviceMemory<double> &a, int lda,
double beta, DeviceMemory<double> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta,
DeviceMemory<float> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
blas::Transpose trans, uint64 n, uint64 k,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *b, int ldb) {
LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
<< "for the \"float\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *b, int ldb) {
LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
<< "for the \"double\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *b, int ldb) {
LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) {
LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *b, int ldb) {
return DoBlasInternal(
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
lda, GpuMemoryMutable(b), ldb);
}
bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *b, int ldb) {
return DoBlasInternal(
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
lda, GpuMemoryMutable(b), ldb);
}
bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n,
std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
DeviceMemory<std::complex<float>> *b, int ldb) {
LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::UpperLower uplo, blas::Transpose transa,
blas::Diagonal diag, uint64 m, uint64 n,
std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) {
LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
<< "for the \"complex<double>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
int64 stride_c, int batch_count) {
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
return DoBlasInternal(
wrap::rocblas_hgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
reinterpret_cast<const rocblas_half *>(&alpha_half),
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda, stride_a,
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb, stride_b,
reinterpret_cast<const rocblas_half *>(&beta_half),
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc, stride_c,
batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
int batch_count) {
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
n, k, &alpha, GpuMemory(a), lda, stride_a, GpuMemory(b),
ldb, stride_b, &beta, GpuMemoryMutable(c), ldc,
stride_c, batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
int batch_count) {
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
n, k, &alpha, GpuMemory(a), lda, stride_a, GpuMemory(b),
ldb, stride_b, &beta, GpuMemoryMutable(c), ldc,
stride_c, batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
int64 stride_c, int batch_count) {
LOG(ERROR) << "rocBLAS does not currently support the "
"DoBlasGemmStridedBatched operation "
<< "for the \"complex<float>\" datatype";
return false;
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
int64 stride_c, int batch_count) {
LOG(ERROR) << "rocBLAS does not currently support the "
"DoBlasGemmStridedBatched operation "
<< "for the \"complex<double>\" datatype";
return false;
}
port::Status ROCMBlas::GetVersion(string *version) {
return port::UnimplementedError("");
}
} // namespace gpu
void initialize_rocblas() {
auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
if (!rocBlasAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()
->RegisterFactory<PluginRegistry::BlasFactory>(
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
[](internal::StreamExecutorInterface *parent)
-> blas::BlasSupport * {
gpu::GpuExecutor *rocm_executor =
dynamic_cast<gpu::GpuExecutor *>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the "
"rocBLAS "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
if (!blas->Init()) {
// Note: Init() will log a more specific error.
delete blas;
return nullptr;
}
return blas;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocBLAS factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
}
}
} // namespace stream_executor
REGISTER_MODULE_INITIALIZER(register_rocblas,
{ stream_executor::initialize_rocblas(); });