Address minor review comments for PR 47251
- Remove use_gpu=True because it is already the default.
- Use int64 for dense_index inside kernel to avoid integer overflow.
- Change reverse for-loop style.
- Reformat inline comments to ensure internal tooling picks them up.
diff --git a/tensorflow/core/kernels/reshape_util_gpu.cu.cc b/tensorflow/core/kernels/reshape_util_gpu.cu.cc
index 2fc837d..80bb212 100644
--- a/tensorflow/core/kernels/reshape_util_gpu.cu.cc
+++ b/tensorflow/core/kernels/reshape_util_gpu.cu.cc
@@ -36,13 +36,13 @@
GPU_1D_KERNEL_LOOP(sparse_index, nnz) {
const Tindex* input_index = &input_indices[sparse_index * input_rank];
Tindex* output_index = &output_indices[sparse_index * output_rank];
- Tindex dense_index = 0;
+ int64 dense_index = 0; // int64 to avoid overflow if Tindex is int32
// Flatten input index from slowest- to fastest-changing dimension.
for (int i = 0; i < input_rank; ++i) {
dense_index = dense_index * input_shape[i] + input_index[i];
}
// Compute output index from fastest- to slowest-changing dimension.
- for (int i = output_rank; i-- > 0;) {
+ for (int i = output_rank - 1; i >= 0; --i) {
Tindex output_size = output_shape[i];
output_index[i] = dense_index % output_size;
dense_index /= output_size;
@@ -95,12 +95,12 @@
auto config = GetGpuLaunchConfig(nnz, device);
return GpuLaunchKernel(ReshapeSparseTensorKernel<int64>, config.block_count,
config.thread_per_block, 0, device.stream(), nnz,
- /* input_rank = */ input_rank,
- /* output_rank = */ output_rank,
- /* input_shape = */ input_shape_gpu.data(),
- /* output_shape = */ output_shape_gpu.data(),
- /* input_indices = */ input_indices.data(),
- /* output_indices = */ output_indices.data());
+ /*input_rank=*/input_rank,
+ /*output_rank=*/output_rank,
+ /*input_shape=*/input_shape_gpu.data(),
+ /*output_shape=*/output_shape_gpu.data(),
+ /*input_indices=*/input_indices.data(),
+ /*output_indices=*/output_indices.data());
}
} // namespace functor
diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
index e0fed14..ab98c9a 100644
--- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
@@ -94,7 +94,7 @@
self.assertAllEqual((2, 3 * 4), sp_output.shape)
def testSameShape(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(input_val, [5, 6])
@@ -105,7 +105,7 @@
@test_util.run_deprecated_v1
def testFeedSameShape(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [5, 6])
@@ -117,7 +117,7 @@
@test_util.run_deprecated_v1
def testWorksWellWithTfShape(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
shape = array_ops.shape(sp_input) # tf.shape generates int32 output
@@ -130,7 +130,7 @@
@test_util.run_deprecated_v1
def testFeedSameShapeWithInferredDim(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [-1, 6])
@@ -142,7 +142,7 @@
@test_util.run_deprecated_v1
def testFeedNewShapeSameRank(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [3, 10])
@@ -156,7 +156,7 @@
@test_util.run_deprecated_v1
def testFeedNewShapeSameRankWithInferredDim(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [3, -1])
@@ -169,7 +169,7 @@
self.assertAllEqual(output_val.dense_shape, [3, 10])
def testUpRank(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(input_val, [2, 3, 5])
@@ -182,7 +182,7 @@
@test_util.run_deprecated_v1
def testFeedUpRank(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [2, 3, 5])
@@ -196,7 +196,7 @@
@test_util.run_deprecated_v1
def testFeedUpRankWithInferredDim(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [2, -1, 5])
@@ -210,7 +210,7 @@
@test_util.run_deprecated_v1
def testFeedDownRank(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_2x3x4()
sp_output = sparse_ops.sparse_reshape(sp_input, [6, 4])
@@ -224,7 +224,7 @@
@test_util.run_deprecated_v1
def testFeedDownRankWithInferredDim(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_2x3x4()
sp_output = sparse_ops.sparse_reshape(sp_input, [6, -1])
@@ -238,7 +238,7 @@
@test_util.run_deprecated_v1
def testFeedMultipleInferredDims(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [4, -1, -1])
@@ -254,7 +254,7 @@
@test_util.run_deprecated_v1
def testFeedMismatchedSizes(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [4, 7])
@@ -264,7 +264,7 @@
@test_util.run_deprecated_v1
def testFeedMismatchedSizesWithInferredDim(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
sp_output = sparse_ops.sparse_reshape(sp_input, [4, -1])
@@ -273,7 +273,7 @@
@test_util.run_deprecated_v1
def testFeedPartialShapes(self):
- with self.session(use_gpu=True):
+ with self.session():
# Incorporate new rank into shape information if known
sp_input = self._SparseTensorPlaceholder()
sp_output = sparse_ops.sparse_reshape(sp_input, [2, 3, 5])
@@ -299,7 +299,7 @@
@test_util.run_deprecated_v1
def testFeedDenseReshapeSemantics(self):
- with self.session(use_gpu=True) as sess:
+ with self.session() as sess:
# Compute a random rank-5 initial shape and new shape, randomly sparsify
# it, and check that the output of SparseReshape has the same semantics
# as a dense reshape.