renaming cu* names to gpu* names within the cusolver_context.[h,cc] files
diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
index 6c6a4a7..766e718 100644
--- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
@@ -33,7 +33,7 @@
static tensorflow::mutex contexts_mu(tensorflow::LINKER_INITIALIZED);
static auto contexts =
- new absl::flat_hash_map<se::Stream*, CusolverContext> TF_GUARDED_BY(
+ new absl::flat_hash_map<se::Stream*, GpusolverContext> TF_GUARDED_BY(
contexts_mu);
CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
@@ -61,13 +61,13 @@
<< " workspace=" << workspace_buffer_.ToString()
<< " info=" << info_buffer_.ToString();
- CusolverContext* context;
+ GpusolverContext* context;
{
tensorflow::mutex_lock lock(contexts_mu);
- auto result = contexts->emplace(params.stream, CusolverContext());
+ auto result = contexts->emplace(params.stream, GpusolverContext());
if (result.second) {
TF_ASSIGN_OR_RETURN(result.first->second,
- CusolverContext::Create(params.stream));
+ GpusolverContext::Create(params.stream));
}
context = &result.first->second;
}
diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc
index 63f0800..9ee7a53 100644
--- a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc
+++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc
@@ -17,10 +17,6 @@
#include "tensorflow/compiler/xla/util.h"
-#if defined(TENSORFLOW_USE_ROCM)
-#include "rocm/include/hip/hip_complex.h"
-#endif
-
namespace xla {
namespace gpu {
@@ -28,36 +24,51 @@
// Type traits to get CUDA complex types from std::complex<T>.
template <typename T>
-struct CUDAComplexT {
+struct GpuComplexT {
typedef T type;
};
#if !defined(TENSORFLOW_USE_ROCM)
+
+using gpuStream_t = cudaStream_t;
+
+#define gpusolverCreate cusolverDnCreate
+#define gpusolverSetStream cusolverDnSetStream
+#define gpusolverDestroy cusolverDnDestroy
+
template <>
-struct CUDAComplexT<std::complex<float>> {
+struct GpuComplexT<std::complex<float>> {
typedef cuComplex type;
};
template <>
-struct CUDAComplexT<std::complex<double>> {
+struct GpuComplexT<std::complex<double>> {
typedef cuDoubleComplex type;
};
+
#else
+
+using gpuStream_t = hipStream_t;
+
+#define gpusolverCreate rocblas_create_handle
+#define gpusolverSetStream rocblas_set_stream
+#define gpusolverDestroy rocblas_destroy_handle
+
template <>
-struct CUDAComplexT<std::complex<float>> {
+struct GpuComplexT<std::complex<float>> {
typedef rocblas_float_complex type;
};
template <>
-struct CUDAComplexT<std::complex<double>> {
+struct GpuComplexT<std::complex<double>> {
typedef rocblas_double_complex type;
};
#endif
template <typename T>
-inline typename CUDAComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
- return static_cast<typename CUDAComplexT<T>::type*>(p.opaque());
+inline typename GpuComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
+ return static_cast<typename GpuComplexT<T>::type*>(p.opaque());
}
#if !defined(TENSORFLOW_USE_ROCM)
-cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) {
+cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
switch (uplo) {
case se::blas::UpperLower::kUpper:
return CUBLAS_FILL_MODE_UPPER;
@@ -69,7 +80,7 @@
}
// Converts a cuSolver status to a Status.
-Status CusolverStatusToStatus(cusolverStatus_t status) {
+Status GpusolverStatusToStatus(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return Status::OK();
@@ -100,7 +111,7 @@
}
}
#else
-rocblas_fill CUDABlasUpperLower(se::blas::UpperLower uplo) {
+rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) {
switch (uplo) {
case se::blas::UpperLower::kUpper:
return rocblas_fill_upper;
@@ -112,7 +123,7 @@
}
// Converts a cuSolver status to a Status.
-Status CusolverStatusToStatus(rocblas_status status) {
+Status GpusolverStatusToStatus(rocblas_status status) {
switch (status) {
case rocblas_status_success:
return Status::OK();
@@ -148,77 +159,47 @@
} // namespace
-#if !defined(TENSORFLOW_USE_ROCM)
-StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
- cusolverDnHandle_t handle;
- TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle)));
- CusolverContext context(stream, handle);
+StatusOr<GpusolverContext> GpusolverContext::Create(se::Stream* stream) {
+ gpusolverHandle_t handle;
+ TF_RETURN_IF_ERROR(GpusolverStatusToStatus(gpusolverCreate(&handle)));
+ GpusolverContext context(stream, handle);
if (stream) {
// StreamExecutor really should just expose the Cuda stream to clients...
- const cudaStream_t* cuda_stream =
- CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
+ const gpuStream_t* gpu_stream =
+ CHECK_NOTNULL(reinterpret_cast<const gpuStream_t*>(
stream->implementation()->GpuStreamMemberHack()));
TF_RETURN_IF_ERROR(
- CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream)));
+ GpusolverStatusToStatus(gpusolverSetStream(handle, *gpu_stream)));
}
return std::move(context);
}
-#else
-StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
- cusolverDnHandle_t handle;
- TF_RETURN_IF_ERROR(CusolverStatusToStatus(rocblas_create_handle(&handle)));
- CusolverContext context(stream, handle);
- if (stream) {
- // StreamExecutor really should just expose the Cuda stream to clients...
- const hipStream_t* hip_stream =
- CHECK_NOTNULL(reinterpret_cast<const hipStream_t*>(
- stream->implementation()->GpuStreamMemberHack()));
- TF_RETURN_IF_ERROR(
- CusolverStatusToStatus(rocblas_set_stream(handle, *hip_stream)));
- }
-
- return std::move(context);
-}
-#endif
-
-CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle)
+GpusolverContext::GpusolverContext(se::Stream* stream, gpusolverHandle_t handle)
: stream_(stream), handle_(handle) {}
-CusolverContext::CusolverContext(CusolverContext&& other) {
+GpusolverContext::GpusolverContext(GpusolverContext&& other) {
handle_ = other.handle_;
stream_ = other.stream_;
other.handle_ = nullptr;
other.stream_ = nullptr;
}
-CusolverContext& CusolverContext::operator=(CusolverContext&& other) {
+GpusolverContext& GpusolverContext::operator=(GpusolverContext&& other) {
std::swap(handle_, other.handle_);
std::swap(stream_, other.stream_);
return *this;
}
-#if !defined(TENSORFLOW_USE_ROCM)
-CusolverContext::~CusolverContext() {
+GpusolverContext::~GpusolverContext() {
if (handle_) {
- Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_));
+ Status status = GpusolverStatusToStatus(gpusolverDestroy(handle_));
if (!status.ok()) {
- LOG(ERROR) << "cusolverDnDestroy failed: " << status;
+ LOG(ERROR) << "gpusolverDestroy failed: " << status;
}
}
}
-#else
-CusolverContext::~CusolverContext() {
- if (handle_) {
- Status status = CusolverStatusToStatus(rocblas_destroy_handle(handle_));
- if (!status.ok()) {
- LOG(ERROR) << "cusolverDnDestroy failed: " << status;
- }
- }
-}
-#endif
#if !defined(TENSORFLOW_USE_ROCM)
#define CALL_LAPACK_TYPES(m) \
@@ -238,30 +219,30 @@
// Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
// buffers to cuSolver buffer size methods and this will be a documented
// behavior in a future cuSolver release.
-StatusOr<int64> CusolverContext::PotrfBufferSize(PrimitiveType type,
- se::blas::UpperLower uplo,
- int n, int lda) {
+StatusOr<int64> GpusolverContext::PotrfBufferSize(PrimitiveType type,
+ se::blas::UpperLower uplo,
+ int n, int lda) {
#if !defined(TENSORFLOW_USE_ROCM)
int size = -1;
switch (type) {
case F32: {
- TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnSpotrf_bufferSize(
- handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+ TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnSpotrf_bufferSize(
+ handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
break;
}
case F64: {
- TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnDpotrf_bufferSize(
- handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+ TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnDpotrf_bufferSize(
+ handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
break;
}
case C64: {
- TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCpotrf_bufferSize(
- handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+ TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnCpotrf_bufferSize(
+ handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
break;
}
case C128: {
- TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnZpotrf_bufferSize(
- handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+ TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnZpotrf_bufferSize(
+ handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
break;
}
default:
@@ -277,22 +258,22 @@
#if !defined(TENSORFLOW_USE_ROCM)
#define POTRF_INSTANCE(T, type_prefix) \
template <> \
- Status CusolverContext::Potrf<T>( \
+ Status GpusolverContext::Potrf<T>( \
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
- return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
- handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
+ return GpusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
+ handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
ToDevicePointer(workspace), workspace.ElementCount(), \
ToDevicePointer(lapack_info))); \
}
#else
#define POTRF_INSTANCE(T, type_prefix) \
template <> \
- Status CusolverContext::Potrf<T>( \
+ Status GpusolverContext::Potrf<T>( \
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
- return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
- handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
+ return GpusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
+ handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
ToDevicePointer(lapack_info))); \
}
#endif
diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h
index 65b0451..4dbb7a2 100644
--- a/tensorflow/compiler/xla/service/gpu/cusolver_context.h
+++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h
@@ -20,9 +20,10 @@
#if !TENSORFLOW_USE_ROCM
#include "third_party/gpus/cuda/include/cusolverDn.h"
+using gpusolverHandle_t = cusolverDnHandle_t;
#else
#include "tensorflow/stream_executor/rocm/rocsolver_wrapper.h"
-typedef rocsolver_handle cusolverDnHandle_t;
+using gpusolverHandle_t = rocblas_handle;
#endif
#include "tensorflow/compiler/xla/statusor.h"
@@ -35,18 +36,18 @@
namespace xla {
namespace gpu {
-class CusolverContext {
+class GpusolverContext {
public:
// stream may be nullptr, in which case the context can only be used for
// buffer size queries.
- static StatusOr<CusolverContext> Create(se::Stream* stream);
- CusolverContext() = default;
- ~CusolverContext();
+ static StatusOr<GpusolverContext> Create(se::Stream* stream);
+ GpusolverContext() = default;
+ ~GpusolverContext();
- CusolverContext(const CusolverContext&) = delete;
- CusolverContext(CusolverContext&&);
- CusolverContext& operator=(const CusolverContext&) = delete;
- CusolverContext& operator=(CusolverContext&&);
+ GpusolverContext(const GpusolverContext&) = delete;
+ GpusolverContext(GpusolverContext&&);
+ GpusolverContext& operator=(const GpusolverContext&) = delete;
+ GpusolverContext& operator=(GpusolverContext&&);
// Computes the Cholesky factorization A = L * L^T for a single matrix.
// Returns Status::OK() if the kernel was launched successfully. See:
@@ -66,19 +67,19 @@
int n, int lda);
private:
- CusolverContext(se::Stream* stream, cusolverDnHandle_t handle);
+ GpusolverContext(se::Stream* stream, gpusolverHandle_t handle);
- cusolverDnHandle_t handle() const { return handle_; }
+ gpusolverHandle_t handle() const { return handle_; }
se::Stream* stream_ = nullptr;
- cusolverDnHandle_t handle_ = nullptr;
+ gpusolverHandle_t handle_ = nullptr;
};
#define CALL_LAPACK_TYPES(m) \
m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
#define POTRF_INSTANCE(T, type_prefix) \
template <> \
- Status CusolverContext::Potrf<T>( \
+ Status GpusolverContext::Potrf<T>( \
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace);
CALL_LAPACK_TYPES(POTRF_INSTANCE);
diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
index 73a7bcc..ccc3392 100644
--- a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
@@ -46,7 +46,7 @@
shape->mutable_layout()->mutable_minor_to_major()->at(1));
}
-StatusOr<HloInstruction*> CreateCholesky(CusolverContext* context,
+StatusOr<HloInstruction*> CreateCholesky(GpusolverContext* context,
HloInstruction* operand,
const CholeskyOptions& options,
const OpMetadata& metadata) {
@@ -129,7 +129,7 @@
} // namespace
// Tries to rewrite a single convolution into a call to cudnn.
-StatusOr<bool> RunOnInstruction(CusolverContext* context,
+StatusOr<bool> RunOnInstruction(GpusolverContext* context,
HloInstruction* instruction) {
if (instruction->opcode() != HloOpcode::kCholesky) {
return false;
@@ -162,8 +162,8 @@
return false;
}
- TF_ASSIGN_OR_RETURN(CusolverContext context,
- CusolverContext::Create(/*stream=*/nullptr));
+ TF_ASSIGN_OR_RETURN(GpusolverContext context,
+ GpusolverContext::Create(/*stream=*/nullptr));
bool changed = false;
for (HloInstruction* instruction : cusolver_calls) {