[SE] Implement GemmStridedBatchedWithAlgorithm
Uses extended GEMM API which supports more datatypes (notably int8xint8->int32)
PiperOrigin-RevId: 374783812
Change-Id: I39ffbdc332aed8e59e00b793a0c9bba6a710d4cc
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index aa9289f..f4d20d4 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1113,6 +1113,15 @@
ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
+ virtual port::Status DoBlasGemmStridedBatchedWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+ DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b,
+ DataType type_b, int ldb, int64 stride_b, const void *beta,
+ DeviceMemoryBase *c, DataType type_c, int ldc, int64 stride_c,
+ int batch_count, ComputationType computation_type,
+ AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+
// Computes a batch of matrix-matrix product with general matrices.
// This is a batched version of DoBlasGemm.
// The batched GEMM computes matrix product for each input/output in a, b,
@@ -2035,6 +2044,15 @@
const DeviceMemoryBase &a, int lda, int64 stride_a, \
const DeviceMemoryBase &b, int ldb, int64 stride_b, const void *beta, \
DeviceMemoryBase *c, int ldc, int64 stride_c, int batch_count); \
+ port::Status DoBlasGemmStridedBatchedWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, const void *alpha, \
+ const DeviceMemoryBase &a, blas::DataType type_a, int lda, \
+ int64 stride_a, const DeviceMemoryBase &b, blas::DataType type_b, \
+ int ldb, int64 stride_b, const void *beta, DeviceMemoryBase *c, \
+ blas::DataType type_c, int ldc, int64 stride_c, int batch_count, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
uint64 m, uint64 n, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index e4b2b63..f54f332 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2034,13 +2034,9 @@
#endif
}
-port::Status CUDABlas::DoBlasGemmWithAlgorithm(
- Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
- blas::DataType type_a, int lda, const DeviceMemoryBase &b,
- blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
- blas::DataType type_c, int ldc, blas::ComputationType computation_type,
- blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+static port::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
+ Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
+ blas::DataType type_b) {
if (type_a != type_b) {
return port::InternalError("Types of inputs mismatch");
}
@@ -2091,15 +2087,6 @@
}
}
- std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
- if (output_profile_result != nullptr) {
- timer.reset(new GpuTimer(parent_));
- if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
- return port::InternalError(
- "output_profile_result given, but unable to create a GpuTimer");
- }
- }
-
// Return false if we might be hitting a cuBLAS bug that produces the wrong
// result. See nvbugs/2156201, b/79126339.
#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
@@ -2110,23 +2097,27 @@
"<9.2 bug with m, n, or k >= 2097153. See b/79126339.");
}
#endif
+ return math_type;
+}
- cudaDataType_t cuda_in_type = GetCUDADataType(type_a);
- // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
- // we do the following compile-time check on the default value:
- static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
- // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we
- // essentially reinterpet_cast to __half, which is safe because Eigen::half
- // inherits from __half.
- port::Status st = DoBlasInternalImpl(
- AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
- CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha,
- a.opaque(), cuda_in_type, lda, b.opaque(), cuda_in_type, ldb, beta,
- c->opaque(), GetCUDADataType(type_c), ldc,
- CUDAComputationType(computation_type),
- static_cast<cublasGemmAlgo_t>(algorithm));
+static port::StatusOr<std::unique_ptr<GpuTimer, GpuTimerDeleter>>
+StartGpuTimerForProfile(Stream *stream, GpuExecutor *executor,
+ blas::ProfileResult *output_profile_result) {
+ std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
+ if (output_profile_result) {
+ timer.reset(new GpuTimer(executor));
+ if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
+ return port::InternalError(
+ "output_profile_result given, but unable to create a GpuTimer");
+ }
+ }
+ return timer;
+}
- if (timer != nullptr && st.ok()) {
+static port::Status PopulateProfileFromTimer(
+ GpuTimer *timer, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result, Stream *stream) {
+ if (timer) {
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsGpuStream(stream))) {
@@ -2137,6 +2128,62 @@
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
+ return port::Status::OK();
+}
+
+port::Status CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+ blas::DataType type_a, int lda, const DeviceMemoryBase &b,
+ blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
+ blas::DataType type_c, int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
+ GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
+
+ TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
+ stream, parent_, output_profile_result));
+
+ // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
+ // we do the following compile-time check on the default value:
+ static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
+
+ TF_RETURN_IF_ERROR(DoBlasInternalImpl(
+ AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha,
+ a.opaque(), GetCUDADataType(type_a), lda, b.opaque(),
+ GetCUDADataType(type_b), ldb, beta, c->opaque(), GetCUDADataType(type_c),
+ ldc, CUDAComputationType(computation_type),
+ static_cast<cublasGemmAlgo_t>(algorithm)));
+ TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
+ output_profile_result, stream));
+ return port::Status::OK();
+}
+
+port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+ blas::DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b,
+ blas::DataType type_b, int ldb, int64 stride_b, const void *beta,
+ DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64 stride_c,
+ int batch_count, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
+ GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
+
+ TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
+ stream, parent_, output_profile_result));
+
+ cudaDataType_t cuda_in_type = GetCUDADataType(type_a);
+ port::Status st = DoBlasInternalImpl(
+ AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
+ math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type,
+ ldb, stride_b, beta, c->opaque(), GetCUDADataType(type_c), ldc, stride_c,
+ batch_count, CUDAComputationType(computation_type),
+ static_cast<cublasGemmAlgo_t>(algorithm));
+ TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
+ output_profile_result, stream));
return st;
}
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index 391ae14..74f56f1 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -1827,6 +1827,18 @@
return port::InternalError("Not implemented on ROCm");
}
+port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+ blas::DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b,
+ blas::DataType type_b, int ldb, int64 stride_b, const void *beta,
+ DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64 stride_c,
+ int batch_count, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ // ROCM TODO: properly implement the interface
+ return port::InternalError("Not implemented on ROCm");
+}
+
bool ROCMBlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
// ROCM TODO: properly implement the interface
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index d29e238..c95383d 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1470,35 +1470,10 @@
DeviceMemory<OutputType> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
- static_assert(std::is_same<InputType, Eigen::half>::value ||
- std::is_same<InputType, float>::value ||
- std::is_same<InputType, double>::value ||
- std::is_same<InputType, int8>::value ||
- std::is_same<InputType, std::complex<float>>::value ||
- std::is_same<InputType, std::complex<double>>::value,
- "The only buffer types supported are: Eigen::half, float, "
- "double, int8, std::complex<float> and std::complex<double>");
- static_assert(
- std::is_same<InputType, OutputType>::value ||
- (std::is_same<InputType, int8>::value &&
- std::is_same<OutputType, int32>::value),
- "Input and output buffer types should be the same unless input is "
- "int8 and output is int32");
- static_assert(std::is_same<ConstantType, OutputType>::value ||
- (std::is_same<ConstantType, float>::value &&
- std::is_same<OutputType, Eigen::half>::value),
- "Constant and output types should match");
- blas::ComputationType expected_computation_type =
- blas::ToComputationType<ConstantType>::value;
- if (expected_computation_type != computation_type &&
- !(computation_type == blas::ComputationType::kF32 &&
- expected_computation_type == blas::ComputationType::kF16)) {
- return port::InternalError(absl::StrCat(
- "Alpha/beta type and computation type have to match, got ",
- blas::ComputationTypeString(computation_type),
- " for computation type, expected: ",
- blas::ComputationTypeString(expected_computation_type)));
- }
+ TF_RETURN_IF_ERROR(
+ CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
+ computation_type));
+
blas::BlasSupport *blas = parent()->AsBlas();
if (!blas) {
return port::InternalError(
@@ -1525,6 +1500,38 @@
return st;
}
+ template <typename InputType, typename OutputType, typename ConstantType>
+ port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
+ int64 stride_a, const DeviceMemory<InputType> &b, int ldb, int64 stride_b,
+ ConstantType beta, DeviceMemory<OutputType> *c, int ldc, int64 stride_c,
+ int batch_count, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ TF_RETURN_IF_ERROR(
+ CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
+ computation_type));
+
+ blas::BlasSupport *blas = parent()->AsBlas();
+ if (!blas) {
+ return port::InternalError(
+ "Attempting to perform BLAS operation using "
+ "StreamExecutor without BLAS support");
+ }
+ port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
+ this, transa, transb, m, n, k, &alpha, a,
+ blas::ToDataType<InputType>::value, stride_a, lda, b,
+ blas::ToDataType<InputType>::value, ldb, stride_b, &beta, c,
+ blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
+ computation_type, algorithm, output_profile_result);
+ if (output_profile_result) {
+ // The error is recorded in the profile.
+ return port::Status::OK();
+ }
+ return st;
+ }
+
// See BlasSupport::DoBlasGemmBatched.
Stream &ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
@@ -2162,6 +2169,42 @@
friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
friend class ocl::CLBlas; // for parent_.
+ // Checks whether types match before a call to extended BLAS version.
+ template <typename InputType, typename OutputType, typename ConstantType>
+ port::Status CheckTypesForExtendedBlas(
+ blas::ComputationType computation_type) {
+ static_assert(std::is_same<InputType, Eigen::half>::value ||
+ std::is_same<InputType, float>::value ||
+ std::is_same<InputType, double>::value ||
+ std::is_same<InputType, int8>::value ||
+ std::is_same<InputType, std::complex<float>>::value ||
+ std::is_same<InputType, std::complex<double>>::value,
+ "The only buffer types supported are: Eigen::half, float, "
+ "double, int8, std::complex<float> and std::complex<double>");
+ static_assert(
+ std::is_same<InputType, OutputType>::value ||
+ (std::is_same<InputType, int8>::value &&
+ std::is_same<OutputType, int32>::value),
+ "Input and output buffer types should be the same unless input is "
+ "int8 and output is int32");
+ static_assert(std::is_same<ConstantType, OutputType>::value ||
+ (std::is_same<ConstantType, float>::value &&
+ std::is_same<OutputType, Eigen::half>::value),
+ "Constant and output types should match");
+ blas::ComputationType expected_computation_type =
+ blas::ToComputationType<ConstantType>::value;
+ if (expected_computation_type != computation_type &&
+ !(computation_type == blas::ComputationType::kF32 &&
+ expected_computation_type == blas::ComputationType::kF16)) {
+ return port::InternalError(absl::StrCat(
+ "Alpha/beta type and computation type have to match, got ",
+ blas::ComputationTypeString(computation_type),
+ " for computation type, expected: ",
+ blas::ComputationTypeString(expected_computation_type)));
+ }
+ return port::Status::OK();
+ }
+
bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
absl::ReaderMutexLock lock(&mu_);
return !status_.ok();