[SE] Implement GemmStridedBatchedWithAlgorithm

Uses extended GEMM API which supports more datatypes (notably int8xint8->int32)

PiperOrigin-RevId: 374783812
Change-Id: I39ffbdc332aed8e59e00b793a0c9bba6a710d4cc
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index aa9289f..f4d20d4 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1113,6 +1113,15 @@
       ComputationType computation_type, AlgorithmType algorithm,
       ProfileResult *output_profile_result) = 0;
 
+  virtual port::Status DoBlasGemmStridedBatchedWithAlgorithm(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+      uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+      DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b,
+      DataType type_b, int ldb, int64 stride_b, const void *beta,
+      DeviceMemoryBase *c, DataType type_c, int ldc, int64 stride_c,
+      int batch_count, ComputationType computation_type,
+      AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+
   // Computes a batch of matrix-matrix product with general matrices.
   // This is a batched version of DoBlasGemm.
   // The batched GEMM computes matrix product for each input/output in a, b,
@@ -2035,6 +2044,15 @@
       const DeviceMemoryBase &a, int lda, int64 stride_a,                      \
       const DeviceMemoryBase &b, int ldb, int64 stride_b, const void *beta,    \
       DeviceMemoryBase *c, int ldc, int64 stride_c, int batch_count);          \
+  port::Status DoBlasGemmStridedBatchedWithAlgorithm(                          \
+      Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
+      uint64 m, uint64 n, uint64 k, const void *alpha,                         \
+      const DeviceMemoryBase &a, blas::DataType type_a, int lda,               \
+      int64 stride_a, const DeviceMemoryBase &b, blas::DataType type_b,        \
+      int ldb, int64 stride_b, const void *beta, DeviceMemoryBase *c,          \
+      blas::DataType type_c, int ldc, int64 stride_c, int batch_count,         \
+      blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
+      blas::ProfileResult *output_profile_result) override;                    \
   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
                   uint64 m, uint64 n, std::complex<float> alpha,               \
                   const DeviceMemory<std::complex<float>> &a, int lda,         \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index e4b2b63..f54f332 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2034,13 +2034,9 @@
 #endif
 }
 
-port::Status CUDABlas::DoBlasGemmWithAlgorithm(
-    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
-    uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
-    blas::DataType type_a, int lda, const DeviceMemoryBase &b,
-    blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
-    blas::DataType type_c, int ldc, blas::ComputationType computation_type,
-    blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+static port::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
+    Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
+    blas::DataType type_b) {
   if (type_a != type_b) {
     return port::InternalError("Types of inputs mismatch");
   }
@@ -2091,15 +2087,6 @@
     }
   }
 
-  std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
-  if (output_profile_result != nullptr) {
-    timer.reset(new GpuTimer(parent_));
-    if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
-      return port::InternalError(
-          "output_profile_result given, but unable to create a GpuTimer");
-    }
-  }
-
   // Return false if we might be hitting a cuBLAS bug that produces the wrong
   // result. See nvbugs/2156201, b/79126339.
 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
@@ -2110,23 +2097,27 @@
         "<9.2 bug with m, n, or k >= 2097153.  See b/79126339.");
   }
 #endif
+  return math_type;
+}
 
-  cudaDataType_t cuda_in_type = GetCUDADataType(type_a);
-  // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
-  // we do the following compile-time check on the default value:
-  static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
-  // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we
-  // essentially reinterpet_cast to __half, which is safe because Eigen::half
-  // inherits from __half.
-  port::Status st = DoBlasInternalImpl(
-      AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
-      CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha,
-      a.opaque(), cuda_in_type, lda, b.opaque(), cuda_in_type, ldb, beta,
-      c->opaque(), GetCUDADataType(type_c), ldc,
-      CUDAComputationType(computation_type),
-      static_cast<cublasGemmAlgo_t>(algorithm));
+static port::StatusOr<std::unique_ptr<GpuTimer, GpuTimerDeleter>>
+StartGpuTimerForProfile(Stream *stream, GpuExecutor *executor,
+                        blas::ProfileResult *output_profile_result) {
+  std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
+  if (output_profile_result) {
+    timer.reset(new GpuTimer(executor));
+    if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
+      return port::InternalError(
+          "output_profile_result given, but unable to create a GpuTimer");
+    }
+  }
+  return timer;
+}
 
-  if (timer != nullptr && st.ok()) {
+static port::Status PopulateProfileFromTimer(
+    GpuTimer *timer, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result, Stream *stream) {
+  if (timer) {
     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
     // state.
     if (!timer->Stop(AsGpuStream(stream))) {
@@ -2137,6 +2128,62 @@
     output_profile_result->set_elapsed_time_in_ms(
         timer->GetElapsedMilliseconds());
   }
+  return port::Status::OK();
+}
+
+port::Status CUDABlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+    blas::DataType type_a, int lda, const DeviceMemoryBase &b,
+    blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
+    blas::DataType type_c, int ldc, blas::ComputationType computation_type,
+    blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+  TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
+                      GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
+
+  TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
+                                      stream, parent_, output_profile_result));
+
+  // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
+  // we do the following compile-time check on the default value:
+  static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
+
+  TF_RETURN_IF_ERROR(DoBlasInternalImpl(
+      AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
+      CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha,
+      a.opaque(), GetCUDADataType(type_a), lda, b.opaque(),
+      GetCUDADataType(type_b), ldb, beta, c->opaque(), GetCUDADataType(type_c),
+      ldc, CUDAComputationType(computation_type),
+      static_cast<cublasGemmAlgo_t>(algorithm)));
+  TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
+                                              output_profile_result, stream));
+  return port::Status::OK();
+}
+
+port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+    blas::DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b,
+    blas::DataType type_b, int ldb, int64 stride_b, const void *beta,
+    DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64 stride_c,
+    int batch_count, blas::ComputationType computation_type,
+    blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+  TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
+                      GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
+
+  TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
+                                      stream, parent_, output_profile_result));
+
+  cudaDataType_t cuda_in_type = GetCUDADataType(type_a);
+  port::Status st = DoBlasInternalImpl(
+      AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
+      math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+      alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type,
+      ldb, stride_b, beta, c->opaque(), GetCUDADataType(type_c), ldc, stride_c,
+      batch_count, CUDAComputationType(computation_type),
+      static_cast<cublasGemmAlgo_t>(algorithm));
+  TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
+                                              output_profile_result, stream));
   return st;
 }
 
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index 391ae14..74f56f1 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -1827,6 +1827,18 @@
   return port::InternalError("Not implemented on ROCm");
 }
 
+port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
+    blas::DataType type_a, int lda, int64 stride_a, const DeviceMemoryBase &b,
+    blas::DataType type_b, int ldb, int64 stride_b, const void *beta,
+    DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64 stride_c,
+    int batch_count, blas::ComputationType computation_type,
+    blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+  // ROCM TODO: properly implement the interface
+  return port::InternalError("Not implemented on ROCm");
+}
+
 bool ROCMBlas::GetBlasGemmAlgorithms(
     std::vector<blas::AlgorithmType> *out_algorithms) {
   // ROCM TODO: properly implement the interface
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index d29e238..c95383d 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1470,35 +1470,10 @@
       DeviceMemory<OutputType> *c, int ldc,
       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
       blas::ProfileResult *output_profile_result) {
-    static_assert(std::is_same<InputType, Eigen::half>::value ||
-                      std::is_same<InputType, float>::value ||
-                      std::is_same<InputType, double>::value ||
-                      std::is_same<InputType, int8>::value ||
-                      std::is_same<InputType, std::complex<float>>::value ||
-                      std::is_same<InputType, std::complex<double>>::value,
-                  "The only buffer types supported are: Eigen::half, float, "
-                  "double, int8, std::complex<float> and std::complex<double>");
-    static_assert(
-        std::is_same<InputType, OutputType>::value ||
-            (std::is_same<InputType, int8>::value &&
-             std::is_same<OutputType, int32>::value),
-        "Input and output buffer types should be the same unless input is "
-        "int8 and output is int32");
-    static_assert(std::is_same<ConstantType, OutputType>::value ||
-                      (std::is_same<ConstantType, float>::value &&
-                       std::is_same<OutputType, Eigen::half>::value),
-                  "Constant and output types should match");
-    blas::ComputationType expected_computation_type =
-        blas::ToComputationType<ConstantType>::value;
-    if (expected_computation_type != computation_type &&
-        !(computation_type == blas::ComputationType::kF32 &&
-          expected_computation_type == blas::ComputationType::kF16)) {
-      return port::InternalError(absl::StrCat(
-          "Alpha/beta type and computation type have to match, got ",
-          blas::ComputationTypeString(computation_type),
-          " for computation type, expected: ",
-          blas::ComputationTypeString(expected_computation_type)));
-    }
+    TF_RETURN_IF_ERROR(
+        CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
+            computation_type));
+
     blas::BlasSupport *blas = parent()->AsBlas();
     if (!blas) {
       return port::InternalError(
@@ -1525,6 +1500,38 @@
     return st;
   }
 
+  template <typename InputType, typename OutputType, typename ConstantType>
+  port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
+      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+      uint64 k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
+      int64 stride_a, const DeviceMemory<InputType> &b, int ldb, int64 stride_b,
+      ConstantType beta, DeviceMemory<OutputType> *c, int ldc, int64 stride_c,
+      int batch_count, blas::ComputationType computation_type,
+      blas::AlgorithmType algorithm,
+      blas::ProfileResult *output_profile_result) {
+    TF_RETURN_IF_ERROR(
+        CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
+            computation_type));
+
+    blas::BlasSupport *blas = parent()->AsBlas();
+    if (!blas) {
+      return port::InternalError(
+          "Attempting to perform BLAS operation using "
+          "StreamExecutor without BLAS support");
+    }
+    port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
+        this, transa, transb, m, n, k, &alpha, a,
+        blas::ToDataType<InputType>::value, stride_a, lda, b,
+        blas::ToDataType<InputType>::value, ldb, stride_b, &beta, c,
+        blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
+        computation_type, algorithm, output_profile_result);
+    if (output_profile_result) {
+      // The error is recorded in the profile.
+      return port::Status::OK();
+    }
+    return st;
+  }
+
   // See BlasSupport::DoBlasGemmBatched.
   Stream &ThenBlasGemmBatched(
       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
@@ -2162,6 +2169,42 @@
   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
   friend class ocl::CLBlas;    // for parent_.
 
+  // Checks whether types match before a call to extended BLAS version.
+  template <typename InputType, typename OutputType, typename ConstantType>
+  port::Status CheckTypesForExtendedBlas(
+      blas::ComputationType computation_type) {
+    static_assert(std::is_same<InputType, Eigen::half>::value ||
+                      std::is_same<InputType, float>::value ||
+                      std::is_same<InputType, double>::value ||
+                      std::is_same<InputType, int8>::value ||
+                      std::is_same<InputType, std::complex<float>>::value ||
+                      std::is_same<InputType, std::complex<double>>::value,
+                  "The only buffer types supported are: Eigen::half, float, "
+                  "double, int8, std::complex<float> and std::complex<double>");
+    static_assert(
+        std::is_same<InputType, OutputType>::value ||
+            (std::is_same<InputType, int8>::value &&
+             std::is_same<OutputType, int32>::value),
+        "Input and output buffer types should be the same unless input is "
+        "int8 and output is int32");
+    static_assert(std::is_same<ConstantType, OutputType>::value ||
+                      (std::is_same<ConstantType, float>::value &&
+                       std::is_same<OutputType, Eigen::half>::value),
+                  "Constant and output types should match");
+    blas::ComputationType expected_computation_type =
+        blas::ToComputationType<ConstantType>::value;
+    if (expected_computation_type != computation_type &&
+        !(computation_type == blas::ComputationType::kF32 &&
+          expected_computation_type == blas::ComputationType::kF16)) {
+      return port::InternalError(absl::StrCat(
+          "Alpha/beta type and computation type have to match, got ",
+          blas::ComputationTypeString(computation_type),
+          " for computation type, expected: ",
+          blas::ComputationTypeString(expected_computation_type)));
+    }
+    return port::Status::OK();
+  }
+
   bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
     absl::ReaderMutexLock lock(&mu_);
     return !status_.ok();