Move the CtcLossDescriptor constructor/destructor back to the header
Surface the scratch memory allocation to the ThenCtcLoss()
Use the absl::Span as a pointer
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 75b2f9f..8ace170 100755
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1704,38 +1704,6 @@
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
-port::StatusOr<DeviceMemory<uint8>> CreateCtcLossWorkspace(
- Stream* stream, const CudnnHandle& cudnn,
- const CudnnCtcLossDescriptor& ctc_loss_desc,
- const CudnnRnnStateTensorDescriptor& probs_desc,
- const CudnnRnnStateTensorDescriptor& grads_desc,
- const absl::Span<const int32>& labels_data,
- const absl::Span<const int32>& labels_lengths_data,
- const absl::Span<const int32>& input_lengths_data,
- ScratchAllocator* workspace_allocator) {
- // Query the workspace size.
- size_t workspace_size_in_bytes = 0;
-#if CUDNN_VERSION >= 7603
- RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
- /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
- /*gradientsDesc=*/grads_desc.handle(),
- /*labels=*/labels_data.data(),
- /*labelLengths=*/labels_lengths_data.data(),
- /*inputLengths=*/input_lengths_data.data(),
- /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
- /*ctcLossDesc=*/ctc_loss_desc.handle(),
- /*sizeInBytes=*/&workspace_size_in_bytes));
-#else
- return port::Status(port::error::INVALID_ARGUMENT,
- "No supported cudnnGetCTCLossWorkspaceSize when "
- "CUDNN_VERSION < 7.6.3");
-#endif
- // Allocate the workspace.
- if (workspace_size_in_bytes == 0) {
- return DeviceMemory<uint8>();
- }
- return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
-}
#endif
} // namespace
@@ -2052,22 +2020,16 @@
port::Status CudnnSupport::DoCtcLossImpl(
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
- const absl::Span<const int32>& labels_data,
- const absl::Span<const int32>& labels_lengths_data,
- const absl::Span<const int32>& input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
const CudnnCtcLossDescriptor& ctc_loss_desc,
- ScratchAllocator* workspace_allocator) {
+ DeviceMemory<uint8> scratch_memory) {
auto cudnn = cudnn_->GetHandle(parent_, stream);
- SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
- CreateCtcLossWorkspace(stream, cudnn, ctc_loss_desc,
- probs_desc, grads_desc,
- labels_data, labels_lengths_data,
- input_lengths_data,
- workspace_allocator));
int kNumTimestamps = probs_desc.num_layers();
int kBatchSize = probs_desc.batch_size();
int kNumLabels = probs_desc.data_size();
@@ -2083,8 +2045,8 @@
/*gradients=*/grads_data.opaque(),
/*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
/*ctcLossDesc=*/ctc_loss_desc.handle(),
- /*workspace=*/workspace.opaque(),
- /*workSpaceSizeInBytes=*/workspace.size()));
+ /*workspace=*/scratch_memory.opaque(),
+ /*workSpaceSizeInBytes=*/scratch_memory.size()));
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnCTCLoss when "
@@ -3953,18 +3915,67 @@
/*report_error=*/!output_profile_result);
}
+port::Status CudnnSupport::DoPrepareForCtcLoss(
+ Stream* stream, dnn::DataType element_type,
+ const dnn::CtcLossDescriptor &ctc_loss_desc,
+ const dnn::RnnStateTensorDescriptor &probs_desc,
+ const dnn::RnnStateTensorDescriptor &grads_desc,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
+ ScratchAllocator* scratch_allocator,
+ DeviceMemory<uint8>* scratch_memory) {
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ctc_loss_desc,
+ ToCudnnDataType(element_type));
+ const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
+ // Query the workspace size.
+ size_t workspace_size_in_bytes = 0;
+#if CUDNN_VERSION >= 7603
+ RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
+ /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
+ /*gradientsDesc=*/cudnn_grads_desc.handle(),
+ /*labels=*/labels_data.data(),
+ /*labelLengths=*/labels_lengths_data.data(),
+ /*inputLengths=*/input_lengths_data.data(),
+ /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
+ /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
+ /*sizeInBytes=*/&workspace_size_in_bytes));
+#else
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No supported cudnnGetCTCLossWorkspaceSize when "
+ "CUDNN_VERSION < 7.6.3");
+#endif
+ // Allocate the workspace.
+ if (workspace_size_in_bytes == 0) {
+ *scratch_memory = DeviceMemory<uint8>();
+ return port::Status::OK();
+ }
+ const auto scratch_or = scratch_allocator->AllocateBytes(
+ workspace_size_in_bytes);
+ if (scratch_or.ok()) {
+ *scratch_memory = scratch_or.ValueOrDie();
+ return port::Status::OK();
+ }
+ return port::InternalError(
+ "Failed to allocate scratch memory for the CuDNN CTC Loss");
+}
+
port::Status CudnnSupport::DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemoryBase probs_data,
- const absl::Span<const int32> &labels_data,
- const absl::Span<const int32> &labels_lengths_data,
- const absl::Span<const int32> &input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemoryBase grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
- ScratchAllocator *workspace_allocator) {
+ DeviceMemory<uint8> scratch_memory) {
// Current cuDNN CTC Loss only supports the float datatype
if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
return port::Status(port::error::INVALID_ARGUMENT,
@@ -3980,7 +3991,7 @@
return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
- workspace_allocator);
+ scratch_memory);
}
bool CudnnSupport::DoTransformTensor(Stream* stream,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 6b4eba5..bdf4166 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -567,14 +567,14 @@
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemoryBase probs_data,
- const absl::Span<const int32> &labels_data,
- const absl::Span<const int32> &labels_lengths_data,
- const absl::Span<const int32> &input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemoryBase grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
- ScratchAllocator *workspace_allocator) override;
+ DeviceMemory<uint8> scratch_memory) override;
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,
@@ -690,14 +690,14 @@
port::Status DoCtcLossImpl(
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
- const absl::Span<const int32>& labels_data,
- const absl::Span<const int32>& labels_lengths_data,
- const absl::Span<const int32>& input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
const CudnnCtcLossDescriptor& ctc_loss_desc,
- ScratchAllocator* workspace_allocator);
+ DeviceMemory<uint8> scratch_memory);
private:
port::Status DoPrepareForConvolution(
@@ -712,6 +712,17 @@
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
DeviceMemory<uint8>* scratch_memory) override;
+ port::Status DoPrepareForCtcLoss(
+ Stream* stream, dnn::DataType element_type,
+ const dnn::CtcLossDescriptor &ctc_loss_desc,
+ const dnn::RnnStateTensorDescriptor &probs_desc,
+ const dnn::RnnStateTensorDescriptor &grads_desc,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
+ ScratchAllocator* scratch_allocator,
+ DeviceMemory<uint8>* scratch_memory) override;
+
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index c8c0201..38d6abc 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -505,12 +505,6 @@
return desc;
}
-// -- CtcLossDescriptor
-//
-CtcLossDescriptor::CtcLossDescriptor() {}
-
-CtcLossDescriptor::~CtcLossDescriptor() {}
-
// -- PoolingDescriptor
PoolingDescriptor::PoolingDescriptor(int ndims)
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 5f7f2aa..f508014 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -193,8 +193,8 @@
// Describes a CTC loss operation.
class CtcLossDescriptor {
public:
- CtcLossDescriptor();
- ~CtcLossDescriptor();
+ CtcLossDescriptor() {}
+ ~CtcLossDescriptor() {}
};
// Specifies the sequence in a RNN model.
@@ -2390,6 +2390,24 @@
return false;
}
+ template <typename ElementType>
+ port::Status PrepareForCtcLoss(
+ Stream* stream,
+ const CtcLossDescriptor &ctc_loss_desc,
+ const RnnStateTensorDescriptor &probs_desc,
+ DeviceMemory<ElementType> probs_data,
+ const RnnStateTensorDescriptor &grads_desc,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
+ ScratchAllocator *workspace_allocator,
+ DeviceMemory<uint8>* scratch_memory) {
+ return DoPrepareForCtcLoss(
+ stream, ToDataType<ElementType>::value, ctc_loss_desc, probs_desc,
+ grads_desc, labels_data, labels_lengths_data, input_lengths_data,
+ workspace_allocator, scratch_memory);
+ }
+
// Enqueue a CTC Loss operation onto the stream.
//
// Arguments:
@@ -2413,34 +2431,34 @@
// afterwards.
virtual port::Status DoCtcLoss(Stream* stream,
dnn::DataType element_type,
- const dnn::RnnStateTensorDescriptor &probs_desc,
+ const RnnStateTensorDescriptor &probs_desc,
const DeviceMemoryBase probs_data,
- const absl::Span<const int32> &labels_data,
- const absl::Span<const int32> &labels_lengths_data,
- const absl::Span<const int32> &input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
- const dnn::RnnStateTensorDescriptor &grads_desc,
+ const RnnStateTensorDescriptor &grads_desc,
DeviceMemoryBase grads_data,
- const dnn::CtcLossDescriptor &ctc_loss_desc,
- ScratchAllocator *workspace_allocator) = 0;
+ const CtcLossDescriptor &ctc_loss_desc,
+ DeviceMemory<uint8> scratch_memory) = 0;
template<typename ElementType>
bool DoCtcLoss(Stream* stream,
const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<ElementType> &probs_data,
- const absl::Span<const int32> &labels_data,
- const absl::Span<const int32> &labels_lengths_data,
- const absl::Span<const int32> &input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemory<ElementType> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<ElementType> *grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
- ScratchAllocator *workspace_allocator) {
+ DeviceMemory<uint8>* scratch_memory) {
return IsStatusOk(
DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
probs_data, labels_data, labels_lengths_data,
input_lengths_data, *costs_data, grads_desc, *grads_data,
- ctc_loss_desc, workspace_allocator),
+ ctc_loss_desc, *scratch_memory),
false);
}
@@ -2699,6 +2717,20 @@
return port::Status::OK();
}
+ virtual port::Status DoPrepareForCtcLoss(
+ Stream* stream, DataType element_type,
+ const CtcLossDescriptor &ctc_loss_desc,
+ const RnnStateTensorDescriptor &probs_desc,
+ const RnnStateTensorDescriptor &grads_desc,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
+ ScratchAllocator* scratch_allocator,
+ DeviceMemory<uint8>* scratch_memory) {
+ *scratch_memory = {};
+ return port::Status::OK();
+ }
+
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
};
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index ed119fb..a079a79 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -5232,9 +5232,9 @@
Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
- const absl::Span<const int32> &labels_data,
- const absl::Span<const int32> &labels_lengths_data,
- const absl::Span<const int32> &input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
@@ -5242,10 +5242,19 @@
ScratchAllocator *workspace_allocator) {
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- auto status = dnn->DoCtcLoss(
- this, probs_desc, probs_data, labels_data, labels_lengths_data,
- input_lengths_data, costs_data, grads_desc, grads_data, ctc_loss_desc,
- workspace_allocator);
+ DeviceMemory<uint8> scratch_memory;
+ auto status =
+ dnn->PrepareForCtcLoss(
+ this, ctc_loss_desc, probs_desc, probs_data, grads_desc,
+ labels_data, labels_lengths_data, input_lengths_data,
+ workspace_allocator, &scratch_memory)
+ .ok();
+ if (status) {
+ status = dnn->DoCtcLoss(
+ this, probs_desc, probs_data, labels_data, labels_lengths_data,
+ input_lengths_data, costs_data, grads_desc, grads_data,
+ ctc_loss_desc, &scratch_memory);
+ }
if (!status) {
SetError();
}
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index fe12908..208103b 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1917,9 +1917,9 @@
Stream &ThenCtcLoss(
const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
- const absl::Span<const int32> &labels_data,
- const absl::Span<const int32> &labels_lengths_data,
- const absl::Span<const int32> &input_lengths_data,
+ absl::Span<const int> labels_data,
+ absl::Span<const int> labels_lengths_data,
+ absl::Span<const int> input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,