Move the CtcLossDescriptor constructor/destructor back to the header

Surface the scratch memory allocation to the ThenCtcLoss()

Use the absl::Span as a pointer
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 75b2f9f..8ace170 100755
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1704,38 +1704,6 @@
   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
 }
 
-port::StatusOr<DeviceMemory<uint8>> CreateCtcLossWorkspace(
-    Stream* stream, const CudnnHandle& cudnn,
-    const CudnnCtcLossDescriptor& ctc_loss_desc,
-    const CudnnRnnStateTensorDescriptor& probs_desc,
-    const CudnnRnnStateTensorDescriptor& grads_desc,
-    const absl::Span<const int32>& labels_data,
-    const absl::Span<const int32>& labels_lengths_data,
-    const absl::Span<const int32>& input_lengths_data,
-    ScratchAllocator* workspace_allocator) {
-  // Query the workspace size.
-  size_t workspace_size_in_bytes = 0;
-#if CUDNN_VERSION >= 7603
-  RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
-      /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
-      /*gradientsDesc=*/grads_desc.handle(),
-      /*labels=*/labels_data.data(),
-      /*labelLengths=*/labels_lengths_data.data(),
-      /*inputLengths=*/input_lengths_data.data(),
-      /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
-      /*ctcLossDesc=*/ctc_loss_desc.handle(),
-      /*sizeInBytes=*/&workspace_size_in_bytes));
-#else
-  return port::Status(port::error::INVALID_ARGUMENT,
-                      "No supported cudnnGetCTCLossWorkspaceSize when "
-                      "CUDNN_VERSION < 7.6.3");
-#endif
-  // Allocate the workspace.
-  if (workspace_size_in_bytes == 0) {
-    return DeviceMemory<uint8>();
-  }
-  return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
-}
 #endif
 
 }  // namespace
@@ -2052,22 +2020,16 @@
 port::Status CudnnSupport::DoCtcLossImpl(
     Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
     const DeviceMemoryBase probs_data,
-    const absl::Span<const int32>& labels_data,
-    const absl::Span<const int32>& labels_lengths_data,
-    const absl::Span<const int32>& input_lengths_data,
+    absl::Span<const int> labels_data,
+    absl::Span<const int> labels_lengths_data,
+    absl::Span<const int> input_lengths_data,
     DeviceMemoryBase costs_data,
     const CudnnRnnStateTensorDescriptor& grads_desc,
     DeviceMemoryBase grads_data,
     const CudnnCtcLossDescriptor& ctc_loss_desc,
-    ScratchAllocator* workspace_allocator) {
+    DeviceMemory<uint8> scratch_memory) {
   auto cudnn = cudnn_->GetHandle(parent_, stream);
 
-  SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
-                      CreateCtcLossWorkspace(stream, cudnn, ctc_loss_desc,
-                                             probs_desc, grads_desc,
-                                             labels_data, labels_lengths_data,
-                                             input_lengths_data,
-                                             workspace_allocator));
   int kNumTimestamps = probs_desc.num_layers();
   int kBatchSize = probs_desc.batch_size();
   int kNumLabels = probs_desc.data_size();
@@ -2083,8 +2045,8 @@
           /*gradients=*/grads_data.opaque(),
           /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
           /*ctcLossDesc=*/ctc_loss_desc.handle(),
-          /*workspace=*/workspace.opaque(),
-          /*workSpaceSizeInBytes=*/workspace.size()));
+          /*workspace=*/scratch_memory.opaque(),
+          /*workSpaceSizeInBytes=*/scratch_memory.size()));
 #else
   return port::Status(port::error::INVALID_ARGUMENT,
                       "No supported cudnnCTCLoss when "
@@ -3953,18 +3915,67 @@
       /*report_error=*/!output_profile_result);
 }
 
+port::Status CudnnSupport::DoPrepareForCtcLoss(
+      Stream* stream, dnn::DataType element_type,
+      const dnn::CtcLossDescriptor &ctc_loss_desc,
+      const dnn::RnnStateTensorDescriptor &probs_desc,
+      const dnn::RnnStateTensorDescriptor &grads_desc,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
+      ScratchAllocator* scratch_allocator,
+      DeviceMemory<uint8>* scratch_memory) {
+  auto cudnn = cudnn_->GetHandle(parent_, stream);
+  CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ctc_loss_desc,
+                                             ToCudnnDataType(element_type));
+  const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
+      static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
+  const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
+      static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
+  // Query the workspace size.
+  size_t workspace_size_in_bytes = 0;
+#if CUDNN_VERSION >= 7603
+  RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
+      /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
+      /*gradientsDesc=*/cudnn_grads_desc.handle(),
+      /*labels=*/labels_data.data(),
+      /*labelLengths=*/labels_lengths_data.data(),
+      /*inputLengths=*/input_lengths_data.data(),
+      /*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
+      /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
+      /*sizeInBytes=*/&workspace_size_in_bytes));
+#else
+  return port::Status(port::error::INVALID_ARGUMENT,
+                      "No supported cudnnGetCTCLossWorkspaceSize when "
+                      "CUDNN_VERSION < 7.6.3");
+#endif
+  // Allocate the workspace.
+  if (workspace_size_in_bytes == 0) {
+    *scratch_memory = DeviceMemory<uint8>();
+    return port::Status::OK();
+  }
+  const auto scratch_or = scratch_allocator->AllocateBytes(
+      workspace_size_in_bytes);
+  if (scratch_or.ok()) {
+    *scratch_memory = scratch_or.ValueOrDie();
+    return port::Status::OK();
+  }
+  return port::InternalError(
+      "Failed to allocate scratch memory for the CuDNN CTC Loss");
+}
+
 port::Status CudnnSupport::DoCtcLoss(
     Stream* stream, dnn::DataType element_type,
     const dnn::RnnStateTensorDescriptor &probs_desc,
     const DeviceMemoryBase probs_data,
-    const absl::Span<const int32> &labels_data,
-    const absl::Span<const int32> &labels_lengths_data,
-    const absl::Span<const int32> &input_lengths_data,
+    absl::Span<const int> labels_data,
+    absl::Span<const int> labels_lengths_data,
+    absl::Span<const int> input_lengths_data,
     DeviceMemoryBase costs_data,
     const dnn::RnnStateTensorDescriptor &grads_desc,
     DeviceMemoryBase grads_data,
     const dnn::CtcLossDescriptor &ctc_loss_desc,
-    ScratchAllocator *workspace_allocator) {
+    DeviceMemory<uint8> scratch_memory) {
   // Current cuDNN CTC Loss only supports the float datatype
   if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
     return port::Status(port::error::INVALID_ARGUMENT,
@@ -3980,7 +3991,7 @@
   return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
              labels_lengths_data, input_lengths_data, costs_data,
              cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
-             workspace_allocator);
+             scratch_memory);
 }
 
 bool CudnnSupport::DoTransformTensor(Stream* stream,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 6b4eba5..bdf4166 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -567,14 +567,14 @@
       Stream* stream, dnn::DataType element_type,
       const dnn::RnnStateTensorDescriptor &probs_desc,
       const DeviceMemoryBase probs_data,
-      const absl::Span<const int32> &labels_data,
-      const absl::Span<const int32> &labels_lengths_data,
-      const absl::Span<const int32> &input_lengths_data,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
       DeviceMemoryBase costs_data,
       const dnn::RnnStateTensorDescriptor &grads_desc,
       DeviceMemoryBase grads_data,
       const dnn::CtcLossDescriptor &ctc_loss_desc,
-      ScratchAllocator *workspace_allocator) override;
+      DeviceMemory<uint8> scratch_memory) override;
 
   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
                          dnn::DataType input_type,
@@ -690,14 +690,14 @@
   port::Status DoCtcLossImpl(
       Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
       const DeviceMemoryBase probs_data,
-      const absl::Span<const int32>& labels_data,
-      const absl::Span<const int32>& labels_lengths_data,
-      const absl::Span<const int32>& input_lengths_data,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
       DeviceMemoryBase costs_data,
       const CudnnRnnStateTensorDescriptor& grads_desc,
       DeviceMemoryBase grads_data,
       const CudnnCtcLossDescriptor& ctc_loss_desc,
-      ScratchAllocator* workspace_allocator);
+      DeviceMemory<uint8> scratch_memory);
 
  private:
   port::Status DoPrepareForConvolution(
@@ -712,6 +712,17 @@
       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
       DeviceMemory<uint8>* scratch_memory) override;
 
+  port::Status DoPrepareForCtcLoss(
+      Stream* stream, dnn::DataType element_type,
+      const dnn::CtcLossDescriptor &ctc_loss_desc,
+      const dnn::RnnStateTensorDescriptor &probs_desc,
+      const dnn::RnnStateTensorDescriptor &grads_desc,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
+      ScratchAllocator* scratch_allocator,
+      DeviceMemory<uint8>* scratch_memory) override;
+
   SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
 };
 
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index c8c0201..38d6abc 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -505,12 +505,6 @@
   return desc;
 }
 
-// -- CtcLossDescriptor
-//
-CtcLossDescriptor::CtcLossDescriptor() {}
-
-CtcLossDescriptor::~CtcLossDescriptor() {}
-
 // -- PoolingDescriptor
 
 PoolingDescriptor::PoolingDescriptor(int ndims)
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 5f7f2aa..f508014 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -193,8 +193,8 @@
 // Describes a CTC loss operation.
 class CtcLossDescriptor {
  public:
-  CtcLossDescriptor();
-  ~CtcLossDescriptor();
+  CtcLossDescriptor() {}
+  ~CtcLossDescriptor() {}
 };
 
 // Specifies the sequence in a RNN model.
@@ -2390,6 +2390,24 @@
     return false;
   }
 
+  template <typename ElementType>
+  port::Status PrepareForCtcLoss(
+      Stream* stream,
+      const CtcLossDescriptor &ctc_loss_desc,
+      const RnnStateTensorDescriptor &probs_desc,
+      DeviceMemory<ElementType> probs_data,
+      const RnnStateTensorDescriptor &grads_desc,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
+      ScratchAllocator *workspace_allocator,
+      DeviceMemory<uint8>* scratch_memory) {
+    return DoPrepareForCtcLoss(
+        stream, ToDataType<ElementType>::value, ctc_loss_desc, probs_desc,
+        grads_desc, labels_data, labels_lengths_data, input_lengths_data,
+        workspace_allocator, scratch_memory);
+  }
+
   // Enqueue a CTC Loss operation onto the stream.
   //
   // Arguments:
@@ -2413,34 +2431,34 @@
   //    afterwards.
   virtual port::Status DoCtcLoss(Stream* stream,
       dnn::DataType element_type,
-      const dnn::RnnStateTensorDescriptor &probs_desc,
+      const RnnStateTensorDescriptor &probs_desc,
       const DeviceMemoryBase probs_data,
-      const absl::Span<const int32> &labels_data,
-      const absl::Span<const int32> &labels_lengths_data,
-      const absl::Span<const int32> &input_lengths_data,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
       DeviceMemoryBase costs_data,
-      const dnn::RnnStateTensorDescriptor &grads_desc,
+      const RnnStateTensorDescriptor &grads_desc,
       DeviceMemoryBase grads_data,
-      const dnn::CtcLossDescriptor &ctc_loss_desc,
-      ScratchAllocator *workspace_allocator) = 0;
+      const CtcLossDescriptor &ctc_loss_desc,
+      DeviceMemory<uint8> scratch_memory) = 0;
 
   template<typename ElementType>
   bool DoCtcLoss(Stream* stream,
                  const dnn::RnnStateTensorDescriptor &probs_desc,
                  const DeviceMemory<ElementType> &probs_data,
-                 const absl::Span<const int32> &labels_data,
-                 const absl::Span<const int32> &labels_lengths_data,
-                 const absl::Span<const int32> &input_lengths_data,
+                 absl::Span<const int> labels_data,
+                 absl::Span<const int> labels_lengths_data,
+                 absl::Span<const int> input_lengths_data,
                  DeviceMemory<ElementType> *costs_data,
                  const dnn::RnnStateTensorDescriptor &grads_desc,
                  DeviceMemory<ElementType> *grads_data,
                  const dnn::CtcLossDescriptor &ctc_loss_desc,
-                 ScratchAllocator *workspace_allocator) {
+                 DeviceMemory<uint8>* scratch_memory) {
     return IsStatusOk(
         DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
                   probs_data, labels_data, labels_lengths_data,
                   input_lengths_data, *costs_data, grads_desc, *grads_data,
-                  ctc_loss_desc, workspace_allocator),
+                  ctc_loss_desc, *scratch_memory),
         false);
   }
 
@@ -2699,6 +2717,20 @@
     return port::Status::OK();
   }
 
+  virtual port::Status DoPrepareForCtcLoss(
+      Stream* stream, DataType element_type,
+      const CtcLossDescriptor &ctc_loss_desc,
+      const RnnStateTensorDescriptor &probs_desc,
+      const RnnStateTensorDescriptor &grads_desc,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
+      ScratchAllocator* scratch_allocator,
+      DeviceMemory<uint8>* scratch_memory) {
+    *scratch_memory = {};
+    return port::Status::OK();
+  }
+
   SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
 };
 
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index ed119fb..a079a79 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -5232,9 +5232,9 @@
 
 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
                             const DeviceMemory<float> &probs_data,
-                            const absl::Span<const int32> &labels_data,
-                            const absl::Span<const int32> &labels_lengths_data,
-                            const absl::Span<const int32> &input_lengths_data,
+                            absl::Span<const int> labels_data,
+                            absl::Span<const int> labels_lengths_data,
+                            absl::Span<const int> input_lengths_data,
                             DeviceMemory<float> *costs_data,
                             const dnn::RnnStateTensorDescriptor &grads_desc,
                             DeviceMemory<float> *grads_data,
@@ -5242,10 +5242,19 @@
                             ScratchAllocator *workspace_allocator) {
   if (ok()) {
     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-      auto status = dnn->DoCtcLoss(
-          this, probs_desc, probs_data, labels_data, labels_lengths_data,
-          input_lengths_data, costs_data, grads_desc, grads_data, ctc_loss_desc,
-          workspace_allocator);
+      DeviceMemory<uint8> scratch_memory;
+      auto status =
+          dnn->PrepareForCtcLoss(
+              this, ctc_loss_desc, probs_desc, probs_data, grads_desc,
+              labels_data, labels_lengths_data, input_lengths_data,
+              workspace_allocator, &scratch_memory)
+          .ok();
+      if (status) {
+        status = dnn->DoCtcLoss(
+            this, probs_desc, probs_data, labels_data, labels_lengths_data,
+            input_lengths_data, costs_data, grads_desc, grads_data,
+            ctc_loss_desc, &scratch_memory);
+      }
       if (!status) {
         SetError();
       }
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index fe12908..208103b 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1917,9 +1917,9 @@
   Stream &ThenCtcLoss(
       const dnn::RnnStateTensorDescriptor &probs_desc,
       const DeviceMemory<float> &probs_data,
-      const absl::Span<const int32> &labels_data,
-      const absl::Span<const int32> &labels_lengths_data,
-      const absl::Span<const int32> &input_lengths_data,
+      absl::Span<const int> labels_data,
+      absl::Span<const int> labels_lengths_data,
+      absl::Span<const int> input_lengths_data,
       DeviceMemory<float> *costs_data,
       const dnn::RnnStateTensorDescriptor &grads_desc,
       DeviceMemory<float> *grads_data,