renaming cu* names to gpu* names within the cusolver_context.[h,cc] files
diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
index 6c6a4a7..766e718 100644
--- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
@@ -33,7 +33,7 @@
 
 static tensorflow::mutex contexts_mu(tensorflow::LINKER_INITIALIZED);
 static auto contexts =
-    new absl::flat_hash_map<se::Stream*, CusolverContext> TF_GUARDED_BY(
+    new absl::flat_hash_map<se::Stream*, GpusolverContext> TF_GUARDED_BY(
         contexts_mu);
 
 CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
@@ -61,13 +61,13 @@
        << " workspace=" << workspace_buffer_.ToString()
        << " info=" << info_buffer_.ToString();
 
-  CusolverContext* context;
+  GpusolverContext* context;
   {
     tensorflow::mutex_lock lock(contexts_mu);
-    auto result = contexts->emplace(params.stream, CusolverContext());
+    auto result = contexts->emplace(params.stream, GpusolverContext());
     if (result.second) {
       TF_ASSIGN_OR_RETURN(result.first->second,
-                          CusolverContext::Create(params.stream));
+                          GpusolverContext::Create(params.stream));
     }
     context = &result.first->second;
   }
diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc
index 63f0800..9ee7a53 100644
--- a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc
+++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc
@@ -17,10 +17,6 @@
 
 #include "tensorflow/compiler/xla/util.h"
 
-#if defined(TENSORFLOW_USE_ROCM)
-#include "rocm/include/hip/hip_complex.h"
-#endif
-
 namespace xla {
 namespace gpu {
 
@@ -28,36 +24,51 @@
 
 // Type traits to get CUDA complex types from std::complex<T>.
 template <typename T>
-struct CUDAComplexT {
+struct GpuComplexT {
   typedef T type;
 };
 #if !defined(TENSORFLOW_USE_ROCM)
+
+using gpuStream_t = cudaStream_t;
+
+#define gpusolverCreate cusolverDnCreate
+#define gpusolverSetStream cusolverDnSetStream
+#define gpusolverDestroy cusolverDnDestroy
+
 template <>
-struct CUDAComplexT<std::complex<float>> {
+struct GpuComplexT<std::complex<float>> {
   typedef cuComplex type;
 };
 template <>
-struct CUDAComplexT<std::complex<double>> {
+struct GpuComplexT<std::complex<double>> {
   typedef cuDoubleComplex type;
 };
+
 #else
+
+using gpuStream_t = hipStream_t;
+
+#define gpusolverCreate rocblas_create_handle
+#define gpusolverSetStream rocblas_set_stream
+#define gpusolverDestroy rocblas_destroy_handle
+
 template <>
-struct CUDAComplexT<std::complex<float>> {
+struct GpuComplexT<std::complex<float>> {
   typedef rocblas_float_complex type;
 };
 template <>
-struct CUDAComplexT<std::complex<double>> {
+struct GpuComplexT<std::complex<double>> {
   typedef rocblas_double_complex type;
 };
 #endif
 
 template <typename T>
-inline typename CUDAComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
-  return static_cast<typename CUDAComplexT<T>::type*>(p.opaque());
+inline typename GpuComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
+  return static_cast<typename GpuComplexT<T>::type*>(p.opaque());
 }
 
 #if !defined(TENSORFLOW_USE_ROCM)
-cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) {
+cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
   switch (uplo) {
     case se::blas::UpperLower::kUpper:
       return CUBLAS_FILL_MODE_UPPER;
@@ -69,7 +80,7 @@
 }
 
 // Converts a cuSolver status to a Status.
-Status CusolverStatusToStatus(cusolverStatus_t status) {
+Status GpusolverStatusToStatus(cusolverStatus_t status) {
   switch (status) {
     case CUSOLVER_STATUS_SUCCESS:
       return Status::OK();
@@ -100,7 +111,7 @@
   }
 }
 #else
-rocblas_fill CUDABlasUpperLower(se::blas::UpperLower uplo) {
+rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) {
   switch (uplo) {
     case se::blas::UpperLower::kUpper:
       return rocblas_fill_upper;
@@ -112,7 +123,7 @@
 }
 
 // Converts a cuSolver status to a Status.
-Status CusolverStatusToStatus(rocblas_status status) {
+Status GpusolverStatusToStatus(rocblas_status status) {
   switch (status) {
     case rocblas_status_success:
       return Status::OK();
@@ -148,77 +159,47 @@
 
 }  // namespace
 
-#if !defined(TENSORFLOW_USE_ROCM)
-StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
-  cusolverDnHandle_t handle;
-  TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle)));
-  CusolverContext context(stream, handle);
+StatusOr<GpusolverContext> GpusolverContext::Create(se::Stream* stream) {
+  gpusolverHandle_t handle;
+  TF_RETURN_IF_ERROR(GpusolverStatusToStatus(gpusolverCreate(&handle)));
+  GpusolverContext context(stream, handle);
 
   if (stream) {
     // StreamExecutor really should just expose the Cuda stream to clients...
-    const cudaStream_t* cuda_stream =
-        CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
+    const gpuStream_t* gpu_stream =
+        CHECK_NOTNULL(reinterpret_cast<const gpuStream_t*>(
             stream->implementation()->GpuStreamMemberHack()));
     TF_RETURN_IF_ERROR(
-        CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream)));
+        GpusolverStatusToStatus(gpusolverSetStream(handle, *gpu_stream)));
   }
 
   return std::move(context);
 }
-#else
-StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
-  cusolverDnHandle_t handle;
-  TF_RETURN_IF_ERROR(CusolverStatusToStatus(rocblas_create_handle(&handle)));
-  CusolverContext context(stream, handle);
 
-  if (stream) {
-    // StreamExecutor really should just expose the Cuda stream to clients...
-    const hipStream_t* hip_stream =
-        CHECK_NOTNULL(reinterpret_cast<const hipStream_t*>(
-            stream->implementation()->GpuStreamMemberHack()));
-    TF_RETURN_IF_ERROR(
-        CusolverStatusToStatus(rocblas_set_stream(handle, *hip_stream)));
-  }
-
-  return std::move(context);
-}
-#endif
-
-CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle)
+GpusolverContext::GpusolverContext(se::Stream* stream, gpusolverHandle_t handle)
     : stream_(stream), handle_(handle) {}
 
-CusolverContext::CusolverContext(CusolverContext&& other) {
+GpusolverContext::GpusolverContext(GpusolverContext&& other) {
   handle_ = other.handle_;
   stream_ = other.stream_;
   other.handle_ = nullptr;
   other.stream_ = nullptr;
 }
 
-CusolverContext& CusolverContext::operator=(CusolverContext&& other) {
+GpusolverContext& GpusolverContext::operator=(GpusolverContext&& other) {
   std::swap(handle_, other.handle_);
   std::swap(stream_, other.stream_);
   return *this;
 }
 
-#if !defined(TENSORFLOW_USE_ROCM)
-CusolverContext::~CusolverContext() {
+GpusolverContext::~GpusolverContext() {
   if (handle_) {
-    Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_));
+    Status status = GpusolverStatusToStatus(gpusolverDestroy(handle_));
     if (!status.ok()) {
-      LOG(ERROR) << "cusolverDnDestroy failed: " << status;
+      LOG(ERROR) << "gpusolverDestroy failed: " << status;
     }
   }
 }
-#else
-CusolverContext::~CusolverContext() {
-  if (handle_) {
-    Status status = CusolverStatusToStatus(rocblas_destroy_handle(handle_));
-    if (!status.ok()) {
-      LOG(ERROR) << "cusolverDnDestroy failed: " << status;
-    }
-  }
-}
-#endif
 
 #if !defined(TENSORFLOW_USE_ROCM)
 #define CALL_LAPACK_TYPES(m) \
@@ -238,30 +219,30 @@
 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
 // buffers to cuSolver buffer size methods and this will be a documented
 // behavior in a future cuSolver release.
-StatusOr<int64> CusolverContext::PotrfBufferSize(PrimitiveType type,
-                                                 se::blas::UpperLower uplo,
-                                                 int n, int lda) {
+StatusOr<int64> GpusolverContext::PotrfBufferSize(PrimitiveType type,
+                                                  se::blas::UpperLower uplo,
+                                                  int n, int lda) {
 #if !defined(TENSORFLOW_USE_ROCM)
   int size = -1;
   switch (type) {
     case F32: {
-      TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnSpotrf_bufferSize(
-          handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+      TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnSpotrf_bufferSize(
+          handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
       break;
     }
     case F64: {
-      TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnDpotrf_bufferSize(
-          handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+      TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnDpotrf_bufferSize(
+          handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
       break;
     }
     case C64: {
-      TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCpotrf_bufferSize(
-          handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+      TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnCpotrf_bufferSize(
+          handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
       break;
     }
     case C128: {
-      TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnZpotrf_bufferSize(
-          handle(), CUDABlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
+      TF_RETURN_IF_ERROR(GpusolverStatusToStatus(cusolverDnZpotrf_bufferSize(
+          handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
       break;
     }
     default:
@@ -277,22 +258,22 @@
 #if !defined(TENSORFLOW_USE_ROCM)
 #define POTRF_INSTANCE(T, type_prefix)                                    \
   template <>                                                             \
-  Status CusolverContext::Potrf<T>(                                       \
+  Status GpusolverContext::Potrf<T>(                                      \
       se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda,   \
       se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
-    return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)(       \
-        handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda,   \
+    return GpusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)(      \
+        handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda,    \
         ToDevicePointer(workspace), workspace.ElementCount(),             \
         ToDevicePointer(lapack_info)));                                   \
   }
 #else
 #define POTRF_INSTANCE(T, type_prefix)                                    \
   template <>                                                             \
-  Status CusolverContext::Potrf<T>(                                       \
+  Status GpusolverContext::Potrf<T>(                                      \
       se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda,   \
       se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
-    return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)(       \
-        handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda,   \
+    return GpusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)(      \
+        handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(A), lda,    \
         ToDevicePointer(lapack_info)));                                   \
   }
 #endif
diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h
index 65b0451..4dbb7a2 100644
--- a/tensorflow/compiler/xla/service/gpu/cusolver_context.h
+++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h
@@ -20,9 +20,10 @@
 
 #if !TENSORFLOW_USE_ROCM
 #include "third_party/gpus/cuda/include/cusolverDn.h"
+using gpusolverHandle_t = cusolverDnHandle_t;
 #else
 #include "tensorflow/stream_executor/rocm/rocsolver_wrapper.h"
-typedef rocsolver_handle cusolverDnHandle_t;
+using gpusolverHandle_t = rocblas_handle;
 #endif
 
 #include "tensorflow/compiler/xla/statusor.h"
@@ -35,18 +36,18 @@
 namespace xla {
 namespace gpu {
 
-class CusolverContext {
+class GpusolverContext {
  public:
   // stream may be nullptr, in which case the context can only be used for
   // buffer size queries.
-  static StatusOr<CusolverContext> Create(se::Stream* stream);
-  CusolverContext() = default;
-  ~CusolverContext();
+  static StatusOr<GpusolverContext> Create(se::Stream* stream);
+  GpusolverContext() = default;
+  ~GpusolverContext();
 
-  CusolverContext(const CusolverContext&) = delete;
-  CusolverContext(CusolverContext&&);
-  CusolverContext& operator=(const CusolverContext&) = delete;
-  CusolverContext& operator=(CusolverContext&&);
+  GpusolverContext(const GpusolverContext&) = delete;
+  GpusolverContext(GpusolverContext&&);
+  GpusolverContext& operator=(const GpusolverContext&) = delete;
+  GpusolverContext& operator=(GpusolverContext&&);
 
   // Computes the Cholesky factorization A = L * L^T for a single matrix.
   // Returns Status::OK() if the kernel was launched successfully. See:
@@ -66,19 +67,19 @@
                                   int n, int lda);
 
  private:
-  CusolverContext(se::Stream* stream, cusolverDnHandle_t handle);
+  GpusolverContext(se::Stream* stream, gpusolverHandle_t handle);
 
-  cusolverDnHandle_t handle() const { return handle_; }
+  gpusolverHandle_t handle() const { return handle_; }
 
   se::Stream* stream_ = nullptr;
-  cusolverDnHandle_t handle_ = nullptr;
+  gpusolverHandle_t handle_ = nullptr;
 };
 
 #define CALL_LAPACK_TYPES(m) \
   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
 #define POTRF_INSTANCE(T, type_prefix)                                  \
   template <>                                                           \
-  Status CusolverContext::Potrf<T>(                                     \
+  Status GpusolverContext::Potrf<T>(                                    \
       se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
       se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace);
 CALL_LAPACK_TYPES(POTRF_INSTANCE);
diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
index 73a7bcc..ccc3392 100644
--- a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
@@ -46,7 +46,7 @@
             shape->mutable_layout()->mutable_minor_to_major()->at(1));
 }
 
-StatusOr<HloInstruction*> CreateCholesky(CusolverContext* context,
+StatusOr<HloInstruction*> CreateCholesky(GpusolverContext* context,
                                          HloInstruction* operand,
                                          const CholeskyOptions& options,
                                          const OpMetadata& metadata) {
@@ -129,7 +129,7 @@
 }  // namespace
 
 // Tries to rewrite a single convolution into a call to cudnn.
-StatusOr<bool> RunOnInstruction(CusolverContext* context,
+StatusOr<bool> RunOnInstruction(GpusolverContext* context,
                                 HloInstruction* instruction) {
   if (instruction->opcode() != HloOpcode::kCholesky) {
     return false;
@@ -162,8 +162,8 @@
     return false;
   }
 
-  TF_ASSIGN_OR_RETURN(CusolverContext context,
-                      CusolverContext::Create(/*stream=*/nullptr));
+  TF_ASSIGN_OR_RETURN(GpusolverContext context,
+                      GpusolverContext::Create(/*stream=*/nullptr));
 
   bool changed = false;
   for (HloInstruction* instruction : cusolver_calls) {