blob: 0497b917c955b4aae79422c6a9e073683c4ed2bc [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
// library capabilities, and is only included into ROCM implementation code --
// it will not introduce rocm headers into other code.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
#include "absl/synchronization/mutex.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/temporary_device_memory.h"
namespace stream_executor {
class Stream;
namespace gpu {
// Type conversion helper that helps to map non-rocblas types to rocblas types
// Right now, it only converts the Eigen::half type to rocblas_half type
template <typename T>
struct RocBlasTypeConversionHelper {
using mapped_type = T;
};
template <>
struct RocBlasTypeConversionHelper<Eigen::half> {
using mapped_type = rocblas_half;
};
// Opaque and unique identifier for the rocBLAS plugin.
extern const PluginId kRocBlasPlugin;
class GpuExecutor;
// BLAS plugin for ROCM platform via rocBLAS library.
//
// This satisfies the platform-agnostic BlasSupport interface.
//
// Note that the rocBLAS handle that this encapsulates is implicitly tied to the
// context (and, as a result, the device) that the parent GpuExecutor is tied
// to. This simply happens as an artifact of creating the rocBLAS handle when a
// ROCM context is active.
//
// Thread-safe post-initialization.
class ROCMBlas : public blas::BlasSupport {
public:
explicit ROCMBlas(GpuExecutor *parent);
// Allocates a rocBLAS handle.
bool Init();
// Releases the rocBLAS handle, if present.
~ROCMBlas() override;
TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
private:
// Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
//
// rocBLAS is stateful, and only be associated with one stream (in order to
// enqueue dispatch) at a given time. As a result, this generally must be
// invoked before calling into rocBLAS.
bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// A helper function that calls the real rocBLAS function together with error
// handling.
//
// rocblas_func: rocBLAS function pointer.
// rocblas_name: rocBLAS function name.
// stream: Stream to enqueue the BLAS operation onto.
// pointer_mode_host: Indicate if the pointer to a scalar value is from host
// (true) or device (false).
// err_on_failure: Whether to print an error if the rocBLAS function
// fails. args: Arguments of rocBLAS function.
template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
Args... args);
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
template <typename FuncT, typename... Args>
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
/*err_on_failure=*/true, args...);
}
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
/*err_on_failure=*/false, args...);
}
// A helper allocation function to convert raw pointers memory layout to
// strided flavor
template <typename T>
port::Status AllocateStridedBuffer(
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
&raw_ptrs,
int batch_count, uint64_t batch_stride,
ScratchAllocator *scratch_allocator, Stream *stream,
std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory);
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.
//
// Note: This function is implemented using gemm_strided_batched interface,
// NOT gemm_batched interface, because rocblas do not support it. As a
// result, if the passed in batch matrix are not allocated in strided batched
// format, it might end up in non-trivial amount of memory allocation and
// copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched
// interface.
//
// In most use cases, batch matrix do get allocated in strided manner, making
// calling this interface equivalent with DoBlasGemmStridedBatched. The only
// use case we see so far that violates this observation is when batch
// matrix is created by broadcasting from a smaller matrix. When it happens,
// It will take advantage of the AllocateStridedBuffer subroutine to
// reallocate the memory layout to be strided batched.
template <typename T, typename FuncT>
port::Status DoBlasGemmBatchedInternal(
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
int ldc, int batch_count, ScratchAllocator *scratch_allocator);
// Helper function for implementing DoBlasGemmWithAlgorithm.
//
// We take alpha and beta by const reference because T might be Eigen::half,
// and we want to avoid pulling in a dependency on Eigen. When we pass the
// references to rocBLAS, we essentially reinterpret_cast to __half, which is
// safe because Eigen::half inherits from __half.
template <typename InT, typename OutT, typename CompT>
bool DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemmWithProfiling.
template <typename T, typename ParamType>
bool DoBlasGemmWithProfilingImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemvWithProfiling.
template <typename T>
bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
uint64 m, uint64 n, const T &alpha,
const DeviceMemory<T> &a, int lda,
const DeviceMemory<T> &x, int incx,
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result);
// mutex that guards the rocBLAS handle for this device.
absl::Mutex mu_;
// GpuExecutor which instantiated this ROCMBlas.
// Immutable post-initialization.
GpuExecutor *parent_;
// rocBLAS library handle on the device.
rocblas_handle blas_ GUARDED_BY(mu_);
SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
};
} // namespace gpu
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_