Change port::ArraySlice<T> to absl::Span<const T> and change port::MutableArraySlice<T> to absl::Span<T> in tensorflow/stream_executor and tensorflow/compiler/xla/stream_executor
PiperOrigin-RevId: 468600178
diff --git a/tensorflow/compiler/xla/stream_executor/BUILD b/tensorflow/compiler/xla/stream_executor/BUILD
index c34138d..bd53f40 100644
--- a/tensorflow/compiler/xla/stream_executor/BUILD
+++ b/tensorflow/compiler/xla/stream_executor/BUILD
@@ -128,6 +128,7 @@
deps = [
"//tensorflow/compiler/xla/stream_executor/lib",
"//tensorflow/compiler/xla/stream_executor/platform",
+ "@com_google_absl//absl/types:span",
],
)
@@ -263,6 +264,7 @@
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -334,6 +336,7 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -430,6 +433,7 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -461,6 +465,7 @@
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
],
)
@@ -476,6 +481,7 @@
"//tensorflow/compiler/xla/stream_executor/lib",
"//tensorflow/compiler/xla/stream_executor/platform",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -540,6 +546,7 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/stream_executor/blas.h b/tensorflow/compiler/xla/stream_executor/blas.h
index 6aaf0fe..2ea0790 100644
--- a/tensorflow/compiler/xla/stream_executor/blas.h
+++ b/tensorflow/compiler/xla/stream_executor/blas.h
@@ -43,10 +43,10 @@
#include <complex>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/stream_executor/data_type.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
-#include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
@@ -65,9 +65,6 @@
template <typename ElemT>
class HostOrDeviceScalar;
-template <typename T>
-using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok
-
namespace blas {
// Specifies whether the input matrix will be transposed or
@@ -1084,43 +1081,39 @@
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb,
uint64_t m, uint64_t n, uint64 k, float alpha,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, // non-absl ok
- int lda,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, // non-absl ok
- int ldb, float beta,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb,
uint64_t m, uint64_t n, uint64 k, float alpha,
- const port::ArraySlice<DeviceMemory<float> *> &a, int lda, // non-absl ok
- const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, // non-absl ok
- float beta,
- const port::ArraySlice<DeviceMemory<float> *> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ const absl::Span<DeviceMemory<float> *const> a, int lda,
+ const absl::Span<DeviceMemory<float> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<float> *const> c, int ldc, int batch_count,
+ ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb,
uint64_t m, uint64_t n, uint64 k, double alpha,
- const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
- int lda,
- const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
- int ldb, double beta,
- const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ const absl::Span<DeviceMemory<double> *const> a, int lda,
+ const absl::Span<DeviceMemory<double> *const> b, int ldb, double beta,
+ const absl::Span<DeviceMemory<double> *const> c, int ldc, int batch_count,
+ ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb,
uint64_t m, uint64_t n, uint64 k, std::complex<float> alpha,
- const DeviceMemorySlice<std::complex<float>> &a, int lda,
- const DeviceMemorySlice<std::complex<float>> &b, int ldb,
- std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c,
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb,
+ std::complex<float> beta,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb,
uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha,
- const DeviceMemorySlice<std::complex<double>> &a, int lda,
- const DeviceMemorySlice<std::complex<double>> &b, int ldb,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb,
std::complex<double> beta,
- const DeviceMemorySlice<std::complex<double>> &c, int ldc,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) = 0;
// Batched gemm with strides instead of pointer arrays.
@@ -1934,39 +1927,39 @@
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64_t m, uint64 n, uint64 k, float alpha, \
- const DeviceMemorySlice<Eigen::half> &a, int lda, \
- const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta, \
- const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count, \
- ScratchAllocator *scratch_allocator) override; \
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda, \
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, \
+ float beta, const absl::Span<DeviceMemory<Eigen::half> *const> c, \
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64_t m, uint64 n, uint64 k, float alpha, \
- const DeviceMemorySlice<float> &a, int lda, \
- const DeviceMemorySlice<float> &b, int ldb, float beta, \
- const DeviceMemorySlice<float> &c, int ldc, int batch_count, \
- ScratchAllocator *scratch_allocator) override; \
+ const absl::Span<DeviceMemory<float> *const> a, int lda, \
+ const absl::Span<DeviceMemory<float> *const> b, int ldb, float beta, \
+ const absl::Span<DeviceMemory<float> *const> c, int ldc, \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64_t m, uint64 n, uint64 k, double alpha, \
- const DeviceMemorySlice<double> &a, int lda, \
- const DeviceMemorySlice<double> &b, int ldb, double beta, \
- const DeviceMemorySlice<double> &c, int ldc, int batch_count, \
- ScratchAllocator *scratch_allocator) override; \
+ const absl::Span<DeviceMemory<double> *const> a, int lda, \
+ const absl::Span<DeviceMemory<double> *const> b, int ldb, double beta, \
+ const absl::Span<DeviceMemory<double> *const> c, int ldc, \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64_t m, uint64 n, uint64 k, std::complex<float> alpha, \
- const DeviceMemorySlice<std::complex<float>> &a, int lda, \
- const DeviceMemorySlice<std::complex<float>> &b, int ldb, \
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda, \
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb, \
std::complex<float> beta, \
- const DeviceMemorySlice<std::complex<float>> &c, int ldc, \
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc, \
int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64_t m, uint64 n, uint64 k, std::complex<double> alpha, \
- const DeviceMemorySlice<std::complex<double>> &a, int lda, \
- const DeviceMemorySlice<std::complex<double>> &b, int ldb, \
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda, \
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb, \
std::complex<double> beta, \
- const DeviceMemorySlice<std::complex<double>> &c, int ldc, \
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc, \
int batch_count, ScratchAllocator *scratch_allocator) override; \
port::Status DoBlasGemmStridedBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
diff --git a/tensorflow/compiler/xla/stream_executor/cuda/BUILD b/tensorflow/compiler/xla/stream_executor/cuda/BUILD
index f1f9fff..a76e40e 100644
--- a/tensorflow/compiler/xla/stream_executor/cuda/BUILD
+++ b/tensorflow/compiler/xla/stream_executor/cuda/BUILD
@@ -314,6 +314,7 @@
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
"//third_party/eigen3",
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/compiler/xla:status_macros",
@@ -444,6 +445,7 @@
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@cudnn_frontend_archive//:cudnn_frontend",
"//third_party/eigen3",
"@local_config_cuda//cuda:cuda_headers",
diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc
index 4a2e6fb..a04477b 100644
--- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc
@@ -2172,10 +2172,10 @@
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
- const DeviceMemorySlice<T> &a_ptrs_to_wrappers, int lda,
- const DeviceMemorySlice<T> &b_ptrs_to_wrappers, int ldb, Scalar beta,
- const DeviceMemorySlice<T> &c_ptrs_to_wrappers, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<T> *const> a_ptrs_to_wrappers, int lda,
+ const absl::Span<DeviceMemory<T> *const> b_ptrs_to_wrappers, int ldb,
+ Scalar beta, const absl::Span<DeviceMemory<T> *const> c_ptrs_to_wrappers,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
for (int i = 0; i < batch_count; ++i) {
a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
@@ -2309,10 +2309,10 @@
bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, float alpha,
- const DeviceMemorySlice<Eigen::half> &a_array, int lda,
- const DeviceMemorySlice<Eigen::half> &b_array, int ldb, float beta,
- const DeviceMemorySlice<Eigen::half> &c_array, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<Eigen::half> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b_array, int ldb,
+ float beta, const absl::Span<DeviceMemory<Eigen::half> *const> c_array,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
// Note: The func passed here (cublasSgemmBatched) is not actually called,
// due to special handling of fp16 inside DoBlasGemmBatchedInternal.
port::Status status = DoBlasGemmBatchedInternal(
@@ -2326,10 +2326,11 @@
bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
- uint64_t n, uint64 k, float alpha, const DeviceMemorySlice<float> &a_array,
- int lda, const DeviceMemorySlice<float> &b_array, int ldb, float beta,
- const DeviceMemorySlice<float> &c_array, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator) {
+ uint64_t n, uint64 k, float alpha,
+ const absl::Span<DeviceMemory<float> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<float> *const> b_array, int ldb, float beta,
+ const absl::Span<DeviceMemory<float> *const> c_array, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
port::Status status = DoBlasGemmBatchedInternal(
cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
@@ -2342,10 +2343,10 @@
bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, double alpha,
- const DeviceMemorySlice<double> &a_array, int lda,
- const DeviceMemorySlice<double> &b_array, int ldb, double beta,
- const DeviceMemorySlice<double> &c_array, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<double> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<double> *const> b_array, int ldb, double beta,
+ const absl::Span<DeviceMemory<double> *const> c_array, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
port::Status status = DoBlasGemmBatchedInternal(
cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
@@ -2358,10 +2359,10 @@
bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, std::complex<float> alpha,
- const DeviceMemorySlice<std::complex<float>> &a_array, int lda,
- const DeviceMemorySlice<std::complex<float>> &b_array, int ldb,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b_array, int ldb,
std::complex<float> beta,
- const DeviceMemorySlice<std::complex<float>> &c_array, int ldc,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) {
port::Status status = DoBlasGemmBatchedInternal(
cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
@@ -2375,11 +2376,12 @@
bool CUDABlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, std::complex<double> alpha,
- const DeviceMemorySlice<std::complex<double>> &a_array, int lda,
- const DeviceMemorySlice<std::complex<double>> &b_array, int ldb,
- std::complex<double> beta,
- const DeviceMemorySlice<std::complex<double>> &c_array, int ldc,
- int batch_count, ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a_array,
+ int lda,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b_array,
+ int ldb, std::complex<double> beta,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c_array,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
port::Status status = DoBlasGemmBatchedInternal(
cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h
index 7e97383..0314f54 100644
--- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h
@@ -41,8 +41,6 @@
// Opaque and unique identifier for the cuBLAS plugin.
extern const PluginId kCuBlasPlugin;
-template <typename T>
-using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok
// BLAS plugin for CUDA platform via cuBLAS library.
//
@@ -109,10 +107,10 @@
port::Status DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
- const DeviceMemorySlice<T> &a_array, int lda,
- const DeviceMemorySlice<T> &b_array, int ldb, Scalar beta,
- const DeviceMemorySlice<T> &c_array, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator);
+ const absl::Span<DeviceMemory<T> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<T> *const> b_array, int ldb, Scalar beta,
+ const absl::Span<DeviceMemory<T> *const> c_array, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
// Helper function for implementing DoBlasGemmWithProfiling.
template <typename T, typename ParamType>
diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc
index 01382c7..f7cbd54 100644
--- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc
@@ -6428,10 +6428,10 @@
return IsStatusOk(status, /*report_error=*/true);
}
-bool CudnnSupport::DoDepthConcatenate(Stream* stream,
- BatchDescriptorSlice input_dimensions,
- DeviceMemorySlice<float> input_data,
- DeviceMemory<float>* output_data) {
+bool CudnnSupport::DoDepthConcatenate(
+ Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
+ DeviceMemory<float>* output_data) {
CHECK_EQ(input_dimensions.size(), input_data.size());
for (const auto& dimensions : input_dimensions) {
@@ -6486,11 +6486,10 @@
return true;
}
-bool CudnnSupport::DoElementwiseOperate(Stream*, dnn::ElementwiseOperation,
- BatchDescriptorSlice,
- DeviceMemorySlice<float>,
- const dnn::BatchDescriptor&,
- DeviceMemory<float>*) {
+bool CudnnSupport::DoElementwiseOperate(
+ Stream*, dnn::ElementwiseOperation, absl::Span<const dnn::BatchDescriptor>,
+ absl::Span<const DeviceMemory<float>* const>, const dnn::BatchDescriptor&,
+ DeviceMemory<float>*) {
LOG(FATAL) << "not yet implemented"; // TODO(leary)
return false;
}
diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h
index 4d6d20e..312539c 100644
--- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h
@@ -20,6 +20,7 @@
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
#include "absl/base/thread_annotations.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/compiler/xla/stream_executor/dnn.h"
#include "tensorflow/compiler/xla/stream_executor/lib/status.h"
@@ -38,13 +39,6 @@
// Opaque and unique identifier for the cuDNN plugin.
extern const PluginId kCuDnnPlugin;
-using BatchDescriptorSlice =
- port::ArraySlice<dnn::BatchDescriptor>; // non-absl ok
-
-template <typename T>
-using DeviceMemorySlice =
- port::ArraySlice<const DeviceMemory<T>*>; // non-absl ok
-
// cudnn-library based DNN support. For details on overridden interface
// functions, see dnn.h.
class CudnnSupport : public dnn::DnnSupport {
@@ -468,15 +462,17 @@
DeviceMemory<float>* raw_variable_gradient,
ScratchAllocator* workspace_allocator) override;
- bool DoDepthConcatenate(Stream* stream, BatchDescriptorSlice input_dimensions,
- DeviceMemorySlice<float> input_data,
- DeviceMemory<float>* output_data) override;
+ bool DoDepthConcatenate(
+ Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
+ DeviceMemory<float>* output_data) override;
- bool DoElementwiseOperate(Stream* stream, dnn::ElementwiseOperation operation,
- BatchDescriptorSlice input_dimensions,
- DeviceMemorySlice<float> input_data,
- const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<float>* output_data) override;
+ bool DoElementwiseOperate(
+ Stream* stream, dnn::ElementwiseOperation operation,
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) override;
bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
const DeviceMemory<float>& input_data, int64_t left_pad,
diff --git a/tensorflow/compiler/xla/stream_executor/dnn.cc b/tensorflow/compiler/xla/stream_executor/dnn.cc
index fa1a3bc..f0ed6f9 100644
--- a/tensorflow/compiler/xla/stream_executor/dnn.cc
+++ b/tensorflow/compiler/xla/stream_executor/dnn.cc
@@ -595,7 +595,7 @@
}
BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
- port::ArraySlice<dnn::BatchDescriptor> inputs) { // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> inputs) {
if (inputs.empty()) {
return BatchDescriptor();
}
diff --git a/tensorflow/compiler/xla/stream_executor/dnn.h b/tensorflow/compiler/xla/stream_executor/dnn.h
index c383ab3..66e2179 100644
--- a/tensorflow/compiler/xla/stream_executor/dnn.h
+++ b/tensorflow/compiler/xla/stream_executor/dnn.h
@@ -37,7 +37,6 @@
#include "tensorflow/compiler/xla/stream_executor/device_description.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
-#include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
#include "tensorflow/compiler/xla/stream_executor/lib/status.h"
#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
@@ -353,7 +352,7 @@
// dimensions, except possibly for feature_map_count(), though this
// function does not verify that.
static BatchDescriptor DepthConcatenateOutputDescriptor(
- port::ArraySlice<dnn::BatchDescriptor> inputs); // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> inputs);
private:
absl::Span<const int64_t> spatial_size() const {
@@ -1707,9 +1706,8 @@
// output_data: un-owned device memory region in which to place the
// depth concatenate result.
virtual bool DoDepthConcatenate(
- Stream* stream,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
+ Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
DeviceMemory<float>* output_data) = 0;
// Concatenates several layers into one, by concatenating each in the
@@ -1734,9 +1732,8 @@
// concat_direction: either dnn:SpaceConcatenateMode::XDirection or
// dnn::SpaceConcatenateMode::YDirection.
virtual bool DoSpaceConcatenate(
- Stream* stream,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
+ Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
DeviceMemory<float>* output_data,
dnn::SpaceConcatenateMode concat_direction) {
return false;
@@ -1854,8 +1851,8 @@
// operation result.
virtual bool DoElementwiseOperate(
Stream* stream, ElementwiseOperation operation,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) = 0;
@@ -1882,10 +1879,9 @@
// operation result.
virtual bool DoElementwiseOperateScaledQuantized(
Stream* stream, ElementwiseOperation operation,
- port::ArraySlice<int> input_multiplicands, // non-absl ok
- int output_divisor,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float>*> input_data, // non-absl ok
+ absl::Span<const int> input_multiplicands, int output_divisor,
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) {
return false;
diff --git a/tensorflow/compiler/xla/stream_executor/kernel.h b/tensorflow/compiler/xla/stream_executor/kernel.h
index e7f5119..8eab4a5 100644
--- a/tensorflow/compiler/xla/stream_executor/kernel.h
+++ b/tensorflow/compiler/xla/stream_executor/kernel.h
@@ -76,9 +76,9 @@
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h"
-#include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
#include "tensorflow/core/platform/logging.h"
@@ -361,8 +361,7 @@
virtual uint64_t number_of_shared_bytes() const = 0;
// Gets the list of argument addresses.
- virtual port::ArraySlice<const void *> argument_addresses() // non-absl ok
- const = 0;
+ virtual absl::Span<const void *const> argument_addresses() const = 0;
// Gets an iterator to the arguments in the array.
virtual KernelArgIterator arg_iterator() const = 0;
@@ -448,10 +447,9 @@
}
// Gets the list of argument addresses.
- port::ArraySlice<const void *> argument_addresses() // non-absl ok
- const override {
- return port::ArraySlice<const void *>( // non-absl ok
- argument_addresses_.data(), number_of_argument_addresses_);
+ absl::Span<const void *const> argument_addresses() const override {
+ return absl::Span<const void *const>(argument_addresses_.data(),
+ number_of_argument_addresses_);
}
// Gets an iterator to the arguments in the array.
diff --git a/tensorflow/compiler/xla/stream_executor/module_spec.h b/tensorflow/compiler/xla/stream_executor/module_spec.h
index 0775603..44dcb45 100644
--- a/tensorflow/compiler/xla/stream_executor/module_spec.h
+++ b/tensorflow/compiler/xla/stream_executor/module_spec.h
@@ -16,7 +16,7 @@
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_MODULE_SPEC_H_
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_MODULE_SPEC_H_
-#include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
@@ -30,7 +30,7 @@
class MultiModuleLoaderSpec {
public:
bool has_cuda_cubin_in_memory() const { return has_cuda_cubin_in_memory_; }
- port::ArraySlice<const uint8> cuda_cubin_in_memory() const { // non-absl ok
+ absl::Span<const uint8> cuda_cubin_in_memory() const {
CHECK(has_cuda_cubin_in_memory());
return {cuda_cubin_in_memory_.data(), cuda_cubin_in_memory_.size()};
}
@@ -41,8 +41,7 @@
return cuda_ptx_in_memory_;
}
- void AddCudaCubinInMemory(
- port::ArraySlice<const uint8> cubin_bytes) { // non-absl ok
+ void AddCudaCubinInMemory(absl::Span<const uint8> cubin_bytes) {
CHECK(!cubin_bytes.empty());
has_cuda_cubin_in_memory_ = true;
cuda_cubin_in_memory_ = cubin_bytes;
@@ -55,7 +54,7 @@
}
private:
- port::ArraySlice<const uint8> cuda_cubin_in_memory_; // non-absl ok
+ absl::Span<const uint8> cuda_cubin_in_memory_;
bool has_cuda_cubin_in_memory_ = false;
const char* cuda_ptx_in_memory_;
bool has_cuda_ptx_in_memory_ = false;
diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc
index 43933d8..138a31c 100644
--- a/tensorflow/compiler/xla/stream_executor/stream.cc
+++ b/tensorflow/compiler/xla/stream_executor/stream.cc
@@ -135,7 +135,7 @@
std::string ToVlogString(double d) { return absl::StrCat(d); }
template <class T>
-std::string ToVlogString(port::ArraySlice<T> elements) { // non-absl ok
+std::string ToVlogString(absl::Span<const T> elements) {
std::string str = absl::StrCat(
ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
elements.size(), "]{");
@@ -161,8 +161,8 @@
}
template <class T>
-std::string ToVlogString(port::MutableArraySlice<T> elements) { // non-absl ok
- return ToVlogString(port::ArraySlice<T>(elements)); // non-absl ok
+std::string ToVlogString(absl::Span<T> elements) {
+ return ToVlogString(absl::Span<const T>(elements));
}
std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
@@ -670,8 +670,8 @@
}
Stream &Stream::ThenDepthConcatenate(
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
@@ -698,8 +698,8 @@
}
Stream &Stream::ThenSpaceConcatenate(
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
DeviceMemory<float> *output_data,
dnn::SpaceConcatenateMode concat_direction) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
@@ -798,8 +798,8 @@
Stream &Stream::ThenElementwiseOperate(
dnn::ElementwiseOperation operation,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
@@ -817,10 +817,9 @@
Stream &Stream::ThenElementwiseOperateScaledQuantized(
dnn::ElementwiseOperation operation,
- port::ArraySlice<int> input_multiplicands, // non-absl ok
- int output_divisor,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const int> input_multiplicands, int output_divisor,
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
@@ -3504,12 +3503,10 @@
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, float alpha,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, // non-absl ok
- int lda,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, // non-absl ok
- int ldb, float beta,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, // non-absl ok
- int ldc, int batch_count) {
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc,
+ int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr);
@@ -3518,24 +3515,19 @@
Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, float alpha,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, // non-absl ok
- int lda,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, // non-absl ok
- int ldb, float beta,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
- ThenBlasImpl<
- blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, float,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &, // non-absl ok
- int,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &, // non-absl ok
- int, float,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &, // non-absl ok
- int, int, ScratchAllocator *>
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
+ float, const absl::Span<DeviceMemory<Eigen::half> *const>, int,
+ const absl::Span<DeviceMemory<Eigen::half> *const>, int, float,
+ const absl::Span<DeviceMemory<Eigen::half> *const>, int, int,
+ ScratchAllocator *>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
@@ -3544,12 +3536,10 @@
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, float alpha,
- const port::ArraySlice<DeviceMemory<float> *> &a, // non-absl ok
- int lda, const port::ArraySlice<DeviceMemory<float> *> &b, // non-absl ok
- int ldb, float beta,
- const port::ArraySlice<DeviceMemory<float> *> &c, // non-absl ok
- int ldc, int batch_count) {
+ uint64_t k, float alpha, const absl::Span<DeviceMemory<float> *const> a,
+ int lda, const absl::Span<DeviceMemory<float> *const> b, int ldb,
+ float beta, const absl::Span<DeviceMemory<float> *const> c, int ldc,
+ int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr);
@@ -3557,22 +3547,19 @@
Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, float alpha,
- const port::ArraySlice<DeviceMemory<float> *> &a, // non-absl ok
- int lda, const port::ArraySlice<DeviceMemory<float> *> &b, // non-absl ok
- int ldb, float beta,
- const port::ArraySlice<DeviceMemory<float> *> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ uint64_t k, float alpha, const absl::Span<DeviceMemory<float> *const> a,
+ int lda, const absl::Span<DeviceMemory<float> *const> b, int ldb,
+ float beta, const absl::Span<DeviceMemory<float> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
- ThenBlasImpl<
- blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, float,
- const port::ArraySlice<DeviceMemory<float> *> &, int, // non-absl ok
- const port::ArraySlice<DeviceMemory<float> *> &, int, // non-absl ok
- float, const port::ArraySlice<DeviceMemory<float> *> &, // non-absl ok
- int, int, ScratchAllocator *>
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
+ float, const absl::Span<DeviceMemory<float> *const>, int,
+ const absl::Span<DeviceMemory<float> *const>, int, float,
+ const absl::Span<DeviceMemory<float> *const>, int, int,
+ ScratchAllocator *>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
@@ -3581,12 +3568,10 @@
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, double alpha,
- const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
- int lda, const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
- int ldb, double beta,
- const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
- int ldc, int batch_count) {
+ uint64_t k, double alpha, const absl::Span<DeviceMemory<double> *const> a,
+ int lda, const absl::Span<DeviceMemory<double> *const> b, int ldb,
+ double beta, const absl::Span<DeviceMemory<double> *const> c, int ldc,
+ int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr);
@@ -3594,23 +3579,19 @@
Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, double alpha,
- const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
- int lda, const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
- int ldb, double beta,
- const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ uint64_t k, double alpha, const absl::Span<DeviceMemory<double> *const> a,
+ int lda, const absl::Span<DeviceMemory<double> *const> b, int ldb,
+ double beta, const absl::Span<DeviceMemory<double> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
- ThenBlasImpl<
- blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, double,
- const port::ArraySlice<DeviceMemory<double> *> &, // non-absl ok
- int, const port::ArraySlice<DeviceMemory<double> *> &, // non-absl ok
- int, double,
- const port::ArraySlice<DeviceMemory<double> *> &, // non-absl ok
- int, int, ScratchAllocator *>
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
+ double, const absl::Span<DeviceMemory<double> *const>, int,
+ const absl::Span<DeviceMemory<double> *const>, int, double,
+ const absl::Span<DeviceMemory<double> *const>, int, int,
+ ScratchAllocator *>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
@@ -3620,15 +3601,11 @@
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, std::complex<float> alpha,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
- &a,
- int lda,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
- &b,
- int ldb, std::complex<float> beta,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
- &c,
- int ldc, int batch_count) {
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb,
+ std::complex<float> beta,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc,
+ int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr);
@@ -3637,25 +3614,22 @@
Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, std::complex<float> alpha,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
- &a,
- int lda,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
- &b,
- int ldb, std::complex<float> beta,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
- &c,
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb,
+ std::complex<float> beta,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
- ThenBlasImpl<
- blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
- std::complex<float>, const DeviceMemorySlice<std::complex<float>> &, int,
- const DeviceMemorySlice<std::complex<float>> &, int, std::complex<float>,
- const DeviceMemorySlice<std::complex<float>> &, int, int,
- ScratchAllocator *>
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
+ std::complex<float>,
+ const absl::Span<DeviceMemory<std::complex<float>> *const>, int,
+ const absl::Span<DeviceMemory<std::complex<float>> *const>, int,
+ std::complex<float>,
+ const absl::Span<DeviceMemory<std::complex<float>> *const>, int,
+ int, ScratchAllocator *>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
@@ -3665,33 +3639,35 @@
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, std::complex<double> alpha,
- const DeviceMemorySlice<std::complex<double>> &a, int lda,
- const DeviceMemorySlice<std::complex<double>> &b, int ldb,
- std::complex<double> beta, const DeviceMemorySlice<std::complex<double>> &c,
- int ldc, int batch_count) {
- return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
- b, ldb, beta, c, ldc, batch_count,
- /*scratch_allocator=*/nullptr);
-}
-
-Stream &Stream::ThenBlasGemmBatchedWithScratch(
- blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, std::complex<double> alpha,
- const DeviceMemorySlice<std::complex<double>> &a, int lda,
- const DeviceMemorySlice<std::complex<double>> &b, int ldb,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb,
std::complex<double> beta,
- const DeviceMemorySlice<std::complex<double>> &c, // non-absl ok
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc,
+ int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
+ uint64_t k, std::complex<double> alpha,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb,
+ std::complex<double> beta,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
- ThenBlasImpl<
- blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
- std::complex<double>, const DeviceMemorySlice<std::complex<double>> &,
- int, const DeviceMemorySlice<std::complex<double>> &, int,
- std::complex<double>, const DeviceMemorySlice<std::complex<double>> &,
- int, int, ScratchAllocator *>
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
+ std::complex<double>,
+ const absl::Span<DeviceMemory<std::complex<double>> *const>, int,
+ const absl::Span<DeviceMemory<std::complex<double>> *const>, int,
+ std::complex<double>,
+ const absl::Span<DeviceMemory<std::complex<double>> *const>, int,
+ int, ScratchAllocator *>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h
index 2e4eb90..8e4500a 100644
--- a/tensorflow/compiler/xla/stream_executor/stream.h
+++ b/tensorflow/compiler/xla/stream_executor/stream.h
@@ -28,6 +28,7 @@
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/stream_executor/blas.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/dnn.h"
@@ -35,7 +36,6 @@
#include "tensorflow/compiler/xla/stream_executor/fft.h"
#include "tensorflow/compiler/xla/stream_executor/kernel.h"
#include "tensorflow/compiler/xla/stream_executor/launch_dim.h"
-#include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h"
@@ -528,13 +528,13 @@
uint64_t options);
Stream &ThenDepthConcatenate(
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
DeviceMemory<float> *output_data);
Stream &ThenSpaceConcatenate(
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
DeviceMemory<float> *output_data,
dnn::SpaceConcatenateMode concat_direction);
@@ -572,17 +572,16 @@
Stream &ThenElementwiseOperate(
dnn::ElementwiseOperation operation,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data);
Stream &ThenElementwiseOperateScaledQuantized(
dnn::ElementwiseOperation operation,
- port::ArraySlice<int> input_multiplicands, // non-absl ok
- int output_divisor,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
- port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
+ absl::Span<const int> input_multiplicands, int output_divisor,
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float> *const> input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data);
@@ -609,13 +608,12 @@
dnn::QuantizedActivationMode mode,
void *host_dst, uint64_t size);
- // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
+ // Template version of ThenMemcpyD2HQuantized that takes a Span
// and uses the Quantization trait to call the generic version of
// ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
template <typename ElementType>
- Stream &ThenMemcpyD2HQuantized(
- const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<ElementType> host_dst) {
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ absl::Span<ElementType> host_dst) {
return ThenMemcpyD2HQuantized(
gpu_unquantized_src, Quantization<ElementType>::kModeId,
host_dst.data(), host_dst.size() * sizeof(ElementType));
@@ -630,9 +628,8 @@
// and uses the Quantization trait to call the generic version of
// ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
template <typename ElementType>
- Stream &ThenMemcpyH2DQuantized(
- port::ArraySlice<ElementType> host_src, // non-absl ok
- DeviceMemory<float> *gpu_unquantized_dst) {
+ Stream &ThenMemcpyH2DQuantized(absl::Span<const ElementType> host_src,
+ DeviceMemory<float> *gpu_unquantized_dst) {
return ThenMemcpyH2DQuantized(
host_src.data(), host_src.size() * sizeof(ElementType),
Quantization<ElementType>::kModeId, gpu_unquantized_dst);
@@ -1381,82 +1378,80 @@
return st;
}
- template <typename T>
- using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok
-
// See BlasSupport::DoBlasGemmBatched.
- Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
- uint64_t m, uint64 n, uint64_t k, float alpha,
- const DeviceMemorySlice<Eigen::half> &a, int lda,
- const DeviceMemorySlice<Eigen::half> &b, int ldb,
- float beta,
- const DeviceMemorySlice<Eigen::half> &c, int ldc,
- int batch_count);
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
+ uint64_t k, float alpha,
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc,
+ int batch_count);
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
uint64_t m, uint64 n, uint64 k, float alpha,
- const DeviceMemorySlice<float> &a, int lda,
- const DeviceMemorySlice<float> &b, int ldb,
- float beta, const DeviceMemorySlice<float> &c,
+ const absl::Span<DeviceMemory<float> *const> a,
+ int lda,
+ const absl::Span<DeviceMemory<float> *const> b,
+ int ldb, float beta,
+ const absl::Span<DeviceMemory<float> *const> c,
+ int ldc, int batch_count);
+ Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
+ uint64_t m, uint64 n, uint64 k, double alpha,
+ const absl::Span<DeviceMemory<double> *const> a,
+ int lda,
+ const absl::Span<DeviceMemory<double> *const> b,
+ int ldb, double beta,
+ const absl::Span<DeviceMemory<double> *const> c,
int ldc, int batch_count);
Stream &ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64 k, double alpha,
- const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
- int lda,
- const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
- int ldb, double beta,
- const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
- int ldc, int batch_count);
- Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
- uint64_t m, uint64 n, uint64_t k,
- std::complex<float> alpha,
- const DeviceMemorySlice<std::complex<float>> &a,
- int lda,
- const DeviceMemorySlice<std::complex<float>> &b,
- int ldb, std::complex<float> beta,
- const DeviceMemorySlice<std::complex<float>> &c,
- int ldc, int batch_count);
- Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
- uint64_t m, uint64 n, uint64_t k,
- std::complex<double> alpha,
- const DeviceMemorySlice<std::complex<double>> &a,
- int lda,
- const DeviceMemorySlice<std::complex<double>> &b,
- int ldb, std::complex<double> beta,
- const DeviceMemorySlice<std::complex<double>> &c,
- int ldc, int batch_count);
+ uint64_t k, std::complex<float> alpha,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb,
+ std::complex<float> beta,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc,
+ int batch_count);
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
+ uint64_t k, std::complex<double> alpha,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb,
+ std::complex<double> beta,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc,
+ int batch_count);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, float alpha, const DeviceMemorySlice<Eigen::half> &a, int lda,
- const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta,
- const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator);
+ uint64_t k, float alpha,
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, float alpha, const DeviceMemorySlice<float> &a, int lda,
- const DeviceMemorySlice<float> &b, int ldb, float beta,
- const DeviceMemorySlice<float> &c, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator);
+ uint64_t k, float alpha, const absl::Span<DeviceMemory<float> *const> a,
+ int lda, const absl::Span<DeviceMemory<float> *const> b, int ldb,
+ float beta, const absl::Span<DeviceMemory<float> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
- uint64_t k, double alpha, const DeviceMemorySlice<double> &a, int lda,
- const DeviceMemorySlice<double> &b, int ldb, double beta,
- const DeviceMemorySlice<double> &c, int ldc, int batch_count,
- ScratchAllocator *scratch_allocator);
+ uint64_t k, double alpha, const absl::Span<DeviceMemory<double> *const> a,
+ int lda, const absl::Span<DeviceMemory<double> *const> b, int ldb,
+ double beta, const absl::Span<DeviceMemory<double> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, std::complex<float> alpha,
- const DeviceMemorySlice<std::complex<float>> &a, int lda,
- const DeviceMemorySlice<std::complex<float>> &b, int ldb,
- std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c,
- int ldc, int batch_count, ScratchAllocator *scratch_allocator);
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb,
+ std::complex<float> beta,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, std::complex<double> alpha,
- const DeviceMemorySlice<std::complex<double>> &a, int lda,
- const DeviceMemorySlice<std::complex<double>> &b, int ldb,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb,
std::complex<double> beta,
- const DeviceMemorySlice<std::complex<double>> &c, int ldc,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
template <typename InputType, typename ConstantType>
@@ -1734,7 +1729,7 @@
// slice size.
template <typename T>
Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
- port::MutableArraySlice<T> host_dst) {
+ absl::Span<T> host_dst) {
auto host_size = host_dst.size() * sizeof(T);
CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
@@ -1744,7 +1739,7 @@
// array slice. Checks that the destination size can accommodate the host
// slice size.
template <typename T>
- Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src, // non-absl ok
+ Stream &ThenMemcpyH2D(absl::Span<const T> host_src,
DeviceMemory<T> *gpu_dst) {
auto host_size = host_src.size() * sizeof(T);
CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h
index 0d1b45a..29e68d6 100644
--- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h
@@ -266,9 +266,8 @@
// array slice. Checks that the destination size can accommodate the host
// slice size.
template <class T>
- port::Status SynchronousMemcpyH2D(
- port::ArraySlice<T> host_src, // non-absl ok
- DeviceMemoryBase* device_dst) {
+ port::Status SynchronousMemcpyH2D(absl::Span<const T> host_src,
+ DeviceMemoryBase* device_dst) {
auto host_size = host_src.size() * sizeof(T);
CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
@@ -283,7 +282,7 @@
// slice size.
template <typename T>
port::Status SynchronousMemcpyD2H(const DeviceMemory<T>& device_src,
- port::MutableArraySlice<T> host_dst) {
+ absl::Span<T> host_dst) {
auto host_size = host_dst.size() * sizeof(T);
CHECK(device_src.size() == 0 || host_size >= device_src.size());
return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index 5899cc2..0e323b3 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -207,6 +207,7 @@
"//tensorflow/compiler/xla/stream_executor:blas",
"//tensorflow/stream_executor/platform:dso_loader",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@local_config_rocm//rocm:rocm_headers",
]),
alwayslink = True,
@@ -283,6 +284,7 @@
"//tensorflow/stream_executor/platform:dso_loader",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@local_config_rocm//rocm:rocm_headers",
]),
alwayslink = True,
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index 051aa1e..b3bf473 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -1733,9 +1733,9 @@
port::Status ROCMBlas::DoBlasGemmBatchedInternal(
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha,
- const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
- const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
- T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+ const absl::Span<DeviceMemory<T> *const> a_ptrs_to_wrappers, int lda,
+ const absl::Span<DeviceMemory<T> *const> b_ptrs_to_wrappers, int ldb,
+ T beta, const absl::Span<DeviceMemory<T> *const> c_ptrs_to_wrappers,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
@@ -1827,9 +1827,9 @@
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, float alpha,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
- const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda,
+ const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta,
+ const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
const Eigen::half alpha_half(alpha);
@@ -1849,9 +1849,9 @@
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, float alpha,
- const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
- const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
- const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
+ const absl::Span<DeviceMemory<float> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<float> *const> b_array, int ldb, float beta,
+ const absl::Span<DeviceMemory<float> *const> c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
port::Status status = DoBlasGemmBatchedInternal(
@@ -1867,10 +1867,10 @@
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, double alpha,
- const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
- const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
- double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<double> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<double> *const> b_array, int ldb, double beta,
+ const absl::Span<DeviceMemory<double> *const> c_array, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
@@ -1885,12 +1885,11 @@
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, std::complex<float> alpha,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
- int lda,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
- int ldb, std::complex<float> beta,
- const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
- int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ const absl::Span<DeviceMemory<std::complex<float>> *const> a_array, int lda,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> b_array, int ldb,
+ std::complex<float> beta,
+ const absl::Span<DeviceMemory<std::complex<float>> *const> c_array, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
@@ -1905,11 +1904,11 @@
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
uint64_t n, uint64 k, std::complex<double> alpha,
- const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> a_array,
int lda,
- const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> b_array,
int ldb, std::complex<double> beta,
- const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
+ const absl::Span<DeviceMemory<std::complex<double>> *const> c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
port::Status status = DoBlasGemmBatchedInternal(
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h
index dd5dc87..e4fd847 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.h
+++ b/tensorflow/stream_executor/rocm/rocm_blas.h
@@ -165,9 +165,9 @@
port::Status DoBlasGemmBatchedInternal(
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha,
- const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
- const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
- T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+ const absl::Span<DeviceMemory<T> *const> a_ptrs_to_wrappers, int lda,
+ const absl::Span<DeviceMemory<T> *const> b_ptrs_to_wrappers, int ldb,
+ T beta, const absl::Span<DeviceMemory<T> *const> c_ptrs_to_wrappers,
int ldc, int batch_count, ScratchAllocator *scratch_allocator);
// Helper function for implementing DoBlasGemmWithProfiling.
diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc
index a0feeeb..7b56e71 100644
--- a/tensorflow/stream_executor/rocm/rocm_dnn.cc
+++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc
@@ -4417,8 +4417,8 @@
}
bool MIOpenSupport::DoDepthConcatenate(
- Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
- port::ArraySlice<const DeviceMemory<float>*> input_data,
+ Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
DeviceMemory<float>* output_data) {
CHECK_EQ(input_dimensions.size(), input_data.size());
@@ -4476,8 +4476,8 @@
bool MIOpenSupport::DoElementwiseOperate(
Stream* stream, dnn::ElementwiseOperation operation,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
- port::ArraySlice<const DeviceMemory<float>*> input_data,
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) {
LOG(FATAL) << "not yet implemented"; // TODO(leary)
diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h
index 43b486a..35b4568 100644
--- a/tensorflow/stream_executor/rocm/rocm_dnn.h
+++ b/tensorflow/stream_executor/rocm/rocm_dnn.h
@@ -472,14 +472,14 @@
ScratchAllocator* workspace_allocator = nullptr) override;
bool DoDepthConcatenate(
- Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
- port::ArraySlice<const DeviceMemory<float>*> input_data,
+ Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
DeviceMemory<float>* output_data) override;
bool DoElementwiseOperate(
Stream* stream, dnn::ElementwiseOperation operation,
- port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
- port::ArraySlice<const DeviceMemory<float>*> input_data,
+ absl::Span<const dnn::BatchDescriptor> input_dimensions,
+ absl::Span<const DeviceMemory<float>* const> input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) override;