Make SparseFillEmptyRows on GPU correct for empty indices and empty output shape

PiperOrigin-RevId: 371520893
Change-Id: I1de0830bc76e29abd2f0a44410372dee3f96e5ae
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
index 42abe89..b9baaa0 100644
--- a/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
@@ -180,6 +180,22 @@
     se::Stream* stream = context->op_device_context()->stream();
     if (!stream) return errors::Internal("No GPU stream available.");
 
+    if (dense_rows == 0) {
+      Tindex* output_indices;
+      T* output_values;
+      Tindex* reverse_index_map;
+      TF_RETURN_IF_ERROR(AllocateOutputs(context, N, rank, /*num_empty_rows=*/0,
+                                         &output_indices, &output_values,
+                                         &reverse_index_map));
+      if (context->output_required(kEmptyRowIndicatorOutput)) {
+        Tensor* unused = nullptr;
+        TF_RETURN_IF_ERROR(context->allocate_output(kEmptyRowIndicatorOutput,
+                                                    TensorShape({0}), &unused));
+      }
+      done();
+      return Status::OK();
+    }
+
     // The algorithm as currently implemented is summarized as follows:
     // 1) Compute elements_per_row (using GpuAtomicAdd).
     // 2) Compute input_row_ends (the end index of each row) by computing the
@@ -245,9 +261,12 @@
       return errors::Internal("Failed to initialize first_invalid_index");
     }
 
-    TF_RETURN_IF_ERROR(wrap_kernel_call(
-        CountElementsPerRowKernel<Tindex>, device, N, dense_rows, rank, indices,
-        elements_per_row, rows_are_not_ordered_gpu, first_invalid_index_gpu));
+    if (N > 0) {
+      TF_RETURN_IF_ERROR(wrap_kernel_call(
+          CountElementsPerRowKernel<Tindex>, /*device=*/device, /*size=*/N,
+          dense_rows, rank, indices, elements_per_row, rows_are_not_ordered_gpu,
+          first_invalid_index_gpu));
+    }
 
     Tensor input_row_ends_t;
     TF_RETURN_IF_ERROR(context->allocate_temp(
@@ -273,8 +292,8 @@
     }
 
     TF_RETURN_IF_ERROR(wrap_kernel_call(ComputeEmptyRowIndicatorKernel<Tindex>,
-                                        device, dense_rows, elements_per_row,
-                                        empty_row_indicator));
+                                        /*device=*/device, /*size=*/dense_rows,
+                                        elements_per_row, empty_row_indicator));
 
     // For each row, the number of empty rows up to and including that row.
     Tensor num_empty_rows_through_t;
@@ -367,19 +386,24 @@
         input_index_map = input_index_map_t.vec<Tindex>().data();
       }
 
+      if (N > 0) {
+        OP_REQUIRES_OK_ASYNC(
+            context,
+            wrap_kernel_call(ScatterInputElementsKernel<T, Tindex>,
+                             /*device=*/device, /*size=*/N, dense_rows, rank,
+                             input_index_map, indices, values,
+                             num_empty_rows_through, output_indices,
+                             output_values, reverse_index_map),
+            done);
+      }
+
       OP_REQUIRES_OK_ASYNC(
           context,
-          wrap_kernel_call(ScatterInputElementsKernel<T, Tindex>, device, N,
-                           dense_rows, rank, input_index_map, indices, values,
-                           num_empty_rows_through, output_indices,
-                           output_values, reverse_index_map),
-          done);
-      OP_REQUIRES_OK_ASYNC(
-          context,
-          wrap_kernel_call(ScatterNewElementsKernel<T, Tindex>, device,
-                           dense_rows, rank, default_value,
-                           num_empty_rows_through, input_row_ends,
-                           empty_row_indicator, output_indices, output_values),
+          wrap_kernel_call(ScatterNewElementsKernel<T, Tindex>,
+                           /*device=*/device, /*size=*/dense_rows, rank,
+                           default_value, num_empty_rows_through,
+                           input_row_ends, empty_row_indicator, output_indices,
+                           output_values),
           done);
 
       done();
@@ -428,8 +452,9 @@
     TF_RETURN_IF_ERROR(
         context->allocate_temp(index_type, TensorShape({N}), &row_indices_t));
     auto row_indices = row_indices_t.flat<Tindex>();
-    TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel<Tindex>, device, N,
-                                        rank, indices, row_indices));
+    TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel<Tindex>,
+                                        /*device=*/device, /*size=*/N, rank,
+                                        indices, row_indices));
     // Allocate input_index_map.
     TF_RETURN_IF_ERROR(context->allocate_temp(index_type, TensorShape({N}),
                                               input_index_map_t));
@@ -494,9 +519,9 @@
     auto visited = visited_t.vec<bool>();
     visited.device(device) = visited.constant(false);
 
-    TF_RETURN_IF_ERROR(
-        wrap_kernel_call(GatherOriginalGradValuesKernel<T, Tindex>, device, N,
-                         reverse_index_map, grad_values, d_values, visited));
+    TF_RETURN_IF_ERROR(wrap_kernel_call(
+        GatherOriginalGradValuesKernel<T, Tindex>, /*device=*/device,
+        /*size=*/N, reverse_index_map, grad_values, d_values, visited));
 
     // Now we mask out the visited values and sum the remaining ones (which
     // correspond to the empty rows in the forward input) to compute
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index a114955..76e19c3 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -602,6 +602,40 @@
       self.assertAllEqual(output.dense_shape, [2, 5])
       self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
 
+  def testEmptyIndicesTensor(self):
+    with test_util.use_gpu():
+      sp_input = sparse_tensor.SparseTensor(
+          indices=np.ones([0, 2]),
+          values=np.ones([0]),
+          dense_shape=np.array([2, 5]))
+      sp_output, empty_row_indicator = (
+          sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+      output, empty_row_indicator_out = self.evaluate(
+          [sp_output, empty_row_indicator])
+
+      self.assertAllEqual(output.indices, [[0, 0], [1, 0]])
+      self.assertAllEqual(output.values, [-1, -1])
+      self.assertAllEqual(output.dense_shape, [2, 5])
+      self.assertAllEqual(empty_row_indicator_out, np.ones(2).astype(np.bool))
+
+  def testEmptyOutput(self):
+    with test_util.use_gpu():
+      sp_input = sparse_tensor.SparseTensor(
+          indices=np.ones([0, 2]),
+          values=np.ones([0]),
+          dense_shape=np.array([0, 3]))
+      sp_output, empty_row_indicator = (
+          sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+      output, empty_row_indicator_out = self.evaluate(
+          [sp_output, empty_row_indicator])
+
+      self.assertAllEqual(output.indices, np.ones([0, 2]))
+      self.assertAllEqual(output.values, np.ones([0]))
+      self.assertAllEqual(output.dense_shape, [0, 3])
+      self.assertAllEqual(empty_row_indicator_out, [])
+
   def testInvalidIndices(self):
     with test_util.use_gpu():
       sp_input = sparse_tensor.SparseTensor(