Replace CreateBlasLtMatmulPlan args with struct
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 456b4be..ac5a45b 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -555,23 +555,23 @@
GetBlasComputationType(dtype, allow_tf32, &computation_type),
errors::Internal("Unsupported dtype for batched matmul"));
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
- stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
- /*ab_type=*/blas_dtype,
- /*cd_type=*/blas_dtype, computation_type,
- se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
- blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
- /*lda=*/in_y.dim_size(2), b_stride,
- /*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
+ stream->parent()->CreateBlasLtMatmulPlan(
+ {/*ab_type=*/blas_dtype,
+ /*c_type=*/blas_dtype, computation_type,
+ se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
+ blas_transpose_b, blas_transpose_a, n, m, k,
+ /*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
+ batch_size, b_stride, a_stride, c_stride});
OP_REQUIRES(
context, plan,
- errors::Internal(
- "CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
- in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
- in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
- in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
- ", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
- ", adjoint_b=", adj_x, ", dtype=", dtype,
- ", computation_type=", computation_type));
+ errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
+ in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
+ in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
+ ", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
+ "), m=", m, ", n=", n, ", k=", k,
+ ", batch_size=", batch_size, ", adjoint_a=", adj_x,
+ ", adjoint_b=", adj_x, ", dtype=", dtype,
+ ", computation_type=", computation_type));
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
algorithms;
OP_REQUIRES(
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index ae5b485..411f6f1 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -242,6 +242,27 @@
virtual size_t workspace_size() const = 0;
};
+// Parameters for the CreateBlasLtMatmulPlan method.
+struct BlasLtMatmulPlanParams {
+ DataType ab_type;
+ DataType c_type;
+ ComputationType computation_type;
+ PointerMode pointer_mode;
+ Epilogue epilogue;
+ Transpose transa;
+ Transpose transb;
+ uint64 m;
+ uint64 n;
+ uint64 k;
+ int64 lda;
+ int64 ldb;
+ int64 ldc;
+ int batch_count = 1;
+ int64 stride_a = 0;
+ int64 stride_b = 0;
+ int64 stride_c = 0;
+};
+
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@@ -1466,25 +1487,8 @@
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
- std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
- blas::DataType ab_type, blas::DataType c_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
- return CreateBlasLtMatmulPlanStridedBatched(
- ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
- transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
- }
-
- // A more general version of CreateBlasLtMatmulPlan supporting
- // batched operations.
- virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
- CreateBlasLtMatmulPlanStridedBatched(
- blas::DataType ab_type, blas::DataType c_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
- int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
+ virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
+ const blas::BlasLtMatmulPlanParams& params) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
@@ -2372,14 +2376,8 @@
uint64 n, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *b, int ldb) override; \
- std::unique_ptr<blas::IBlasLtMatmulPlan> \
- CreateBlasLtMatmulPlanStridedBatched( \
- blas::DataType ab_type, blas::DataType cd_type, \
- blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
- int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
- override; \
+ std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
+ const blas::BlasLtMatmulPlanParams& params) override; \
bool GetBlasLtMatmulAlgorithms( \
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
int max_algorithm_count, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 1d95b00..f2bc79e 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -3231,13 +3231,7 @@
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
- CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType compute_type,
- blas::PointerMode pointer_mode, blas::Epilogue epilogue,
- blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, int batch_count, int64 lda,
- int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
- int64 stride_c, int64 ldd, int64 stride_d);
+ CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@@ -3280,39 +3274,34 @@
};
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
- blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
- int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
- int64 stride_d)
+ const blas::BlasLtMatmulPlanParams& p)
: op_desc_(CreateCublasLtOperationDesc(
- computation_type, GetScaleType(cd_type, computation_type),
- pointer_mode, epilogue, transa, transb)),
+ p.computation_type, GetScaleType(p.c_type, p.computation_type),
+ p.pointer_mode, p.epilogue, p.transa, p.transb)),
a_desc_(nullptr),
b_desc_(nullptr),
- c_desc_(
- CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
- d_desc_(
- CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
- ab_type_(ab_type),
- cd_type_(cd_type),
- scale_type_(GetScaleType(cd_type, computation_type)),
- pointer_mode_(pointer_mode),
- epilogue_(epilogue),
- batch_count_(batch_count),
- stride_a_(stride_a),
- stride_b_(stride_b),
- stride_c_(stride_c),
- stride_d_(stride_d) {
- uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
- uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
- uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
- uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
- a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
- batch_count);
- b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
- batch_count);
+ c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
+ p.batch_count)),
+ d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
+ p.batch_count)),
+ ab_type_(p.ab_type),
+ cd_type_(p.c_type),
+ scale_type_(GetScaleType(p.c_type, p.computation_type)),
+ pointer_mode_(p.pointer_mode),
+ epilogue_(p.epilogue),
+ batch_count_(p.batch_count),
+ stride_a_(p.stride_a),
+ stride_b_(p.stride_b),
+ stride_c_(p.stride_c),
+ stride_d_(p.stride_c) {
+ uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
+ uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
+ uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
+ uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
+ a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
+ p.stride_a, p.batch_count);
+ b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
+ p.stride_b, p.batch_count);
}
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
@@ -3395,18 +3384,10 @@
#endif // CUDA_VERSION >= 11000
-std::unique_ptr<blas::IBlasLtMatmulPlan>
-CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
- blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
- int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
+std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
+ const blas::BlasLtMatmulPlanParams& params) {
#if CUDA_VERSION >= 11000
- auto result = std::make_unique<CUDABlasLtMatmulPlan>(
- ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
- transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
- ldc, stride_c);
+ auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
if (!result->ok()) {
result.reset();
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index d75c1bc..d40b6ad 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -337,34 +337,12 @@
}
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
- blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
+ const blas::BlasLtMatmulPlanParams& params) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
- return blas_support->CreateBlasLtMatmulPlan(
- ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
- transb, m, n, k, lda, ldb, ldc);
-}
-
-std::unique_ptr<blas::IBlasLtMatmulPlan>
-StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
- blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
- int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
- blas::BlasSupport *blas_support = AsBlas();
- if (!blas_support) {
- return nullptr;
- }
- return blas_support->CreateBlasLtMatmulPlanStridedBatched(
- ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
- transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
- stride_c);
+ return blas_support->CreateBlasLtMatmulPlan(params);
}
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index b40c0c2..ce801bf 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -399,19 +399,7 @@
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
- blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
-
- // A more general version of CreateBlasLtMatmulPlan supporting
- // batched operations.
- std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
- blas::DataType ab_type, blas::DataType cd_type,
- blas::ComputationType computation_type, blas::PointerMode pointer_mode,
- blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
- uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
- int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
+ const blas::BlasLtMatmulPlanParams& params);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an