Make SparseFillEmptyRows on GPU correct for empty indices and empty output shape
PiperOrigin-RevId: 371520893
Change-Id: I1de0830bc76e29abd2f0a44410372dee3f96e5ae
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
index 42abe89..b9baaa0 100644
--- a/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc
@@ -180,6 +180,22 @@
se::Stream* stream = context->op_device_context()->stream();
if (!stream) return errors::Internal("No GPU stream available.");
+ if (dense_rows == 0) {
+ Tindex* output_indices;
+ T* output_values;
+ Tindex* reverse_index_map;
+ TF_RETURN_IF_ERROR(AllocateOutputs(context, N, rank, /*num_empty_rows=*/0,
+ &output_indices, &output_values,
+ &reverse_index_map));
+ if (context->output_required(kEmptyRowIndicatorOutput)) {
+ Tensor* unused = nullptr;
+ TF_RETURN_IF_ERROR(context->allocate_output(kEmptyRowIndicatorOutput,
+ TensorShape({0}), &unused));
+ }
+ done();
+ return Status::OK();
+ }
+
// The algorithm as currently implemented is summarized as follows:
// 1) Compute elements_per_row (using GpuAtomicAdd).
// 2) Compute input_row_ends (the end index of each row) by computing the
@@ -245,9 +261,12 @@
return errors::Internal("Failed to initialize first_invalid_index");
}
- TF_RETURN_IF_ERROR(wrap_kernel_call(
- CountElementsPerRowKernel<Tindex>, device, N, dense_rows, rank, indices,
- elements_per_row, rows_are_not_ordered_gpu, first_invalid_index_gpu));
+ if (N > 0) {
+ TF_RETURN_IF_ERROR(wrap_kernel_call(
+ CountElementsPerRowKernel<Tindex>, /*device=*/device, /*size=*/N,
+ dense_rows, rank, indices, elements_per_row, rows_are_not_ordered_gpu,
+ first_invalid_index_gpu));
+ }
Tensor input_row_ends_t;
TF_RETURN_IF_ERROR(context->allocate_temp(
@@ -273,8 +292,8 @@
}
TF_RETURN_IF_ERROR(wrap_kernel_call(ComputeEmptyRowIndicatorKernel<Tindex>,
- device, dense_rows, elements_per_row,
- empty_row_indicator));
+ /*device=*/device, /*size=*/dense_rows,
+ elements_per_row, empty_row_indicator));
// For each row, the number of empty rows up to and including that row.
Tensor num_empty_rows_through_t;
@@ -367,19 +386,24 @@
input_index_map = input_index_map_t.vec<Tindex>().data();
}
+ if (N > 0) {
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ wrap_kernel_call(ScatterInputElementsKernel<T, Tindex>,
+ /*device=*/device, /*size=*/N, dense_rows, rank,
+ input_index_map, indices, values,
+ num_empty_rows_through, output_indices,
+ output_values, reverse_index_map),
+ done);
+ }
+
OP_REQUIRES_OK_ASYNC(
context,
- wrap_kernel_call(ScatterInputElementsKernel<T, Tindex>, device, N,
- dense_rows, rank, input_index_map, indices, values,
- num_empty_rows_through, output_indices,
- output_values, reverse_index_map),
- done);
- OP_REQUIRES_OK_ASYNC(
- context,
- wrap_kernel_call(ScatterNewElementsKernel<T, Tindex>, device,
- dense_rows, rank, default_value,
- num_empty_rows_through, input_row_ends,
- empty_row_indicator, output_indices, output_values),
+ wrap_kernel_call(ScatterNewElementsKernel<T, Tindex>,
+ /*device=*/device, /*size=*/dense_rows, rank,
+ default_value, num_empty_rows_through,
+ input_row_ends, empty_row_indicator, output_indices,
+ output_values),
done);
done();
@@ -428,8 +452,9 @@
TF_RETURN_IF_ERROR(
context->allocate_temp(index_type, TensorShape({N}), &row_indices_t));
auto row_indices = row_indices_t.flat<Tindex>();
- TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel<Tindex>, device, N,
- rank, indices, row_indices));
+ TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel<Tindex>,
+ /*device=*/device, /*size=*/N, rank,
+ indices, row_indices));
// Allocate input_index_map.
TF_RETURN_IF_ERROR(context->allocate_temp(index_type, TensorShape({N}),
input_index_map_t));
@@ -494,9 +519,9 @@
auto visited = visited_t.vec<bool>();
visited.device(device) = visited.constant(false);
- TF_RETURN_IF_ERROR(
- wrap_kernel_call(GatherOriginalGradValuesKernel<T, Tindex>, device, N,
- reverse_index_map, grad_values, d_values, visited));
+ TF_RETURN_IF_ERROR(wrap_kernel_call(
+ GatherOriginalGradValuesKernel<T, Tindex>, /*device=*/device,
+ /*size=*/N, reverse_index_map, grad_values, d_values, visited));
// Now we mask out the visited values and sum the remaining ones (which
// correspond to the empty rows in the forward input) to compute
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index a114955..76e19c3 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -602,6 +602,40 @@
self.assertAllEqual(output.dense_shape, [2, 5])
self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
+ def testEmptyIndicesTensor(self):
+ with test_util.use_gpu():
+ sp_input = sparse_tensor.SparseTensor(
+ indices=np.ones([0, 2]),
+ values=np.ones([0]),
+ dense_shape=np.array([2, 5]))
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = self.evaluate(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 0]])
+ self.assertAllEqual(output.values, [-1, -1])
+ self.assertAllEqual(output.dense_shape, [2, 5])
+ self.assertAllEqual(empty_row_indicator_out, np.ones(2).astype(np.bool))
+
+ def testEmptyOutput(self):
+ with test_util.use_gpu():
+ sp_input = sparse_tensor.SparseTensor(
+ indices=np.ones([0, 2]),
+ values=np.ones([0]),
+ dense_shape=np.array([0, 3]))
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = self.evaluate(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(output.indices, np.ones([0, 2]))
+ self.assertAllEqual(output.values, np.ones([0]))
+ self.assertAllEqual(output.dense_shape, [0, 3])
+ self.assertAllEqual(empty_row_indicator_out, [])
+
def testInvalidIndices(self):
with test_util.use_gpu():
sp_input = sparse_tensor.SparseTensor(