Performance improvements for RaggedTensorToTensor

PiperOrigin-RevId: 270283647
diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
index 9199fcf..a49c7ae 100644
--- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
@@ -44,7 +44,6 @@
 namespace {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 using ::std::vector;
-using ::tensorflow::errors::Internal;
 
 const int kShapeInputIndex = 0;
 const int kValueInputIndex = 1;
@@ -188,23 +187,22 @@
    * If first_dimension_output = 11 instead, then:
    * result = [0 100 200 300 400 500 600 700 800 900]
    */
-  vector<INDEX_TYPE> CalculateFirstParentOutputIndex(
-      INDEX_TYPE first_dimension, INDEX_TYPE output_index_multiplier,
-      INDEX_TYPE first_dimension_output) {
+  void CalculateFirstParentOutputIndex(INDEX_TYPE first_dimension,
+                                       INDEX_TYPE output_index_multiplier,
+                                       INDEX_TYPE first_dimension_output,
+                                       vector<INDEX_TYPE>* result) {
     const INDEX_TYPE min_dimension =
         std::min(first_dimension, first_dimension_output);
-    vector<INDEX_TYPE> result;
-    result.reserve(first_dimension);
+    result->reserve(first_dimension);
     int current_output_index = 0;
     for (INDEX_TYPE i = 0; i < min_dimension;
          ++i, current_output_index += output_index_multiplier) {
-      result.push_back(current_output_index);
+      result->push_back(current_output_index);
     }
     for (INDEX_TYPE i = min_dimension; i < first_dimension; ++i) {
-      result.push_back(-1);
+      result->push_back(-1);
     }
-    DCHECK_EQ(result.size(), first_dimension);
-    return result;
+    DCHECK_EQ(result->size(), first_dimension);
   }
 
   void CalculateOutputIndexRowSplit(
@@ -350,10 +348,10 @@
     OP_REQUIRES_OK(context,
                    CalculateOutputSize(first_dimension, context, &output_size));
     vector<INDEX_TYPE> multiplier;
-    multiplier.resize(output_size.size());
+    multiplier.resize(ragged_rank_ + 1);
 
     multiplier[multiplier.size() - 1] = 1;
-    for (int i = output_size.size() - 2; i >= 0; --i) {
+    for (int i = multiplier.size() - 2; i >= 0; --i) {
       multiplier[i] = multiplier[i + 1] * output_size[i + 1];
     }
     // Full size of the tensor.
@@ -366,21 +364,25 @@
                    context->allocate_output(0, output_shape, &output_tensor));
     const INDEX_TYPE full_size = multiplier[0] * output_size[0];
     if (full_size > 0) {
-      vector<INDEX_TYPE> output_index = CalculateFirstParentOutputIndex(
-          first_dimension, multiplier[0], output_size[0]);
+      vector<INDEX_TYPE> output_index, new_output_index;
+      int nvals = context->input(kValueInputIndex).shape().dim_size(0);
+      output_index.reserve(nvals);
+      new_output_index.reserve(nvals);
 
+      CalculateFirstParentOutputIndex(first_dimension, multiplier[0],
+                                      output_size[0], &output_index);
       for (int i = 1; i <= ragged_rank_; ++i) {
-        vector<INDEX_TYPE> new_output_index;
         OP_REQUIRES_OK(context, CalculateOutputIndex(
                                     context, i - 1, output_index, multiplier[i],
                                     output_size[i], &new_output_index));
-        output_index = new_output_index;
+        output_index.swap(new_output_index);
+        new_output_index.clear();
       }
 
-      SetOutput(context, output_index, output_tensor);
+      SetOutput(context, ragged_rank_, output_index, output_tensor);
     }
   }
-  virtual void SetOutput(OpKernelContext* context,
+  virtual void SetOutput(OpKernelContext* context, int ragged_rank,
                          const vector<INDEX_TYPE>& output_index,
                          Tensor* output_tensor) = 0;
 
@@ -397,20 +399,17 @@
 }
 
 template <typename VALUE_TYPE, typename INDEX_TYPE>
-void copy_array(VALUE_TYPE* dst, const VALUE_TYPE* src, INDEX_TYPE size,
-                size_t bytes) {
-  memcpy(dst, src, bytes);
+void copy_array(VALUE_TYPE* dst, const VALUE_TYPE* src, INDEX_TYPE size) {
+  memcpy(dst, src, size * sizeof(VALUE_TYPE));
 }
 
 template <>
-void copy_array<string, int64>(string* dst, const string* src, int64 size,
-                               size_t bytes) {
+void copy_array<string, int64>(string* dst, const string* src, int64 size) {
   slow_copy_array(dst, src, size);
 }
 
 template <>
-void copy_array<string, int32>(string* dst, const string* src, int32 size,
-                               size_t bytes) {
+void copy_array<string, int32>(string* dst, const string* src, int32 size) {
   slow_copy_array(dst, src, size);
 }
 
@@ -419,13 +418,13 @@
 // is not TriviallyCopyable
 template <>
 void copy_array<Eigen::half, int64>(Eigen::half* dst, const Eigen::half* src,
-                                    int64 size, size_t bytes) {
+                                    int64 size) {
   slow_copy_array(dst, src, size);
 }
 
 template <>
 void copy_array<Eigen::half, int32>(Eigen::half* dst, const Eigen::half* src,
-                                    int32 size, size_t bytes) {
+                                    int32 size) {
   slow_copy_array(dst, src, size);
 }
 
@@ -435,80 +434,111 @@
   explicit RaggedTensorToTensorOp(OpKernelConstruction* context)
       : RaggedTensorToTensorBaseOp<INDEX_TYPE>(context) {}
 
-  void SetOutput(OpKernelContext* context,
+  void SetOutput(OpKernelContext* context, int ragged_rank,
                  const vector<INDEX_TYPE>& output_index,
                  Tensor* output_tensor) override {
-    typename tensorflow::TTypes<VALUE_TYPE>::Flat output_flat =
-        output_tensor->flat<VALUE_TYPE>();
-    const auto& value_tensor = context->input(kValueInputIndex);
+    // Note: it's ok to use OP_REQUIRES_OK (rather than TF_RETURN_IF_ERROR)
+    // in this function, but only because it's the last thing we do before
+    // returning from Compute().
+
+    if (output_tensor->NumElements() == 0) return;
+
+    const auto& values_tensor = context->input(kValueInputIndex);
+    const VALUE_TYPE* values_base = values_tensor.flat<VALUE_TYPE>().data();
     const auto& default_value_tensor = context->input(kDefaultValueInputIndex);
-    if (value_tensor.shape().dims() == 1) {
-      // Initialize tensor to default_value.
-      VALUE_TYPE* base_output = output_flat.data();
-      VALUE_TYPE default_value = default_value_tensor.scalar<VALUE_TYPE>()();
+    VALUE_TYPE* output_base = output_tensor->flat<VALUE_TYPE>().data();
 
-      std::fill(base_output, base_output + output_flat.size(), default_value);
-      auto values = context->input(kValueInputIndex).flat<VALUE_TYPE>();
-      int values_size = values.size();
-      OP_REQUIRES(context, values_size == output_index.size(),
-                  Internal("Values and indices must be equal"));
-      for (int i = 0; i < values_size; ++i) {
-        if (output_index[i] >= 0) {
-          output_flat(output_index[i]) = values(i);
-        }
-      }
-    } else {
-      const auto& output_shape = output_tensor->shape();
-      const auto& default_value_shape = default_value_tensor.shape();
+    TensorShape element_shape = output_tensor->shape();
+    element_shape.RemoveDimRange(0, ragged_rank + 1);
+    int value_element_size = element_shape.num_elements();
+    size_t output_index_size = output_index.size();
 
-      // Initialize tensor to default_value.
-
-      BCast bcast(BCast::FromShape(default_value_shape),
-                  BCast::FromShape(output_shape),
+    // Broadcast the default value to value_element_size.  (We can skip this
+    // if default_value_tensor.NumElements() == 1, since we use std::fill
+    // when that's true.)
+    const VALUE_TYPE* default_value =
+        default_value_tensor.flat<VALUE_TYPE>().data();
+    Tensor bcast_default;  // Temporary tensor for result of broadcast
+    if (default_value_tensor.NumElements() != value_element_size &&
+        default_value_tensor.NumElements() != 1) {
+      const auto& src_shape = default_value_tensor.shape();
+      BCast bcast(BCast::FromShape(src_shape), BCast::FromShape(element_shape),
                   /*fewer_dims_optimization=*/true);
-      OP_REQUIRES(
-          context, bcast.IsValid(),
-          errors::InvalidArgument(
-              "Incompatible shapes: ", default_value_shape.DebugString(),
-              " vs. ", default_value_shape.DebugString()));
-      OP_REQUIRES(
-          context, BCast::ToShape(bcast.output_shape()) == output_shape,
-          errors::InvalidArgument("Unable to broadcast default_value of shape ",
-                                  default_value_shape, " to tensor of shape ",
-                                  output_shape));
+      // Note: bcast should always be valid, since we rejected any incompatible
+      // shapes when we called ValidateDefaultValueShape().
+      OP_REQUIRES(context, bcast.IsValid(),
+                  errors::InvalidArgument("Error broadcasting default_value"));
+      OP_REQUIRES_OK(context,
+                     context->allocate_temp(default_value_tensor.dtype(),
+                                            element_shape, &bcast_default));
       const CPUDevice& device = context->eigen_device<CPUDevice>();
       functor::BroadcastTo<CPUDevice, VALUE_TYPE>()(
-          device, context, *output_tensor, output_shape, default_value_tensor,
-          default_value_shape, bcast);
+          device, context, bcast_default, element_shape, default_value_tensor,
+          src_shape, bcast);
+      default_value = bcast_default.flat<VALUE_TYPE>().data();
+    }
 
-      VALUE_TYPE* base_output = output_flat.data();
-      auto values = context->input(kValueInputIndex).flat<VALUE_TYPE>();
-      size_t values_size = values.size();
-      size_t output_index_size = output_index.size();
-      //  A value "element" is a group of values that are arranged together.
-      // For example, if the value shape is [3,4,5], then 20 values are in a
-      // value element.
-      int value_element_size = values_size / output_index_size;
-      int value_element_bytesize = value_element_size * sizeof(VALUE_TYPE);
-      const VALUE_TYPE* values_base = values.data();
+    // Loop through the output_index vector, finding contiguous regions that
+    // should be copied.  Once we find the end of a contiguous region, copy it
+    // and add any necessary padding (with default_value).
+    INDEX_TYPE src_start = 0;  // Start of contiguous region (in values)
+    INDEX_TYPE dst_start = 0;  // Destination for contiguous region (in output)
+    INDEX_TYPE dst_end = 0;    // Destination for contiguous region (in output)
+    for (int src_i = 0; src_i <= output_index_size; ++src_i) {
+      // dst_i is the destination where the value at src_i should be copied.
+      INDEX_TYPE dst_i = src_i < output_index_size ? output_index[src_i] : -1;
 
-      OP_REQUIRES(context,
-                  value_tensor.shape().dim_size(0) == output_index_size,
-                  Internal("Values and indices must be equal"));
+      // If we're still in a contiguous region, then update dst_end go to the
+      // next src_i.
+      if (dst_i == dst_end) {
+        ++dst_end;
+        continue;
+      }
 
-      OP_REQUIRES(context,
-                  values_size == output_index_size * value_element_size,
-                  Internal("Values and indices must be equal"));
-      INDEX_TYPE value_index = 0;
-      for (int i = 0; i < output_index_size;
-           ++i, value_index += value_element_size) {
-        if (output_index[i] >= 0) {
-          VALUE_TYPE* dst = base_output + output_index[i];
-          const VALUE_TYPE* src = values_base + value_index;
-          copy_array<VALUE_TYPE, INDEX_TYPE>(dst, src, value_element_size,
-                                             value_element_bytesize);
+      // We found the end of contiguous region.  This can be because we found
+      // a gap (dst_i > dst_end), or a source value that shouldn't be copied
+      // because it's out-of-bounds (dst_i == -1), or the end of the tensor
+      // (dst_i = -1).
+      if (dst_start < dst_end) {
+        // Copy the contiguous region.
+        const VALUE_TYPE* src = values_base + src_start * value_element_size;
+        VALUE_TYPE* dst = output_base + dst_start * value_element_size;
+        INDEX_TYPE nvals = (dst_end - dst_start) * value_element_size;
+        copy_array<VALUE_TYPE, INDEX_TYPE>(dst, src, nvals);
+      }
+
+      // Add any necessary padding (w/ default_value).
+      if (src_i >= output_index_size) {
+        // We reached the end of values: pad to the end of output.
+        size_t output_size = output_tensor->NumElements();
+        dst_i = output_size / value_element_size;
+      }
+      if (dst_i > dst_end) {
+        if (default_value_tensor.NumElements() == 1) {
+          std::fill(output_base + dst_end * value_element_size,
+                    output_base + dst_i * value_element_size, *default_value);
+          dst_end = dst_i;
+        } else {
+          while (dst_i > dst_end) {
+            VALUE_TYPE* dst = output_base + dst_end * value_element_size;
+            copy_array<VALUE_TYPE, INDEX_TYPE>(dst, default_value,
+                                               value_element_size);
+            ++dst_end;
+          }
         }
       }
+
+      // Update indices.
+      if (dst_i < 0) {
+        // src_i should be skipped -- leave it out of the contiguous region.
+        src_start = src_i + 1;
+        dst_start = dst_end;
+      } else {
+        // src_i should be copied -- include it in the contiguous region.
+        src_start = src_i;
+        dst_start = dst_end;
+        dst_end = dst_start + 1;
+      }
     }
   }
 };